Source code for oumi.analyze.config

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configuration for the typed analyzer framework."""

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import yaml

from oumi.core.configs.params.test_params import TestParams


[docs] @dataclass class AnalyzerConfig: """Configuration for a single analyzer instance. Each analyzer has a type (`id`) and a unique instance name (`instance_id`). Multiple instances of the same type are supported (e.g. two length analyzers with different tokenizers). Attributes: id: Analyzer type (registry id, e.g. "length", "difficulty_judge"). instance_id: Unique instance name (always required). Used as the results key and in test metric paths. params: Analyzer-specific parameters. """ id: str instance_id: str params: dict[str, Any] = field(default_factory=dict)
@dataclass class OutputFieldSchema: """Schema definition for a single output field. Attributes: name: Field name (key in the returned dict). type: Field type ("int", "float", "bool", "str", "list"). description: Description of the field. """ name: str type: str = "float" description: str = "" @dataclass class CustomMetricConfig: """Configuration for a custom user-defined metric. Custom metrics allow users to define Python functions that compute additional metrics. These are executed during the analysis phase and their results are cached. .. warning:: **Security Warning**: The ``function`` field contains arbitrary Python code that is executed dynamically. Only load configurations from trusted sources. Never load YAML configs from untrusted users or external sources without review, as they could execute malicious code. Example YAML:: custom_metrics: - id: word_to_char_ratio scope: conversation description: "Ratio of words to characters" output_schema: - name: ratio type: float description: "Words divided by characters (0.15-0.20 is typical)" function: | def compute(conversation): chars = sum(len(m.content) for m in conversation.messages) words = sum(len(m.content.split()) for m in conversation.messages) return {"ratio": words / chars if chars > 0 else 0.0} Attributes: id: Unique identifier for the metric. scope: Scope of the metric ("message", "conversation", or "dataset"). function: Python code defining a compute() function. description: Description of what the metric computes. output_schema: List of output field definitions. """ id: str scope: str = "conversation" # "message", "conversation", or "dataset" function: str = "" description: str | None = None output_schema: list[OutputFieldSchema] = field(default_factory=list) depends_on: list[str] = field(default_factory=list) def __post_init__(self): """Validate the configuration.""" if self.scope not in ("message", "conversation", "dataset"): raise ValueError( f"Invalid scope '{self.scope}'. " "Must be 'message', 'conversation', or 'dataset'." ) def get_metric_paths(self) -> list[str]: """Get full metric paths for all output fields. Returns: List of metric paths like ["metric_id.field_name", ...]. """ if self.output_schema: return [f"{self.id}.{f.name}" for f in self.output_schema] return [f"{self.id}.<field>"] def get_field_info(self) -> dict[str, dict[str, str]]: """Get field information for display. Returns: Dict mapping field names to {"type": ..., "description": ...}. """ return { f.name: {"type": f.type, "description": f.description} for f in self.output_schema }
[docs] @dataclass class TypedAnalyzeConfig: """Configuration for the typed analyzer pipeline. This is the main configuration class for the new typed analyzer architecture. It supports both programmatic construction and loading from YAML files. Example YAML:: dataset_path: /path/to/data.jsonl sample_count: 1000 output_path: ./analysis_output analyzers: - id: length params: count_tokens: true - id: quality custom_metrics: - id: turn_pattern scope: conversation function: | def compute(conversation): ... tests: - id: max_words type: threshold metric: LengthAnalyzer.total_words operator: ">" value: 10000 max_percentage: 5.0 Attributes: dataset_name: Name of the dataset (HuggingFace identifier). dataset_path: Path to local dataset file. split: Dataset split to use. sample_count: Number of samples to analyze. output_path: Directory for output artifacts. analyzers: List of analyzer configurations. custom_metrics: List of custom metric configurations. tests: List of test configurations. tokenizer_name: Tokenizer for token counting. generate_report: Whether to generate HTML report. report_title: Custom title for the report. """ eval_name: str | None = None parent_eval_id: str | None = None dataset_name: str | None = None dataset_path: str | None = None split: str = "train" subset: str | None = None sample_count: int | None = None output_path: str = "." analyzers: list[AnalyzerConfig] = field(default_factory=list) custom_metrics: list[CustomMetricConfig] = field(default_factory=list) tests: list[TestParams] = field(default_factory=list) tokenizer_name: str | None = None tokenizer_kwargs: dict[str, Any] = field(default_factory=dict) generate_report: bool = False report_title: str | None = None
[docs] @classmethod def from_yaml( cls, path: str | Path, allow_custom_code: bool = False ) -> "TypedAnalyzeConfig": """Load configuration from a YAML file. .. warning:: **Security Warning**: If the YAML file contains ``custom_metrics`` with ``function`` fields, arbitrary Python code will be loaded. Only load configurations from trusted sources. Set ``allow_custom_code=True`` to explicitly acknowledge this risk. Args: path: Path to YAML configuration file. allow_custom_code: If True, allow loading custom_metrics with function code. If False (default) and the config contains custom metrics with code, raises ValueError. Returns: TypedAnalyzeConfig instance. Raises: ValueError: If config contains custom code but allow_custom_code=False. """ with open(path) as f: data = yaml.safe_load(f) return cls.from_dict(data, allow_custom_code=allow_custom_code)
@classmethod def _parse_analyzers(cls, data: dict[str, Any]) -> list[AnalyzerConfig]: """Parse analyzer configurations, raising on duplicate instance_ids.""" analyzers = [] for analyzer_data in data.get("analyzers", []): if isinstance(analyzer_data, dict): # instance_id defaults to id if not provided in YAML if "instance_id" not in analyzer_data: analyzer_data = { **analyzer_data, "instance_id": analyzer_data["id"], } analyzers.append(AnalyzerConfig(**analyzer_data)) elif isinstance(analyzer_data, str): analyzers.append( AnalyzerConfig(id=analyzer_data, instance_id=analyzer_data) ) # Validate unique instance_ids instance_ids = [a.instance_id for a in analyzers] duplicates = [id for id in set(instance_ids) if instance_ids.count(id) > 1] if duplicates: raise ValueError( f"Duplicate analyzer instance_id values: {duplicates}. " "Each analyzer must have a unique instance_id to avoid collisions." ) return analyzers @classmethod def _parse_custom_metrics( cls, data: dict[str, Any], allow_custom_code: bool ) -> list[CustomMetricConfig]: """Parse custom metrics, raising if code is present and not allowed.""" custom_metrics = [] for metric_data in data.get("custom_metrics", []): output_schema = [ OutputFieldSchema(**f) for f in metric_data.get("output_schema", []) if isinstance(f, dict) ] remaining = {k: v for k, v in metric_data.items() if k != "output_schema"} custom_metrics.append( CustomMetricConfig(**remaining, output_schema=output_schema) ) # Security check: reject custom code unless explicitly allowed if not allow_custom_code: metrics_with_code = [m.id for m in custom_metrics if m.function.strip()] if metrics_with_code: raise ValueError( f"Configuration contains custom metrics with executable code: " f"{metrics_with_code}. This is a security risk if loading from " f"untrusted sources. Set allow_custom_code=True to explicitly " f"allow code execution, or remove the 'function' fields." ) return custom_metrics @classmethod def _parse_tests(cls, data: dict[str, Any]) -> list[TestParams]: """Parse and validate test configurations.""" tests = [] for test_data in data.get("tests", []): test_params = TestParams(**test_data) test_params.finalize_and_validate() tests.append(test_params) return tests
[docs] @classmethod def from_dict( cls, data: dict[str, Any], allow_custom_code: bool = False ) -> "TypedAnalyzeConfig": """Create configuration from a dictionary. Args: data: Configuration dictionary. allow_custom_code: If True, allow custom_metrics with function code. If False (default) and the config contains custom metrics with code, raises ValueError. Returns: TypedAnalyzeConfig instance. Raises: ValueError: If config contains custom code but allow_custom_code=False, or if duplicate analyzer instance_ids found. """ analyzers = cls._parse_analyzers(data) custom_metrics = cls._parse_custom_metrics(data, allow_custom_code) tests = cls._parse_tests(data) return cls( eval_name=data.get("eval_name"), parent_eval_id=data.get("parent_eval_id"), dataset_name=data.get("dataset_name"), dataset_path=data.get("dataset_path"), split=data.get("split", "train"), subset=data.get("subset"), sample_count=data.get("sample_count"), output_path=data.get("output_path", "."), analyzers=analyzers, custom_metrics=custom_metrics, tests=tests, tokenizer_name=data.get("tokenizer_name"), tokenizer_kwargs=data.get("tokenizer_kwargs", {}), generate_report=data.get("generate_report", False), report_title=data.get("report_title"), )
[docs] def to_dict(self) -> dict[str, Any]: """Convert configuration to a dictionary.""" return { "eval_name": self.eval_name, "parent_eval_id": self.parent_eval_id, "dataset_name": self.dataset_name, "dataset_path": self.dataset_path, "split": self.split, "subset": self.subset, "sample_count": self.sample_count, "output_path": self.output_path, "analyzers": [ {"id": a.id, "instance_id": a.instance_id, "params": a.params} for a in self.analyzers ], "custom_metrics": [ { "id": m.id, "scope": m.scope, "function": m.function, "description": m.description, "output_schema": [ { "name": f.name, "type": f.type, "description": f.description, } for f in m.output_schema ], "depends_on": m.depends_on, } for m in self.custom_metrics ], "tests": [ { "id": t.id, "type": t.type, "metric": t.metric, "severity": t.severity, "title": t.title, "description": t.description, "operator": t.operator, "value": t.value, "condition": t.condition, "max_percentage": t.max_percentage, "min_percentage": t.min_percentage, } for t in self.tests ], "tokenizer_name": self.tokenizer_name, "tokenizer_kwargs": self.tokenizer_kwargs, "generate_report": self.generate_report, "report_title": self.report_title, }