Source code for oumi.core.configs.analyze_config

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional

from omegaconf import MISSING

from oumi.core.configs.base_config import BaseConfig
from oumi.core.configs.params.base_params import BaseParams


[docs] class DatasetSource(Enum): """Source of the dataset for analysis. .. deprecated:: This enum is deprecated and will be removed in a future release. The dataset source is now automatically determined based on whether a dataset is passed directly to DatasetAnalyzer.__init__(). """ CONFIG = "config" """Load dataset from config parameters (dataset_name, dataset_path, etc.)""" DIRECT = "direct" """Pass dataset directly to DatasetAnalyzer.__init__()"""
[docs] @dataclass class SampleAnalyzerParams(BaseParams): """Params for a single sample analyzer plugin.""" id: str = MISSING """Unique identifier for the analyzer.""" params: dict[str, Any] = field(default_factory=dict) """Analyzer-specific parameters passed to the analyzer constructor."""
[docs] @dataclass class AnalyzeConfig(BaseConfig): """Configuration for dataset analysis and aggregation.""" dataset_source: Optional[DatasetSource] = None """Source of the dataset for analysis. .. deprecated:: This field is deprecated and will be removed in a future release. The dataset source is now automatically determined based on whether a dataset is passed directly to DatasetAnalyzer.__init__(). """ dataset_format: Optional[str] = None """Format of the custom dataset. .. deprecated:: This field is deprecated and will be removed in a future release. The dataset format is now automatically detected from the file contents. """ dataset_name: Optional[str] = None """Dataset name.""" dataset_path: Optional[str] = None """Path to a custom dataset file (JSON or JSONL format). If provided, this takes precedence over dataset_name for loading custom datasets. """ split: str = "train" """The split of the dataset to load. This is typically one of "train", "test", or "validation". Defaults to "train". """ subset: Optional[str] = None """The subset of the dataset to load. If None, uses the base dataset.""" sample_count: Optional[int] = None """The number of examples to sample from the dataset. If None, uses the full dataset. If specified, must be non-negative. """ output_path: str = "." """Directory path where output files will be saved. Defaults to current directory ('.'). """ analyzers: list[SampleAnalyzerParams] = field(default_factory=list) """List of analyzer configurations (plugin-style).""" # Tokenizer configuration tokenizer_name: Optional[str] = None """The name or path of the tokenizer to use for token counting metrics. If None, no tokenizer will be used. This is typically a model identifier from HuggingFace Hub (e.g., "openai-community/gpt2"). """ tokenizer_kwargs: dict[str, Any] = field(default_factory=dict) """Additional keyword arguments to pass to the tokenizer constructor.""" tokenizer_config: Optional[dict[str, Any]] = None """Tokenizer configuration for building a tokenizer. .. deprecated:: This field is deprecated and will be removed in a future release. Use 'tokenizer_name' and 'tokenizer_kwargs' instead. """ # Processor parameters for vision-language datasets processor_name: Optional[str] = None """Processor name for vision-language datasets. If provided, the dataset will be treated as multimodal (vision-language). """ processor_kwargs: dict[str, Any] = field(default_factory=dict) """Processor-specific parameters.""" trust_remote_code: bool = False """Whether to trust remote code for tokenizer/processor loading.""" is_multimodal: Optional[bool] = None """Whether to treat the dataset as multimodal (vision-language). .. deprecated:: This field is deprecated and will be removed in a future release. Multimodality is now automatically detected based on whether 'processor_name' is provided. """
[docs] def __post_init__(self): """Validates the configuration parameters.""" # Emit deprecation warnings for deprecated fields if self.dataset_source is not None: warnings.warn( "The 'dataset_source' field is deprecated and will be removed in a " "future release. The dataset source is now automatically determined " "based on whether a dataset is passed directly to " "DatasetAnalyzer.__init__(). This field is ignored.", DeprecationWarning, stacklevel=2, ) if self.dataset_format is not None: warnings.warn( "The 'dataset_format' field is deprecated and will be removed in a " "future release. The dataset format is now automatically detected " "from the file contents. This field is ignored.", DeprecationWarning, stacklevel=2, ) # Handle deprecated tokenizer_config field if self.tokenizer_config is not None: warnings.warn( "The 'tokenizer_config' field is deprecated and will be removed in a " "future release. Use 'tokenizer_name' and 'tokenizer_kwargs' instead. " "Values from 'tokenizer_config' will be used for this run.", DeprecationWarning, stacklevel=2, ) # Migrate values from tokenizer_config to new fields if not already set if self.tokenizer_name is None and "model_name" in self.tokenizer_config: self.tokenizer_name = self.tokenizer_config["model_name"] if ( not self.tokenizer_kwargs and "tokenizer_kwargs" in self.tokenizer_config ): self.tokenizer_kwargs = self.tokenizer_config["tokenizer_kwargs"] # trust_remote_code from tokenizer_config only applies if not explicitly set if ( "trust_remote_code" in self.tokenizer_config and self.tokenizer_config["trust_remote_code"] and not self.trust_remote_code ): self.trust_remote_code = self.tokenizer_config["trust_remote_code"] # Handle deprecated is_multimodal field if self.is_multimodal is not None: warnings.warn( "The 'is_multimodal' field is deprecated and will be removed in a " "future release. Multimodality is now automatically detected based " "on whether 'processor_name' is provided. This field is ignored.", DeprecationWarning, stacklevel=2, ) # Validate sample_count if self.sample_count is not None and self.sample_count <= 0: raise ValueError("`sample_count` must be greater than 0.") # Validate analyzer configurations analyzer_ids = set() for analyzer in self.analyzers: # Validate analyzer ID if not analyzer.id: raise ValueError("Analyzer 'id' must be provided") if analyzer.id in analyzer_ids: raise ValueError(f"Duplicate analyzer ID found: '{analyzer.id}'") analyzer_ids.add(analyzer.id)