Source code for oumi.core.synthesis.attribute_synthesizer

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from dataclasses import dataclass, field

from oumi.builders.inference_engines import build_inference_engine
from oumi.core.configs.inference_config import InferenceConfig
from oumi.core.configs.inference_engine_type import InferenceEngineType
from oumi.core.configs.params.synthesis_params import (
    GeneralSynthesisParams,
    GeneratedAttribute,
    GeneratedAttributePostprocessingParams,
    TextMessage,
)
from oumi.core.inference.base_inference_engine import BatchResult
from oumi.core.synthesis.attribute_formatter import AttributeFormatter
from oumi.core.types.conversation import Conversation, Message
from oumi.inference.remote_inference_engine import BatchInfo
from oumi.utils.logging import logger


[docs] @dataclass class SynthBatchResult: """Result from partial batch synthesis, separating successes from failures.""" successful: list[tuple[int, dict[str, str]]] """List of (original_index, attribute_dict) for successfully synthesized items.""" failed_indices: list[int] """Indices of items that failed synthesis.""" error_messages: dict[int, str] = field(default_factory=dict) """Mapping of failed index to error message.""" @property def has_failures(self) -> bool: """Return True if any items in the batch failed synthesis.""" return len(self.failed_indices) > 0
[docs] class AttributeSynthesizer: """Synthesizes values for a generated attribute based on the given samples. Args: params: The parameters for the attribute synthesizer. inference_config: The configuration for the inference engine. """ def __init__( self, params: GeneralSynthesisParams, inference_config: InferenceConfig, ): """Initialize the synthesizer.""" self._params = params self._formatter = AttributeFormatter(params) self._inference_engine = build_inference_engine( engine_type=inference_config.engine or InferenceEngineType.NATIVE, model_params=inference_config.model, remote_params=inference_config.remote_params, ) self._inference_config = inference_config self._total_input_tokens: int = 0 self._total_output_tokens: int = 0 self._total_cached_tokens: int = 0
[docs] def synthesize( self, samples: list[dict], generated_attribute: GeneratedAttribute, ) -> list[dict[str, str]]: """Synthesize a value for the generated attribute. Order will be identical to the order of the samples. Args: samples: The samples to synthesize values for. generated_attribute: The generated attribute to synthesize a value for. Returns: A list of dictionaries, one for each sample, with the generated attribute value added to the dictionary. """ inference_conversations: list[Conversation] = [] for sample in samples: inference_conversations.append( self._format_instructions( sample, generated_attribute.instruction_messages, ) ) inference_results = self._inference_engine.infer( inference_conversations, inference_config=self._inference_config, ) self._accumulate_token_usage(inference_results) return self._process_inference_results(inference_results, generated_attribute)
[docs] def synthesize_batch( self, samples: list[dict], generated_attribute: GeneratedAttribute, ) -> str: """Submit a batch inference job for attribute synthesis. Args: samples: The samples to synthesize values for. generated_attribute: The generated attribute to synthesize a value for. Returns: The batch ID that can be used with get_batch_status and get_batch_results. Raises: NotImplementedError: If the inference engine does not support batch. """ if not hasattr(self._inference_engine, "infer_batch"): raise NotImplementedError( f"Inference engine {type(self._inference_engine).__name__} does not " "support batch inference. Use synthesize() instead." ) conversations = self._build_batch_conversations(samples, generated_attribute) batch_id = self._inference_engine.infer_batch( # type: ignore[attr-defined] conversations, inference_config=self._inference_config, ) return batch_id
[docs] def get_batch_status(self, batch_id: str) -> BatchInfo: """Get the status of a batch inference job. Args: batch_id: The batch ID returned from synthesize_batch(). Returns: BatchInfo containing the job status and progress information. Raises: NotImplementedError: If the inference engine does not support batch. """ if not hasattr(self._inference_engine, "get_batch_status"): raise NotImplementedError( f"Inference engine {type(self._inference_engine).__name__} does not " "support batch inference." ) return self._inference_engine.get_batch_status(batch_id) # type: ignore[attr-defined]
[docs] def get_batch_results( self, batch_id: str, samples: list[dict], generated_attribute: GeneratedAttribute, ) -> list[dict[str, str]]: """Get results from a completed batch inference job. Args: batch_id: The batch ID returned from synthesize_batch(). samples: The original samples that were submitted for synthesis. generated_attribute: The generated attribute configuration. Returns: A list of dictionaries, one for each sample, with the generated attribute value added to the dictionary (same format as synthesize()). Raises: NotImplementedError: If the inference engine does not support batch. RuntimeError: If any items failed synthesis. """ result = self.get_batch_results_partial(batch_id, samples, generated_attribute) if result.has_failures: first_idx = result.failed_indices[0] raise RuntimeError( f"Synthesis batch {batch_id} failed for " f"{len(result.failed_indices)} items. " f"First error (index {first_idx}): " f"{result.error_messages.get(first_idx, 'unknown')}" ) return [output for _, output in sorted(result.successful)]
[docs] def get_batch_results_partial( self, batch_id: str, samples: list[dict], generated_attribute: GeneratedAttribute, ) -> SynthBatchResult: """Get partial results from a completed batch inference job. This method returns successful results alongside failure information. Parse failures from _process_inference_results are also captured. Args: batch_id: The batch ID returned from synthesize_batch(). samples: The original samples that were submitted for synthesis. generated_attribute: The generated attribute configuration. Returns: SynthBatchResult with successful outputs and failure info. Raises: NotImplementedError: If the inference engine does not support batch. """ conversations = self._build_batch_conversations(samples, generated_attribute) logger.info( f"Retrieving partial synthesis results for batch {batch_id} " f"({len(conversations)} conversations)" ) batch_result: BatchResult = self._inference_engine.get_batch_results_partial( batch_id, conversations ) successful_outputs: list[tuple[int, dict[str, str]]] = [] failed_indices: list[int] = list(batch_result.failed_indices) error_messages: dict[int, str] = dict(batch_result.error_messages) parse_failures = 0 for idx, conv in batch_result.successful: try: processed = self._process_inference_results([conv], generated_attribute) successful_outputs.append((idx, processed[0])) self._accumulate_token_usage([conv]) except Exception as e: parse_failures += 1 failed_indices.append(idx) error_messages[idx] = f"Failed to process synthesis output: {e}" logger.warning( f"Batch {batch_id} request {idx}: " f"failed to process synthesis output: {e}" ) logger.info( f"Batch {batch_id} synthesis results: " f"{len(successful_outputs)} processed successfully, " f"{len(batch_result.failed_indices)} inference failures, " f"{parse_failures} parse failures" ) return SynthBatchResult( successful=successful_outputs, failed_indices=failed_indices, error_messages=error_messages, )
def _build_batch_conversations( self, samples: list[dict], generated_attribute: GeneratedAttribute, ) -> list[Conversation]: """Build inference conversations from samples for batch operations.""" return [ self._format_instructions(sample, generated_attribute.instruction_messages) for sample in samples ] @property def total_input_tokens(self) -> int: """Total input/prompt tokens accumulated across all synthesize() calls.""" return self._total_input_tokens @property def total_output_tokens(self) -> int: """Total output/completion tokens accumulated across all synthesize() calls.""" return self._total_output_tokens @property def total_cached_tokens(self) -> int: """Total cached tokens accumulated across all synthesize() calls.""" return self._total_cached_tokens def _accumulate_token_usage(self, inference_results: list[Conversation]) -> None: """Accumulate token usage from inference response metadata.""" for result in inference_results: usage = result.metadata.get("usage", {}) self._total_input_tokens += usage.get("prompt_tokens", 0) self._total_output_tokens += usage.get("completion_tokens", 0) self._total_cached_tokens += usage.get("cached_tokens", 0) def _extract_response( self, inference_conversations: list[Conversation], ) -> list[str]: """Get the inference results from the inference conversations. If the inference result is not a string, an empty string will be returned. """ return [ inference_result.messages[-1].content if isinstance(inference_result.messages[-1].content, str) else "" for inference_result in inference_conversations ] def _format_instructions( self, sample: dict, instruction_messages: list[TextMessage], ) -> Conversation: """Format the instructions for the sample.""" new_messages = [] for turn in instruction_messages: if not isinstance(turn.content, str): new_messages.append(turn) continue formatted_content = self._formatter.format( sample, turn.content, missing_values_allowed=False, ) new_message = Message( role=turn.role, content=formatted_content, ) new_messages.append(new_message) return Conversation(messages=new_messages) def _process_inference_results( self, inference_results: list[Conversation], generated_attribute: GeneratedAttribute, ) -> list[dict[str, str]]: """Extract and postprocess inference results. Args: inference_results: The inference results from the inference engine. generated_attribute: The generated attribute configuration. Returns: A list of dictionaries with the processed attribute values. """ original_responses = self._extract_response(inference_results) if not generated_attribute.postprocessing_params: return [ {generated_attribute.id: response} for response in original_responses ] keep_original = ( generated_attribute.postprocessing_params.keep_original_text_attribute ) if keep_original: records = [ {generated_attribute.id: response} for response in original_responses ] else: records = [{} for _ in original_responses] for i, original_response in enumerate(original_responses): new_id = generated_attribute.postprocessing_params.id new_response = original_response try: new_response = self._postprocess_sample( original_response, generated_attribute.postprocessing_params ) except ValueError as e: logger.warning( f"Error postprocessing inference result: {e}. Leaving as-is and " "skipping." ) finally: records[i][new_id] = new_response return records def _postprocess_sample( self, response: str, postprocessing_params: GeneratedAttributePostprocessingParams, ) -> str: """Postprocess the response, removing extraneous text. Order of operations: 1. If regex is provided, use the first match. 2. Cut off everything before the first occurrence of the prefix and after the last occurrence of the suffix. 3. Strip whitespace. 4. Add prefix and suffix to what remains. Args: response: The response to postprocess. postprocessing_params: The postprocessing parameters. Returns: The postprocessed response. """ if postprocessing_params.regex: match = re.search(postprocessing_params.regex, response) if match: response = match.group(0) # Cut off prefix and suffix if postprocessing_params.cut_prefix: prefix_loc = response.find(postprocessing_params.cut_prefix) if prefix_loc != -1: response = response[ prefix_loc + len(postprocessing_params.cut_prefix) : ] if postprocessing_params.cut_suffix: suffix_loc = response.rfind(postprocessing_params.cut_suffix) if suffix_loc != -1: response = response[:suffix_loc] # Strip whitespace if postprocessing_params.strip_whitespace: response = response.strip() # Add prefix and suffix if postprocessing_params.added_prefix: response = postprocessing_params.added_prefix + response if postprocessing_params.added_suffix: response = response + postprocessing_params.added_suffix return response