# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from typing import Any, Optional, Union
import pandas as pd
from tqdm import tqdm
from oumi.builders.models import build_tokenizer
from oumi.core.configs.analyze_config import AnalyzeConfig
from oumi.core.configs.params.model_params import ModelParams
from oumi.core.datasets.base_iterable_dataset import BaseIterableDataset
from oumi.core.datasets.base_map_dataset import BaseMapDataset
from oumi.core.registry.registry import REGISTRY
from oumi.core.types.conversation import Conversation
logger = logging.getLogger(__name__)
[docs]
def build_tokenizer_from_config(tokenizer_config: Optional[dict[str, Any]]):
"""Build a tokenizer from configuration dictionary.
Args:
tokenizer_config: Dictionary containing tokenizer configuration
Returns:
Built tokenizer or None if config is None
Raises:
ValueError: If required fields are missing from tokenizer_config
"""
if not tokenizer_config:
return None
if "model_name" not in tokenizer_config:
raise ValueError("tokenizer_config must contain 'model_name' field")
model_params = ModelParams(
model_name=tokenizer_config["model_name"],
tokenizer_kwargs=tokenizer_config.get("tokenizer_kwargs", {}),
trust_remote_code=tokenizer_config.get("trust_remote_code", False),
)
tokenizer = build_tokenizer(model_params)
logger.info(f"Built tokenizer for model: {model_params.model_name}")
return tokenizer
[docs]
def load_dataset_from_config(
config: AnalyzeConfig, tokenizer: Optional[Any] = None
) -> Union[BaseMapDataset, BaseIterableDataset]:
"""Load dataset based on configuration.
This function loads datasets directly from the registry for analysis purposes.
If a tokenizer is provided, it will be passed to the dataset constructor.
For custom datasets, it supports loading from local files using
TextSftJsonLinesDataset for text data and VLJsonlinesDataset for
vision-language data.
Args:
config: Configuration object containing dataset parameters
tokenizer: Optional tokenizer to use with the dataset
Returns:
Loaded dataset
"""
dataset_name = config.dataset_name
split = config.split
subset = config.subset
dataset_path = config.dataset_path
dataset_format = config.dataset_format
if not dataset_name and not dataset_path:
raise ValueError("Either dataset_name or dataset_path must be provided")
# Handle custom dataset loading from local files
if dataset_path:
return _load_custom_dataset_from_path(
dataset_path, dataset_format, tokenizer, config
)
# Handle registered dataset loading
try:
# Load dataset from the REGISTRY
if dataset_name is None:
raise ValueError("dataset_name cannot be None for registered datasets")
dataset_class = REGISTRY.get_dataset(dataset_name, subset=subset)
if dataset_class is not None:
# Check if this is an iterable dataset that supports streaming
import inspect
from oumi.core.datasets.base_iterable_dataset import BaseIterableDataset
# Ensure dataset_class is actually a class before using issubclass
is_iterable_dataset = inspect.isclass(dataset_class) and issubclass(
dataset_class, BaseIterableDataset
)
# For iterable datasets, force streaming mode to avoid downloading all
if is_iterable_dataset:
logger.info(
f"Using streaming mode for iterable dataset: {dataset_name}"
)
# Don't modify split for iterable datasets - streaming handles limiting
else:
# For map datasets, modify split to include slicing if sample_count set
if config.sample_count is not None and config.sample_count > 0:
# Use a larger slice (10x sample_count) to ensure enough data
# after any filtering that might happen in the dataset
slice_size = config.sample_count * 10
if "[" not in split: # Only add slicing if not already present
split = f"{split}[:{slice_size}]"
# Prepare dataset constructor arguments
dataset_kwargs = {
"dataset_name": dataset_name,
"dataset_path": None,
"split": split,
"subset": subset,
"trust_remote_code": config.trust_remote_code,
}
# Force streaming for iterable datasets
if is_iterable_dataset:
dataset_kwargs["stream"] = True
# Add tokenizer if provided
if tokenizer is not None:
dataset_kwargs["tokenizer"] = tokenizer
# Add processor parameters for vision-language datasets
if config.processor_name:
dataset_kwargs["processor_name"] = config.processor_name
dataset_kwargs["processor_kwargs"] = config.processor_kwargs
dataset_kwargs["trust_remote_code"] = config.trust_remote_code
# Add required parameters for pretraining datasets
if is_iterable_dataset:
# Import here to avoid circular imports
from oumi.core.datasets.base_pretraining_dataset import (
BasePretrainingDataset,
)
if inspect.isclass(dataset_class) and issubclass(
dataset_class, BasePretrainingDataset
):
# Pretraining datasets require seq_length and tokenizer
if "seq_length" not in dataset_kwargs:
dataset_kwargs["seq_length"] = 64 # Default sequence length
if tokenizer is None:
# Create a default tokenizer if none provided
from oumi.builders.models import build_tokenizer
from oumi.core.configs.params.model_params import ModelParams
model_params = ModelParams(
model_name="openai-community/gpt2",
tokenizer_kwargs={"pad_token": "<|endoftext|>"},
)
dataset_kwargs["tokenizer"] = build_tokenizer(model_params)
# Load registered dataset with parameters
# Check if dataset_class is callable (class or mock for testing)
if not (inspect.isclass(dataset_class) or callable(dataset_class)):
raise TypeError(
f"Expected class or callable, got {type(dataset_class)} for "
f"dataset {dataset_name}"
)
dataset = dataset_class(**dataset_kwargs)
# Ensure we return a supported dataset type
if isinstance(dataset, (BaseMapDataset, BaseIterableDataset)):
return dataset
else:
raise NotImplementedError(
f"Dataset type {type(dataset)} is not supported for analysis. "
"Please use a dataset that inherits from BaseMapDataset or "
"BaseIterableDataset."
)
else:
# TODO: Implement HuggingFace Hub loading
raise NotImplementedError(
f"Dataset '{dataset_name}' is not registered in the REGISTRY. "
"Loading from HuggingFace Hub is not yet implemented."
)
except Exception as e:
logger.error(f"Failed to load dataset {dataset_name}: {e}")
raise
def _load_custom_dataset_from_path(
dataset_path: str,
dataset_format: Optional[str],
tokenizer: Optional[Any],
config: AnalyzeConfig,
) -> BaseMapDataset:
"""Load a custom dataset from a local file path.
Args:
dataset_path: Path to the dataset file
dataset_format: Format of the dataset ('oumi' or 'alpaca') - required for
custom datasets
tokenizer: Optional tokenizer to use with the dataset
config: Configuration object containing additional parameters
Returns:
Loaded dataset (TextSftJsonLinesDataset or VLJsonlinesDataset)
"""
# Import here to avoid circular imports
from oumi.datasets.sft.sft_jsonlines import TextSftJsonLinesDataset
from oumi.datasets.vision_language.vision_jsonlines import VLJsonlinesDataset
path = Path(dataset_path)
if not path.exists():
raise FileNotFoundError(f"Dataset file not found: {dataset_path}")
if not path.is_file():
raise ValueError(
f"Dataset path must be a file, not a directory: {dataset_path}"
)
# Multimodal handling is explicit via config.is_multimodal
if config.is_multimodal is True:
# Note: processor_name requirement is already validated in AnalyzeConfig
dataset_kwargs = {
"dataset_path": str(path),
"tokenizer": tokenizer,
"processor_name": config.processor_name,
"processor_kwargs": config.processor_kwargs,
"trust_remote_code": config.trust_remote_code,
}
dataset_kwargs = {k: v for k, v in dataset_kwargs.items() if v is not None}
dataset = VLJsonlinesDataset(**dataset_kwargs)
logger.info(f"Loaded vision-language dataset from: {dataset_path}")
return dataset
elif config.is_multimodal is False:
# If explicitly forced to text, load as text-only
dataset_kwargs = {
"dataset_path": str(path),
"format": dataset_format,
}
if tokenizer is not None:
dataset_kwargs["tokenizer"] = tokenizer
dataset_kwargs = {k: v for k, v in dataset_kwargs.items() if v is not None}
dataset = TextSftJsonLinesDataset(**dataset_kwargs)
logger.info(f"Loaded text dataset from: {dataset_path}")
return dataset
else:
# This should never happen due to config validation
# is_multimodal=None case is already caught by AnalyzeConfig.__post_init__
raise ValueError("Invalid vision-language configuration")
[docs]
def compute_statistics(series: pd.Series, decimal_precision: int = 2) -> dict[str, Any]:
"""Compute statistics for a pandas Series.
This utility function handles edge cases like empty series or single-element
series, ensuring that standard deviation is 0.0 for single values instead
of NaN.
Args:
series: Pandas Series containing numeric values
decimal_precision: Number of decimal places for rounding
Returns:
Dictionary with computed statistics (count, mean, std, min, max, median)
"""
if series.empty:
return {
"count": 0,
"mean": 0.0,
"std": 0.0,
"min": 0,
"max": 0,
"median": 0.0,
}
if len(series) == 1:
single_value = round(float(series.iloc[0]), decimal_precision)
return {
"count": 1,
"mean": single_value,
"std": 0.0, # Standard deviation is 0 for single value
"min": single_value,
"max": single_value,
"median": single_value,
}
return {
"count": len(series),
"mean": round(series.mean(), decimal_precision),
"std": round(series.std(), decimal_precision),
"min": round(series.min(), decimal_precision),
"max": round(series.max(), decimal_precision),
"median": round(series.median(), decimal_precision),
}
[docs]
def conversation_to_dataframes(
conversation: Conversation, conversation_id: str, conversation_idx: int
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Convert a single conversation to separate conversation and message DataFrames.
This creates two DataFrames: one for conversation-level data and one for
message-level data, suitable for comprehensive dataset analysis.
Args:
conversation: The conversation object to convert
conversation_id: ID of the conversation
conversation_idx: Index of the conversation
Returns:
Tuple of (conversation_df, message_df)
"""
# Create conversation-level data
conversation_data = {
"conversation_index": conversation_idx,
"conversation_id": conversation_id,
"num_messages": len(conversation.messages),
}
conversation_df = pd.DataFrame([conversation_data])
# Create message-level data
messages_data = []
for msg_idx, message in enumerate(conversation.messages):
text_content = (
message.content
if isinstance(message.content, str)
else message.compute_flattened_text_content()
)
messages_data.append(
{
"conversation_index": conversation_idx,
"conversation_id": conversation_id,
"message_index": msg_idx,
"message_id": message.id or f"msg_{msg_idx}",
"role": message.role.value,
"text_content": text_content,
}
)
message_df = pd.DataFrame(messages_data)
return conversation_df, message_df
[docs]
def convert_dataset_to_dataframes(
dataset, # Union[BaseMapDataset, BaseIterableDataset]
items_to_analyze: int,
dataset_name: str = "Dataset",
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Convert a dataset to conversations and messages DataFrames.
This method handles different dataset types:
- SFT/Vision-SFT/GRPO datasets: Convert via conversation() method
- DPO/KTO/Pretraining datasets: Convert raw data structure
- Iterable datasets: Stream and limit to items_to_analyze
Args:
dataset: The dataset to process (BaseMapDataset or BaseIterableDataset)
items_to_analyze: Number of items to analyze
dataset_name: Name of the dataset for progress display
Returns:
Tuple of (conversations_df, messages_df) ready for analysis
Raises:
ValueError: If dataset is not provided
"""
if dataset is None:
raise ValueError("Dataset must be provided for conversation processing")
# Handle iterable datasets (streaming datasets like C4)
if isinstance(dataset, BaseIterableDataset):
return _convert_iterable_dataset_to_dataframes(
dataset, items_to_analyze, dataset_name
)
# Check if dataset has conversation() method (SFT/Vision-SFT/GRPO datasets)
elif hasattr(dataset, "conversation") and callable(
getattr(dataset, "conversation")
):
return _convert_conversation_dataset_to_dataframes(
dataset, items_to_analyze, dataset_name
)
else:
# For non-conversation datasets (DPO, KTO, pretraining), convert raw data
return _convert_raw_dataset_to_dataframes(
dataset, items_to_analyze, dataset_name
)
def _convert_iterable_dataset_to_dataframes(
dataset: BaseIterableDataset, items_to_analyze: int, dataset_name: str
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Convert iterable datasets (streaming) to DataFrames.
This handles datasets like C4, Pile, etc. that are streamed and don't support index.
We iterate through the dataset and stop after items_to_analyze items.
"""
from oumi.utils.logging import logger
# For now, create empty conversations DataFrame since these datasets don't have
# conversation structure
conversations_df = pd.DataFrame()
# Stream raw data and limit to items_to_analyze
raw_data_list = []
try:
# Use tqdm to show progress while streaming
dataset_iter = iter(dataset)
for idx in tqdm(
range(items_to_analyze),
desc=f"Streaming {dataset_name} data",
unit="item",
):
try:
# Get next item from the iterator
raw_item = next(dataset_iter)
# Convert to dict if it's not already
if hasattr(raw_item, "to_dict") and callable(
getattr(raw_item, "to_dict")
):
raw_item = raw_item.to_dict() # type: ignore
elif not isinstance(raw_item, dict):
# For pretraining datasets, the item might be a tensor or list
# Convert to a simple dict structure
raw_item = {"input_ids": raw_item, "item_index": idx}
# Add index information
raw_item["item_index"] = idx
raw_data_list.append(raw_item)
except StopIteration:
# Dataset ended before we reached items_to_analyze
logger.info(
f"Dataset ended after {idx} items (requested {items_to_analyze})"
)
break
except Exception as e:
logger.warning(f"Failed to process item {idx} from {dataset_name}: {e}")
continue
except Exception as e:
logger.error(f"Failed to iterate over dataset {dataset_name}: {e}")
# Return empty DataFrames if we can't iterate
return pd.DataFrame(), pd.DataFrame()
# Create a DataFrame from raw data for analysis
if raw_data_list:
messages_df = pd.DataFrame(raw_data_list)
# Add required columns for analysis compatibility
messages_df["conversation_index"] = messages_df["item_index"]
messages_df["message_index"] = (
0 # Single message per item for non-conversation data
)
messages_df["message_id"] = messages_df["item_index"].apply(
lambda x: f"msg_{x}"
)
else:
messages_df = pd.DataFrame()
return conversations_df, messages_df
def _convert_conversation_dataset_to_dataframes(
dataset, items_to_analyze: int, dataset_name: str
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Convert datasets with conversation() method to DataFrames."""
conversation_df_list = []
message_df_list = []
for conversation_idx in tqdm(
range(items_to_analyze),
desc=f"Converting {dataset_name} to DataFrames",
unit="item",
):
conversation = dataset.conversation(conversation_idx)
conversation_id = conversation.conversation_id or str(conversation_idx)
conversation_df, message_df = conversation_to_dataframes(
conversation, conversation_id, conversation_idx
)
# Collect all DataFrames for concatenation
if not conversation_df.empty:
conversation_df_list.append(conversation_df)
if not message_df.empty:
message_df_list.append(message_df)
# Create complete DataFrames by concatenating all individual DataFrames
conversations_df = (
pd.concat(conversation_df_list, ignore_index=True)
if conversation_df_list
else pd.DataFrame()
)
messages_df = (
pd.concat(message_df_list, ignore_index=True)
if message_df_list
else pd.DataFrame()
)
return conversations_df, messages_df
def _convert_raw_dataset_to_dataframes(
dataset, items_to_analyze: int, dataset_name: str
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Convert datasets without conversation() method to DataFrames.
This handles DPO, KTO, and pretraining datasets that maintain their original.
"""
from oumi.utils.logging import logger
# For now, create empty conversations DataFrame since these datasets don't have
# conversation structure
conversations_df = pd.DataFrame()
# Get raw data structure for analysis
raw_data_list = []
for idx in tqdm(
range(items_to_analyze),
desc=f"Converting {dataset_name} raw data",
unit="item",
):
try:
# Get raw data from the dataset
if hasattr(dataset, "raw"):
raw_item = dataset.raw(idx)
else:
raw_item = dataset[idx]
# Convert to dict if it's a pandas Series
if hasattr(raw_item, "to_dict"):
raw_item = raw_item.to_dict()
# Add index information
raw_item["item_index"] = idx
raw_data_list.append(raw_item)
except Exception as e:
logger.warning(f"Failed to process item {idx} from {dataset_name}: {e}")
continue
# Create a DataFrame from raw data for analysis
if raw_data_list:
messages_df = pd.DataFrame(raw_data_list)
# Add required columns for analysis compatibility
messages_df["conversation_index"] = messages_df["item_index"]
messages_df["message_index"] = (
0 # Single message per item for non-conversation data
)
messages_df["message_id"] = messages_df["item_index"].apply(
lambda x: f"msg_{x}"
)
else:
messages_df = pd.DataFrame()
return conversations_df, messages_df
[docs]
def get_conversation_schema() -> dict:
"""Get column configuration for conversation format (oumi format).
Returns:
Dictionary mapping column names to their configuration.
"""
from oumi.core.analyze.column_types import ColumnType, ContentType
return {
# Conversation DataFrame columns
"conversation_index": {
"type": ColumnType.INT,
"content_type": ContentType.METADATA,
"description": "Conversation index in dataset",
},
"conversation_id": {
"type": ColumnType.STRING,
"content_type": ContentType.METADATA,
"description": "Conversation identifier",
},
"num_messages": {
"type": ColumnType.INT,
"content_type": ContentType.NUMERIC,
"description": "Number of messages in conversation",
},
# Message DataFrame columns
"message_index": {
"type": ColumnType.INT,
"content_type": ContentType.METADATA,
"description": "Message index within conversation",
},
"message_id": {
"type": ColumnType.STRING,
"content_type": ContentType.METADATA,
"description": "Message identifier",
},
"role": {
"type": ColumnType.STRING,
"content_type": ContentType.CATEGORICAL,
"description": "Message role (user/assistant/system)",
},
"text_content": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Message text content",
},
# Additional fields that might be present
"timestamp": {
"type": ColumnType.TIMESTAMP,
"content_type": ContentType.METADATA,
"description": "Message timestamp",
},
"processing_time": {
"type": ColumnType.FLOAT,
"content_type": ContentType.NUMERIC,
"description": "AI processing time in seconds",
},
"model": {
"type": ColumnType.STRING,
"content_type": ContentType.METADATA,
"description": "Model used for generation",
},
"temperature": {
"type": ColumnType.FLOAT,
"content_type": ContentType.METADATA,
"description": "Sampling temperature",
},
"max_tokens": {
"type": ColumnType.INT,
"content_type": ContentType.METADATA,
"description": "Maximum tokens to generate",
},
}
[docs]
def get_alpaca_schema() -> dict:
"""Get column configuration for alpaca format (instruction format).
Returns:
Dictionary mapping column names to their configuration.
"""
from oumi.core.analyze.column_types import ColumnType, ContentType
return {
"instruction": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Instruction or prompt text",
},
"input": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Input context or additional information",
},
"output": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Expected output or response",
},
# Common additional fields
"text": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Combined text content",
},
}
[docs]
def get_prompt_response_schema() -> dict:
"""Get column configuration for prompt/response format.
Returns:
Dictionary mapping column names to their configuration.
"""
from oumi.core.analyze.column_types import ColumnType, ContentType
return {
"prompt": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Input prompt or question",
},
"response": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Model response or answer",
},
# Common variations
"instruction": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Instruction or prompt text",
},
"output": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Expected output or response",
},
}
[docs]
def get_dpo_schema() -> dict:
"""Get column configuration for DPO (preference tuning) format.
Returns:
Dictionary mapping column names to their configuration.
"""
from oumi.core.analyze.column_types import ColumnType, ContentType
return {
"prompt": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Input prompt for preference comparison",
},
"chosen": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Preferred response",
},
"rejected": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Rejected response",
},
# Metadata fields
"score_chosen": {
"type": ColumnType.FLOAT,
"content_type": ContentType.NUMERIC,
"description": "Score for chosen response",
},
"score_rejected": {
"type": ColumnType.FLOAT,
"content_type": ContentType.NUMERIC,
"description": "Score for rejected response",
},
}
[docs]
def get_pretraining_schema() -> dict:
"""Get column configuration for pretraining format.
Returns:
Dictionary mapping column names to their configuration.
"""
from oumi.core.analyze.column_types import ColumnType, ContentType
return {
"text": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Raw text content for pretraining",
},
"content": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Text content (alternative field name)",
},
# Metadata fields common in pretraining datasets
"url": {
"type": ColumnType.STRING,
"content_type": ContentType.METADATA,
"description": "Source URL",
},
"timestamp": {
"type": ColumnType.TIMESTAMP,
"content_type": ContentType.METADATA,
"description": "Content timestamp",
},
"source": {
"type": ColumnType.STRING,
"content_type": ContentType.CATEGORICAL,
"description": "Data source identifier",
},
"language": {
"type": ColumnType.STRING,
"content_type": ContentType.CATEGORICAL,
"description": "Content language",
},
"quality_score": {
"type": ColumnType.FLOAT,
"content_type": ContentType.NUMERIC,
"description": "Content quality score",
},
"length": {
"type": ColumnType.INT,
"content_type": ContentType.NUMERIC,
"description": "Content length in characters",
},
}
[docs]
def get_kto_schema() -> dict:
"""Get column configuration for KTO (binary feedback) format.
Returns:
Dictionary mapping column names to their configuration.
"""
from oumi.core.analyze.column_types import ColumnType, ContentType
return {
"prompt": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Input prompt",
},
"completion": {
"type": ColumnType.STRING,
"content_type": ContentType.TEXT,
"description": "Model completion/response",
},
"label": {
"type": ColumnType.BOOL,
"content_type": ContentType.CATEGORICAL,
"description": "Binary feedback label (True=desirable, False=undesirable)",
},
# Additional fields
"score": {
"type": ColumnType.FLOAT,
"content_type": ContentType.NUMERIC,
"description": "Numeric score for the completion",
},
}