Source code for oumi.core.analyze.dataset_analyzer

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from dataclasses import asdict, dataclass
from typing import Any, Optional, Union, cast

import pandas as pd

from oumi.core.analyze.dataframe_analyzer import DataFrameAnalyzer, DataFrameWithSchema
from oumi.core.configs import AnalyzeConfig, DatasetSource
from oumi.core.datasets import BaseMapDataset
from oumi.core.datasets.base_iterable_dataset import BaseIterableDataset
from oumi.core.registry import REGISTRY
from oumi.utils.analysis_utils import (
    build_tokenizer_from_config,
    compute_statistics,
    convert_dataset_to_dataframes,
    get_schema_for_format,
    load_dataset_from_config,
)
from oumi.utils.logging import logger


@dataclass
class MessageAnalysisResult:
    """Result of analyzing a single message in a conversation.

    Attributes:
        message_index: Index of the message within the conversation
        role: Role of the message sender (e.g., 'user', 'assistant')
        message_id: Unique identifier for the message
        text_content: The text content of the message
        analyzer_metrics: Dictionary containing analyzer metrics for this message
    """

    message_index: int
    role: str
    message_id: str
    text_content: str
    analyzer_metrics: dict[str, Any]

    def to_dict(self) -> dict[str, Any]:
        """Convert the analysis result to a dictionary with flattened analyzer metrics.

        Returns:
            Dictionary representation of the analysis result
        """
        return asdict(self)


@dataclass
class ConversationAnalysisResult:
    """Result of analyzing a conversation as a whole.

    Attributes:
        analyzer_metrics: Dictionary containing analyzer metrics for the conversation
    """

    analyzer_metrics: dict[str, Any]

    def to_dict(self) -> dict[str, Any]:
        """Convert the analysis result to a dictionary.

        Returns:
            Dictionary representation of the analysis result
        """
        return asdict(self)


@dataclass
class DatasetAnalysisResult:
    """Complete result of dataset analysis.

    Attributes:
        dataset_name: Name of the analyzed dataset
        total_conversations: Total number of conversations in the dataset
        conversations_analyzed: Number of conversations actually analyzed
    """

    dataset_name: str
    total_conversations: int
    conversations_analyzed: int

    def to_dict(self) -> dict[str, Any]:
        """Convert the analysis result to a dictionary.

        Returns:
            Dictionary representation of the analysis result
        """
        return asdict(self)


[docs] class DatasetAnalyzer: """Orchestrates the analysis of datasets using multiple sample analyzers.""" def __init__(self, config: AnalyzeConfig, dataset: Optional[BaseMapDataset] = None): """Initialize the dataset analyzer with configuration. Args: config: AnalyzeConfig object containing all analysis parameters dataset: Optional pre-loaded dataset. If provided, this dataset will be used instead of loading from the config. """ self.config = config self.dataset_name = config.dataset_name self.split = config.split # Build tokenizer from config if provided self.tokenizer = build_tokenizer_from_config(config.tokenizer_config) # Use provided dataset or load from config based on dataset_source if config.dataset_source == DatasetSource.DIRECT: # Direct mode: must provide dataset if dataset is None: raise ValueError( "Config specifies dataset_source=DatasetSource.DIRECT but no " "dataset was provided. Either pass a dataset to " "DatasetAnalyzer.__init__() or " "set dataset_source=DatasetSource.CONFIG.value." ) self.dataset = dataset # Use the provided dataset name if config doesn't have one if not self.dataset_name: self.dataset_name = getattr(dataset, "dataset_name", "Custom Dataset") # Handle iterable datasets that don't support len() if isinstance(dataset, BaseIterableDataset): logger.info(f"Using provided streaming dataset '{self.dataset_name}'") else: logger.info( f"Using provided dataset '{self.dataset_name}' with " f"{len(dataset)} conversations" ) elif config.dataset_source == DatasetSource.CONFIG: # Config mode: load dataset from config parameters if dataset is not None: raise ValueError( f"Dataset provided but config.dataset_source is " f"'{config.dataset_source.value}'. When using " f"DatasetSource.CONFIG, do not pass a dataset to the " f"constructor. Set dataset_source=DatasetSource.DIRECT " f"if you want to use the provided dataset." ) # Load dataset with the tokenizer self.dataset = load_dataset_from_config(config, self.tokenizer) logger.info(f"Loaded dataset from config: {self.dataset_name}") else: raise ValueError(f"Invalid dataset_source: {config.dataset_source}") self.sample_analyzers = self._initialize_sample_analyzers() # Initialize dataframe analyzer with sample analyzers self.dataframe_analyzer = DataFrameAnalyzer(self.sample_analyzers) # Initialize analysis results as None self._analysis_results: Optional[DatasetAnalysisResult] = None self._merged_df: Optional[pd.DataFrame] = None self._message_df: Optional[pd.DataFrame] = None self._conversation_df: Optional[pd.DataFrame] = None self._analysis_summary: Optional[dict[str, Any]] = None # Decimal precision for rounding metrics self._decimal_precision = 2 def _get_schema_for_dataset(self) -> dict: """Get column schema configuration based on dataset type. Detects the appropriate schema based on the dataset class inheritance. Based on analysis of all 60 Oumi datasets: - 37 datasets (SFT/Vision-SFT/GRPO) convert to conversation format → use 'oumi' - 23 datasets (pretraining/DPO/KTO) maintain original structure → use specific Returns: Dictionary mapping column names to their configuration. """ # Detect dataset type based on the dataset class dataset_type = self._detect_dataset_type() try: return get_schema_for_format(dataset_type) except ValueError: # Fallback to conversation schema for unknown types logger.warning( f"Unknown dataset type '{dataset_type}', using conversation schema" ) return get_schema_for_format("oumi") def _detect_dataset_type(self) -> str: """Detect the dataset type based on the dataset class and configuration. Returns: String indicating the dataset type for schema selection. """ if self.dataset is None: # No dataset provided, use config format or default to conversation return getattr(self.config, "dataset_format", None) or "oumi" # Check dataset class inheritance hierarchy for accurate detection dataset_class_bases = [base.__name__ for base in self.dataset.__class__.__mro__] # Datasets that convert to conversation format during loading if any( base in dataset_class_bases for base in [ "BaseSftDataset", "VisionLanguageSftDataset", "BaseExperimentalGrpoDataset", ] ): return "oumi" # All convert to conversation format # Datasets that maintain original structure elif "BasePretrainingDataset" in dataset_class_bases: return "pretraining" elif any( base in dataset_class_bases for base in ["BaseDpoDataset", "VisionLanguageDpoDataset"] ): return "dpo" elif "BaseExperimentalKtoDataset" in dataset_class_bases: return "kto" else: # Check if we have explicit format from config config_format = getattr(self.config, "dataset_format", None) if config_format in [ "alpaca", "prompt_response", "dpo", "pretraining", "kto", ]: return config_format else: # Default to conversation format for unknown SFT-like datasets return "oumi" def _initialize_sample_analyzers(self) -> dict[str, Any]: """Initialize sample analyzer plugins from configuration. Returns: Dictionary mapping analyzer IDs to analyzer instances """ sample_analyzers = {} for analyzer_params in self.config.analyzers: try: # Get the analyzer class from the registry analyzer_class = REGISTRY.get_sample_analyzer(analyzer_params.id) if analyzer_class is None: raise ValueError( f"Sample analyzer '{analyzer_params.id}' not found in registry" ) # Prepare parameters for analyzer constructor analyzer_kwargs = dict(analyzer_params.params) if self.tokenizer is not None: analyzer_kwargs["tokenizer"] = self.tokenizer # Create analyzer instance with keyword arguments sample_analyzer = analyzer_class(**analyzer_kwargs) sample_analyzers[analyzer_params.id] = sample_analyzer logger.info(f"Initialized sample analyzer: {analyzer_params.id}") except Exception as e: logger.error( f"Failed to initialize sample analyzer {analyzer_params.id}: {e}" ) logger.error(f"Analyzer configuration: {analyzer_params}") return sample_analyzers
[docs] def analyze_dataset(self) -> None: """Analyze the dataset and store results internally. This method performs both message-level and conversation-level analysis using the configured sample analyzers. Each analyzer processes entire conversations and returns metrics for both individual messages and conversations as a whole. Results are stored internally and can be accessed via the query() method. Raises: ValueError: If no analyzers are configured for analysis. """ if not self.sample_analyzers: raise ValueError( "No analyzers configured for analysis. Please add at least one " "analyzer to the configuration before calling analyze_dataset()." ) logger.info(f"Starting analysis of dataset: {self.dataset_name}") logger.info( f"Using {len(self.sample_analyzers)} sample analyzers: " f"{list(self.sample_analyzers.keys())}" ) # Handle iterable datasets differently to avoid downloading everything if isinstance(self.dataset, BaseIterableDataset): # For iterable datasets, we can't get the total length without iterating # So we'll use the sample_count directly and iterate only what we need conversations_to_analyze = ( self.config.sample_count or 1000 ) # Default limit for streaming total_conversations = None # Unknown for iterable datasets logger.info( f"Analyzing up to {conversations_to_analyze} conversations from " f"streaming dataset" ) else: # For map datasets, we can get the total length total_conversations = len(self.dataset) conversations_to_analyze = min( total_conversations, self.config.sample_count or total_conversations ) logger.info( f"Analyzing {conversations_to_analyze} of {total_conversations} " f"conversations" ) dataframe_list, total_items, items_to_analyze = self._prepare_dataframe_list( conversations_to_analyze ) analysis_result = self.dataframe_analyzer.analyze_dataframe_list( input_data_list=dataframe_list, merge_on=["conversation_index", "conversation_id"], ) self._merged_df = analysis_result.merged_df self._message_df = analysis_result.messages_df self._conversation_df = analysis_result.conversations_df self._analysis_results = DatasetAnalysisResult( dataset_name=self.dataset_name or "", total_conversations=total_conversations or conversations_to_analyze, conversations_analyzed=conversations_to_analyze, ) # Generate and store the analysis summary after metrics are computed self._analysis_summary = self._generate_analysis_summary()
def _prepare_dataframe_list( self, max_items: Optional[int] = None ) -> tuple[list[DataFrameWithSchema], int, int]: """Prepare DataFrameWithSchema list from input source with optional limiting. Args: max_items: Maximum number of items to analyze (None for no limit) Returns: Tuple of (dataframe_list, total_items, items_to_analyze) """ if self.dataset is not None: # Conversation dataset input - convert to DataFrames if isinstance(self.dataset, BaseIterableDataset): # For iterable datasets, we can't get the total length total_items = max_items or 1000 # Use max_items or default items_to_analyze = total_items logger.info( f"Converting streaming dataset with up to {items_to_analyze} items" ) else: # For map datasets, we can get the total length total_items = len(self.dataset) logger.info(f"Converting conversation dataset with {total_items} items") # Determine how many items to analyze items_to_analyze = total_items if max_items is not None: items_to_analyze = min(total_items, max_items) if items_to_analyze < total_items: logger.info( f"Limiting analysis to first {max_items} " f"items (dataset has {total_items} total)" ) # Use utility function to convert dataset to DataFrames conversations_df, messages_df = convert_dataset_to_dataframes( dataset=self.dataset, items_to_analyze=items_to_analyze, dataset_name=self.dataset_name or "Unknown Dataset", ) schema = self._get_schema_for_dataset() dataframe_list = [ DataFrameWithSchema(conversations_df, schema, "conversations"), DataFrameWithSchema(messages_df, schema, "messages"), ] return dataframe_list, total_items, items_to_analyze else: raise ValueError("Either dataframes or dataset must be provided") @property def analysis_results(self) -> Optional[DatasetAnalysisResult]: """Get the analysis results if available. Returns: DatasetAnalysisResult if analysis has been run, None otherwise """ return self._analysis_results
[docs] def query(self, query_expression: str) -> pd.DataFrame: """Query the analysis results using pandas query syntax. Args: query_expression: Pandas query expression (e.g., "char_count > 10") Returns: DataFrame containing rows that match the query expression Raises: RuntimeError: If analysis has not been run yet. """ # Check if analysis has been run if self._merged_df is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to query the analysis results." ) # Apply the query filter try: filtered_df = self._merged_df.query(query_expression) logger.info(f"Query '{query_expression}' returned {len(filtered_df)} rows") except Exception as e: logger.error(f"Query failed: {e}") raise ValueError(f"Invalid query expression: {query_expression}") from e return filtered_df
@property def analysis_df(self) -> Union[pd.DataFrame, None]: """Get the merged analysis DataFrame with both message and conversation metrics. Returns: DataFrame with columns prefixed by message_ and conversation_ for each analyzer Raises: RuntimeError: If analysis has not been run yet. """ if self._merged_df is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to access the analysis DataFrame." ) return self._merged_df @property def message_df(self) -> Union[pd.DataFrame, None]: """Get the message-level analysis DataFrame. Returns: DataFrame with message-level metrics prefixed by message_ Raises: RuntimeError: If analysis has not been run yet. """ if self._message_df is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to access the message DataFrame." ) return self._message_df @property def conversation_df(self) -> Union[pd.DataFrame, None]: """Get the conversation-level analysis DataFrame. Returns: DataFrame with conversation-level metrics prefixed by conversation_ Raises: RuntimeError: If analysis has not been run yet. """ if self._conversation_df is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to access the conversation DataFrame." ) return self._conversation_df
[docs] def query_conversations( self, query_expression: str, ) -> pd.DataFrame: """Query conversation-level analysis results using pandas query expression. Args: query_expression: Pandas query expression to filter conversation analysis results Returns: DataFrame with filtered conversation analysis results Raises: RuntimeError: If analysis has not been run yet. Examples: # Filter for short conversations long_conversations = analyzer.query_conversations( "length_token_count > 1000" ) """ # Check if analysis has been run if self._conversation_df is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to query conversation results." ) # Apply the query filter try: filtered_df = self._conversation_df.query(query_expression) logger.info(f"Query '{query_expression}' returned {len(filtered_df)} rows") except Exception as e: logger.error(f"Query failed: {e}") raise ValueError(f"Invalid query expression '{query_expression}': {e}") return filtered_df
[docs] def filter( self, query_expression: str, ) -> Union[BaseMapDataset, BaseIterableDataset]: """Filter the original dataset based on analysis results. This method uses analysis results to filter the original dataset, returning a new dataset object containing only the conversations that match the query. Args: query_expression: Pandas query expression to filter analysis results Returns: A new dataset object containing only the filtered conversations Raises: RuntimeError: If analysis has not been run yet. Examples: # Filter for conversations with short messages short_dataset = analyzer.filter("length_word_count < 10") # Filter for conversations with assistant messages assistant_dataset = analyzer.filter("role == 'assistant'") # Filter for conversations with long user messages long_user_dataset = analyzer.filter( "role == 'user' and length_word_count > 100" ) """ # Get filtered analysis results filtered_df = self.query(query_expression) # Get unique conversation indices from filtered results conversation_indices = filtered_df.conversation_index.unique().tolist() # Create a new dataset with only the filtered conversations filtered_dataset = self._create_filtered_dataset(conversation_indices) # Get total dataset size, handling iterable datasets from oumi.core.datasets.base_iterable_dataset import BaseIterableDataset if isinstance(self.dataset, BaseIterableDataset): total_size = "unknown (streaming)" else: total_size = str(len(self.dataset)) logger.info( f"Filtered dataset: {len(conversation_indices)} conversations " f"out of {total_size} total" ) return filtered_dataset
def _create_filtered_dataset( self, conversation_indices: list[int] ) -> Union[BaseMapDataset, BaseIterableDataset]: """Create a new dataset containing only the specified conversations. Args: conversation_indices: List of conversation indices to include Returns: A new dataset object with the same format as the original """ # Deep copy the original dataset to preserve all attributes and methods filtered_dataset = copy.deepcopy(self.dataset) # Filter the DataFrame to only include the specified conversations # Note: This only works for map datasets, not iterable datasets from oumi.core.datasets.base_iterable_dataset import BaseIterableDataset if isinstance(self.dataset, BaseIterableDataset): # For iterable datasets, we can't filter by index # Return the original dataset as filtering is not supported return filtered_dataset original_df = self.dataset.data filtered_dataset._data = original_df.iloc[conversation_indices].copy() # Update the dataset name to indicate it's filtered filtered_dataset.dataset_name = f"{self.dataset.dataset_name}_filtered" return filtered_dataset def _generate_analysis_summary(self) -> dict[str, Any]: """Generate a comprehensive summary of dataset analysis results. This method aggregates metrics from all analyzers to provide insights useful for assessing datasets. It computes statistics like averages, standard deviations, min/max values, and efficiency metrics. Returns: Dictionary containing comprehensive dataset analysis summary with: - Dataset overview statistics - Message-level aggregated metrics - Conversation-level aggregated metrics """ # Check if we have data to analyze if self._merged_df is None or self._merged_df.empty: return {"error": "No analysis data available"} # TODO: Refactor summary methods to be dataset agnostic # Currently these methods assume conversation dataset structure with # messages/conversations. # They should be generalized to work with any dataset type and column structure. summary = { "dataset_overview": self._get_dataset_overview(), "message_level_summary": self._get_message_level_summary(), "conversation_level_summary": self._get_conversation_level_summary(), "conversation_turns": self._get_conversation_turns_summary(), } return summary @property def analysis_summary(self) -> dict[str, Any]: """Get the comprehensive analysis summary. Returns: Dictionary containing comprehensive dataset analysis summary Raises: RuntimeError: If analysis has not been run yet. """ if self._analysis_summary is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to generate the analysis summary." ) return self._analysis_summary def _get_dataset_overview(self) -> dict[str, Any]: """Get basic dataset overview statistics.""" if self._analysis_results is None: return {} return { "dataset_name": self._analysis_results.dataset_name, "total_conversations": self._analysis_results.total_conversations, "conversations_analyzed": self._analysis_results.conversations_analyzed, "dataset_coverage_percentage": round( 100.0 * self._analysis_results.conversations_analyzed / self._analysis_results.total_conversations if self._analysis_results.total_conversations > 0 else 0, self._decimal_precision, ), "total_messages": len(self._message_df) if self._message_df is not None else 0, "analyzers_used": list(self.sample_analyzers.keys()), } def _get_message_level_summary(self) -> dict[str, Any]: """Get aggregated message-level metrics across all analyzers.""" if self._message_df is None or self._message_df.empty: return {} # Get all analyzer columns (columns that are not base message columns) base_columns = { "conversation_index", "conversation_id", "message_index", "message_id", "role", "text_content", } analyzer_columns = [ col for col in self._message_df.columns if col not in base_columns and pd.api.types.is_numeric_dtype(self._message_df[col]) ] summary = {} for col in analyzer_columns: # Extract analyzer name and metric from column # Format: text_content_{analyzer}_{metric} # Example: text_content_length_analyzer_char_count parts = col.split("_") if len(parts) >= 5: # text_content_analyzer_metric_type if parts[0] == "text" and parts[1] == "content": # The analyzer name and metric are in the remaining parts # For "text_content_length_analyzer_char_count": # parts[2:] = ["length", "analyzer", "char", "count"] # We need to find where the analyzer name ends and metric begins # Look for known metric suffixes to split correctly remaining_parts = parts[2:] metric_suffixes = [ "char_count", "word_count", "sentence_count", "token_count", ] analyzer_name = None metric_name = None # Try to find a metric suffix for i in range( 1, len(remaining_parts) ): # Start from 1 to ensure analyzer_name is not empty potential_metric = "_".join(remaining_parts[i:]) if any( potential_metric.endswith(suffix) for suffix in metric_suffixes ): analyzer_name = "_".join(remaining_parts[:i]) metric_name = f"text_content_{potential_metric}" break # Fallback: assume last two parts are metric if analyzer_name is None: if len(remaining_parts) >= 2: analyzer_name = "_".join(remaining_parts[:-2]) metric_name = ( f"text_content_{remaining_parts[-2]}_" f"{remaining_parts[-1]}" ) if analyzer_name and metric_name: if analyzer_name not in summary: summary[analyzer_name] = {} # Compute statistics for numeric columns values = cast(pd.Series, self._message_df[col].dropna()) if len(values) > 0: summary[analyzer_name][metric_name] = compute_statistics( values, self._decimal_precision ) return summary def _get_conversation_level_summary(self) -> dict[str, Any]: """Get aggregated conversation-level metrics across all analyzers.""" if self._conversation_df is None or self._conversation_df.empty: return {} # Get all analyzer columns (columns that are not base conversation columns) base_columns = { "conversation_index", "conversation_id", "num_messages", } analyzer_columns = [ col for col in self._conversation_df.columns if col not in base_columns and pd.api.types.is_numeric_dtype(self._conversation_df[col]) ] summary = {} for col in analyzer_columns: # Use the same parsing logic as message level summary # Format: text_content_{analyzer}_{metric} # (for conversation-level aggregated metrics) parts = col.split("_") if len(parts) >= 5: # text_content_analyzer_metric_type if parts[0] == "text" and parts[1] == "content": remaining_parts = parts[2:] metric_suffixes = [ "char_count", "word_count", "sentence_count", "token_count", ] analyzer_name = None metric_name = None # Try to find a metric suffix for i in range( 1, len(remaining_parts) ): # Start from 1 to ensure analyzer_name is not empty potential_metric = "_".join(remaining_parts[i:]) if any( potential_metric.endswith(suffix) for suffix in metric_suffixes ): analyzer_name = "_".join(remaining_parts[:i]) metric_name = f"text_content_{potential_metric}" break # Fallback: assume last two parts are metric if analyzer_name is None: if len(remaining_parts) >= 2: analyzer_name = "_".join(remaining_parts[:-2]) metric_name = ( f"text_content_{remaining_parts[-2]}_" f"{remaining_parts[-1]}" ) if analyzer_name and metric_name: if analyzer_name not in summary: summary[analyzer_name] = {} # Compute statistics for numeric columns values = cast(pd.Series, self._conversation_df[col].dropna()) if len(values) > 0: summary[analyzer_name][metric_name] = compute_statistics( values, self._decimal_precision ) return summary def _get_conversation_turns_summary(self) -> dict[str, Any]: """Get conversation turn statistics summary. Returns: Dictionary containing conversation turn statistics """ if self._message_df is None or self._message_df.empty: return {} # groupby().size() always returns a Series, but we cast it because # type checker can't infer this turns_per_conversation = cast( pd.Series, self._message_df.groupby("conversation_id").size() ) return compute_statistics(turns_per_conversation, self._decimal_precision)