Source code for oumi.analyze.analyzers.quality

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data quality analyzer implementation."""

import re

from pydantic import BaseModel, Field

from oumi.analyze.base import ConversationAnalyzer
from oumi.core.registry import register_sample_analyzer
from oumi.core.types.conversation import Conversation, Message

__all__ = ["DataQualityMetrics", "DataQualityAnalyzer"]

# Invalid serialization patterns: (regex, display_name)
_INVALID_VALUE_PATTERNS = [
    (re.compile(r"\bNaN\b"), "NaN"),
    (re.compile(r"\bnan\b"), "nan"),
    (re.compile(r"\bnull\b"), "null"),
    (re.compile(r"\bNone\b"), "None"),
    (re.compile(r"\bundefined\b"), "undefined"),
]


[docs] class DataQualityMetrics(BaseModel): """Result model for data quality checks on a conversation. Example: >>> result = DataQualityMetrics( ... has_non_alternating_turns=False, ... has_empty_turns=False, ... empty_turn_count=0, ... has_invalid_values=False, ... invalid_value_patterns=[], ... ) >>> print(result.has_non_alternating_turns) False """ has_non_alternating_turns: bool = Field( description=( "True if non-system messages do NOT strictly alternate between " "user and assistant roles (i.e. consecutive same-role messages exist)" ) ) has_empty_turns: bool = Field( description="True if any message has empty or whitespace-only content" ) empty_turn_count: int = Field( description="Number of messages with empty or whitespace-only content" ) has_invalid_values: bool = Field( description=( "True if any message contains values serialized as strings " "(e.g. 'NaN', 'null', 'None', 'undefined')" ) ) invalid_value_patterns: list[str] = Field( description="List of invalid value patterns found across all messages" )
[docs] @register_sample_analyzer("quality") class DataQualityAnalyzer(ConversationAnalyzer[DataQualityMetrics]): """Analyzer for basic data quality checks on conversations. Checks for three common data quality issues without requiring an LLM: - Non-alternating user/assistant message patterns - Empty or whitespace-only turns - Values serialized as strings (NaN, null, None, undefined) Example: >>> from oumi.analyze.analyzers.quality import DataQualityAnalyzer >>> from oumi.core.types.conversation import Conversation, Message, Role >>> >>> analyzer = DataQualityAnalyzer() >>> conversation = Conversation(messages=[ ... Message(role=Role.USER, content="Hello"), ... Message(role=Role.ASSISTANT, content="Hi there!"), ... ]) >>> result = analyzer.analyze(conversation) >>> print(result.has_non_alternating_turns) False """ _result_model = DataQualityMetrics
[docs] @classmethod def get_config_schema(cls) -> dict: """Get JSON schema for DataQualityAnalyzer configuration.""" return {"properties": {}}
[docs] def analyze(self, conversation: Conversation) -> DataQualityMetrics: """Analyze data quality for a conversation. Args: conversation: The conversation to analyze. Returns: DataQualityMetrics with the quality check results. """ # 1. Non-alternating turns check (ignoring system messages) roles = [m.role.value for m in conversation.messages] non_system = [r for r in roles if r != "system"] has_non_alternating = False for i in range(1, len(non_system)): if non_system[i] == non_system[i - 1]: has_non_alternating = True break # 2. Empty turns check def _text(m: Message) -> str: return DataQualityAnalyzer.get_text_content(m) empty_count = sum(1 for m in conversation.messages if not _text(m).strip()) # 3. Invalid serialized values check patterns_found: set[str] = set() for message in conversation.messages: content = _text(message) for pattern, name in _INVALID_VALUE_PATTERNS: if pattern.search(content): patterns_found.add(name) return DataQualityMetrics( has_non_alternating_turns=has_non_alternating, has_empty_turns=empty_count > 0, empty_turn_count=empty_count, has_invalid_values=len(patterns_found) > 0, invalid_value_patterns=sorted(patterns_found), )