# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test engine for validating typed analysis results.
This module provides a test engine that operates on typed Pydantic results
instead of DataFrames. Tests are pure validation - no computation allowed.
"""
import logging
import operator
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel
from oumi.analyze.testing.results import TestResult, TestSeverity, TestSummary
from oumi.core.configs.params.test_params import TestParams, TestType
logger = logging.getLogger(__name__)
# Maximum number of sample indices to include in test results
MAX_SAMPLE_INDICES = 50
# Maximum number of failure reasons to include in test details
MAX_FAILURE_REASONS = 50
OPERATORS: dict[str, Callable[[Any, Any], bool]] = {
"<": operator.lt,
">": operator.gt,
"<=": operator.le,
">=": operator.ge,
"==": operator.eq,
"!=": operator.ne,
}
[docs]
class TestEngine:
"""Engine for running tests on typed analysis results.
Tests operate on typed Pydantic results, not DataFrames. This ensures
tests are pure validation with no computation - all metrics must be
pre-computed by analyzers.
Example:
>>> from oumi.analyze.testing import TestEngine, TestParams, TestType
>>>
>>> tests = [
... TestParams(
... id="max_words",
... type=TestType.THRESHOLD,
... metric="length.total_tokens",
... operator=">",
... value=10000,
... max_percentage=5.0,
... severity=TestSeverity.MEDIUM,
... ),
... ]
>>> engine = TestEngine(tests)
>>> summary = engine.run(results)
>>> print(f"Pass rate: {summary.pass_rate}%")
Args:
tests: List of test configurations.
"""
def __init__(self, tests: list[TestParams]):
"""Initialize the test engine with test configurations."""
self.tests = tests
def _create_error_result(self, test: TestParams, error: str) -> TestResult:
"""Create a TestResult for an error condition."""
return TestResult(
test_id=test.id,
passed=False,
severity=TestSeverity(test.severity),
title=test.title or test.id,
description=test.description or "",
metric=test.metric or "",
error=error,
)
def _build_test_result(
self,
test: TestParams,
passed: bool,
total_count: int,
affected_indices: list[int],
affected_pct: float,
details: dict[str, Any],
actual_value: float | None = None,
) -> TestResult:
"""Build a TestResult from common fields."""
return TestResult(
test_id=test.id,
passed=passed,
severity=TestSeverity(test.severity),
title=test.title or test.id,
description=test.description or "",
metric=test.metric or "",
affected_count=len(affected_indices),
total_count=total_count,
affected_percentage=round(affected_pct, 2),
threshold=test.max_percentage or test.min_percentage,
actual_value=actual_value,
sample_indices=affected_indices[:MAX_SAMPLE_INDICES],
all_affected_indices=affected_indices,
details=details,
)
def _get_actual_value(self, values: list[Any]) -> float | None:
"""Extract actual value for single-value numeric metrics."""
if len(values) == 1 and isinstance(values[0], int | float):
return float(values[0])
return None
[docs]
def run(
self,
results: dict[str, list[BaseModel] | BaseModel],
) -> TestSummary:
"""Run all tests on the analysis results.
Args:
results: Dictionary mapping analyzer names to results.
Returns:
TestSummary containing all test results.
"""
test_results: list[TestResult] = []
logger.info(f"Running {len(self.tests)} tests...")
for test in self.tests:
try:
result = self._run_single_test(test, results)
test_results.append(result)
status = "PASSED" if result.passed else "FAILED"
logger.debug(
f" Test '{test.id}': {status} "
f"({result.affected_count}/{result.total_count} affected)"
)
except Exception as e:
error_result = self._create_error_result(
test, f"Test execution failed: {e}"
)
test_results.append(error_result)
logger.warning(f" Test '{test.id}': ERROR - {e}")
summary = TestSummary.from_results(test_results)
logger.info(
f"Test results: {summary.passed_tests}/{summary.total_tests} passed "
f"({summary.pass_rate}%)"
)
if summary.high_severity_failures > 0:
logger.warning(f" {summary.high_severity_failures} high severity failures")
return summary
def _run_single_test(
self,
test: TestParams,
results: dict[str, list[BaseModel] | BaseModel],
) -> TestResult:
"""Run a single test and return its result."""
if not test.metric:
return self._create_error_result(test, "Test requires 'metric' field")
values = self._extract_metric_values(test.metric, results)
if not values:
return self._create_error_result(
test, f"Metric '{test.metric}' not found in results"
)
if test.type == TestType.THRESHOLD:
return self._run_threshold_test(test, values)
else:
return self._create_error_result(test, f"Unknown test type: {test.type}")
def _extract_metric_values(
self,
metric: str,
results: dict[str, list[BaseModel] | BaseModel],
) -> list[Any]:
"""Extract values for a metric path like "instance_id.field_name"."""
parts = metric.split(".")
if len(parts) < 2:
return []
analyzer_name = parts[0]
field_path = parts[1:]
if analyzer_name not in results:
return []
analyzer_results = results[analyzer_name]
if isinstance(analyzer_results, BaseModel):
value = self._get_nested_value(analyzer_results, field_path)
return [value] if value is not None else []
values = []
for result in analyzer_results:
value = self._get_nested_value(result, field_path)
if value is not None:
values.append(value)
return values
def _get_nested_value(self, obj: Any, field_path: list[str]) -> Any:
"""Get a nested field value from a Pydantic model or dict."""
current: Any = obj
for i, field in enumerate(field_path):
if isinstance(current, BaseModel):
if field in type(current).model_fields:
current = getattr(current, field)
else:
# Check for CustomMetricResult with values dict
values = getattr(current, "values", None)
if isinstance(values, dict):
return self._traverse_dict(values, field_path[i:])
return None
elif isinstance(current, dict):
if field in current:
current = current[field]
else:
return None
else:
raise TypeError(
f"Cannot traverse type {type(current).__name__}. "
f"Expected BaseModel or dict, got {current!r}"
)
return current
def _traverse_dict(self, d: dict, path: list[str]) -> Any | None:
"""Traverse a dict using a field path."""
current: Any = d
for field in path:
if isinstance(current, dict) and field in current:
current = current[field]
else:
return None
return current
def _run_threshold_test(
self,
test: TestParams,
values: list[Any],
) -> TestResult:
"""Run a threshold test against metric values."""
if test.operator is None or test.value is None:
return self._create_error_result(
test, "Threshold test requires 'operator' and 'value'"
)
op_func = OPERATORS.get(test.operator)
if op_func is None:
return self._create_error_result(test, f"Unknown operator: {test.operator}")
matching_indices: list[int] = []
non_matching_indices: list[int] = []
matching_reasons: dict[int, str] = {}
non_matching_reasons: dict[int, str] = {}
for i, value in enumerate(values):
try:
if op_func(value, test.value):
matching_indices.append(i)
matching_reasons[i] = (
f"{value} satisfies {test.operator} {test.value}"
)
else:
non_matching_indices.append(i)
non_matching_reasons[i] = (
f"{value} does not satisfy {test.operator} {test.value}"
)
except (TypeError, ValueError):
non_matching_indices.append(i)
non_matching_reasons[i] = f"Cannot evaluate: {value}"
total_count = len(values)
matching_count = len(matching_indices)
if total_count > 0:
matching_pct = 100.0 * matching_count / total_count
non_matching_pct = 100.0 * len(non_matching_indices) / total_count
else:
matching_pct = 0.0
non_matching_pct = 0.0
passed = True
affected_indices: list[int] = []
affected_pct = 0.0
failure_reasons: dict[int, str] = {}
if test.max_percentage is not None and matching_pct > test.max_percentage:
passed = False
affected_indices = matching_indices
affected_pct = matching_pct
failure_reasons = matching_reasons
if test.min_percentage is not None and matching_pct < test.min_percentage:
passed = False
# If max also failed, keep matching; otherwise use non-matching as
# the affected set because they prevent meeting the minimum.
if not affected_indices:
affected_indices = non_matching_indices
affected_pct = non_matching_pct
failure_reasons = non_matching_reasons
# No percentage thresholds: test passes when no samples match.
if test.max_percentage is None and test.min_percentage is None:
passed = len(matching_indices) == 0
affected_indices = matching_indices
affected_pct = matching_pct
failure_reasons = matching_reasons
return self._build_test_result(
test=test,
passed=passed,
total_count=total_count,
affected_indices=affected_indices,
affected_pct=affected_pct,
actual_value=self._get_actual_value(values),
details={
"operator": test.operator,
"value": test.value,
"max_percentage": test.max_percentage,
"min_percentage": test.min_percentage,
"matching_count": matching_count,
"matching_percentage": round(matching_pct, 2),
"failure_reasons": {
k: v for k, v in list(failure_reasons.items())[:MAX_FAILURE_REASONS]
},
},
)