Source code for oumi.core.synthesis.conversation_synthesizer

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random

from oumi.builders.inference_engines import build_inference_engine
from oumi.core.configs.inference_config import InferenceConfig
from oumi.core.configs.inference_engine_type import InferenceEngineType
from oumi.core.configs.params.synthesis_params import (
    GeneralSynthesisParams,
    MultiTurnAttribute,
)
from oumi.core.synthesis.attribute_formatter import AttributeFormatter
from oumi.core.types.conversation import Conversation, Message, Role
from oumi.utils.logging import logger
from oumi.utils.str_utils import extract_json


[docs] class ConversationSynthesizer: """Synthesizes a conversation. Args: params: The parameters for the conversation synthesizer. inference_config: The configuration for the inference engine. """ def __init__( self, params: GeneralSynthesisParams, inference_config: InferenceConfig, ): """Initialize the synthesizer.""" self._params = params self._formatter = AttributeFormatter(params) self._inference_engine = build_inference_engine( engine_type=inference_config.engine or InferenceEngineType.NATIVE, model_params=inference_config.model, remote_params=inference_config.remote_params, ) self._inference_config = inference_config self._default_turn_order = [Role.USER, Role.ASSISTANT] def _validate_roles(self, multiturn_attribute: MultiTurnAttribute) -> None: """Validate that required roles have corresponding personas. Args: multiturn_attribute: The multi-turn attribute to validate. Raises: ValueError: If a required role is missing from role_instruction_messages. """ available_roles = set(multiturn_attribute.role_instruction_messages.keys()) for role in self._default_turn_order: if role not in available_roles: raise ValueError( f"Role '{role.value}' is missing from " f"role_instruction_messages. Available roles: " f"{[r.value for r in available_roles]}" )
[docs] def synthesize( self, samples: list[dict], multiturn_attributes: MultiTurnAttribute, ) -> list[dict[str, dict | str] | None]: """Synthesize a multi-turn conversation. Order will be identical to the order of the samples. Args: samples: The samples to synthesize values for. multiturn_attributes: The multi-turn attribute defining conversation rules. Returns: A list aligned to the input samples. Each entry is either: - a dictionary containing the conversation and plan, or - None when the synthesized conversation is filtered out. """ if not samples: return [] self._validate_roles(multiturn_attributes) logger.info( f"Synthesizing {len(samples)} conversations for " f"attribute '{multiturn_attributes.id}'" ) samples = self._plan_samples(samples, multiturn_attributes) conversations = self._synthesize_all_samples(samples, multiturn_attributes) records: list[dict[str, dict | str] | None] = [] plan_key = f"{multiturn_attributes.id}_plan" filtered_count = 0 for sample, conversation in zip(samples, conversations): if self._has_empty_messages(conversation): filtered_count += 1 records.append(None) continue record: dict[str, dict | str] = { multiturn_attributes.id: conversation.to_dict(), plan_key: sample["conversation_plan"], } records.append(record) if filtered_count > 0: logger.warning( f"Filtered out {filtered_count} conversation(s) with empty messages " f"out of {len(conversations)} total" ) return records
def _plan_samples( self, samples: list[dict], multiturn_attributes: MultiTurnAttribute, max_retries: int = 2, ) -> list[dict]: """Plan the conversation samples with retry logic for failed parses. Args: samples: The conversation samples to plan. multiturn_attributes: The multi-turn attribute defining conversation rules. max_retries: Maximum number of retry attempts for failed plan parsing. Returns: A list of sample dicts augmented with runtime fields (target_turns, conversation_plan, parsed_turn_plans). """ turn_order = self._default_turn_order augmented_samples: list[dict] = [] for sample in samples: target_turns = self._select_target_turns(multiturn_attributes, turn_order) augmented_sample = { **sample, "target_turns": target_turns, "conversation_plan": "", "parsed_turn_plans": [""] * target_turns, } augmented_samples.append(augmented_sample) logger.debug(f"Planning conversation with {target_turns} turns") indices_to_process = list(range(len(augmented_samples))) for attempt in range(max_retries + 1): if not indices_to_process: break planner_conversations = [ self._create_planner_prompt( multiturn_attributes, augmented_samples[i], ) for i in indices_to_process ] plans = self._generate_plan(planner_conversations) failed_indices: list[int] = [] for idx, plan in zip(indices_to_process, plans): augmented_sample = augmented_samples[idx] target_turns = augmented_sample["target_turns"] parsed = self._parse_plan(plan, target_turns) if parsed is not None: augmented_sample["conversation_plan"] = plan augmented_sample["parsed_turn_plans"] = parsed else: failed_indices.append(idx) if attempt < max_retries: logger.warning( f"Plan parsing failed for sample {idx}, " f"retrying ({attempt + 1}/{max_retries})" ) indices_to_process = failed_indices if indices_to_process: logger.warning( f"Failed to parse plans for {len(indices_to_process)} samples " f"after {max_retries + 1} attempts, proceeding without plan" ) return augmented_samples def _parse_plan(self, plan: str, target_turns: int) -> list[str] | None: """Parse a JSON-formatted conversation plan. Extracts turn instructions from JSON array. Expects format: [{"turn": 1, "instruction": "..."}, ...] Args: plan: The full plan text from the planner. target_turns: Expected number of turns. Returns: List of instruction strings (one per turn), or None if parsing failed. """ if not plan: return None turns = extract_json(plan, expected_type=list) if turns is None: single = extract_json(plan, expected_type=dict) if single is not None: turns = [single] else: return None result = [""] * target_turns for turn in turns: if not isinstance(turn, dict): continue turn_num = turn.get("turn") instruction = turn.get("instruction", "") if isinstance(turn_num, str): try: turn_num = int(turn_num) except ValueError: continue if isinstance(turn_num, int) and 1 <= turn_num <= target_turns: result[turn_num - 1] = str(instruction).strip() return result def _extract_response( self, inference_conversations: list[Conversation], ) -> list[str]: """Get the inference results from the inference conversations. If the inference result is not a string or the conversation is empty, an empty string will be returned. Strips whitespace to avoid API errors with trailing whitespace. """ results = [] for inference_result in inference_conversations: if not inference_result.messages: results.append("") continue content = inference_result.messages[-1].content if isinstance(content, str): results.append(content.strip()) else: results.append("") return results def _has_empty_messages(self, conversation: Conversation) -> bool: """Check if any non-system message in a conversation has empty content. System messages (e.g., output_system_prompt) are excluded from this check since they are generated by the synthesizer itself, not by inference. Args: conversation: The conversation to check. Returns: True if any non-system message has empty string content. """ for message in conversation.messages: if message.role == Role.SYSTEM: continue if not isinstance(message.content, str) or not message.content.strip(): return True return False def _format_persona(self, sample: dict, persona: str, role: Role) -> Message: """Format the persona for the sample. Args: sample: The sample dict containing all attributes. persona: The persona string to format. role: The role for this persona. Returns: A Message with the formatted persona as a SYSTEM message. """ formatted_content = self._formatter.format( sample, persona, missing_values_allowed=False, ) return Message( role=Role.SYSTEM, content=formatted_content, ) def _build_role_context( self, sample: dict, multiturn_attribute: MultiTurnAttribute ) -> str: """Build formatted role context for the planner. Formats the persona strings for each role. The returned string has curly braces escaped ({{ and }}) so it can be safely embedded in another template without causing format errors. """ parts = [] for role, persona in multiturn_attribute.role_instruction_messages.items(): formatted = self._formatter.format( sample, persona, missing_values_allowed=False ) parts.append(f"[{role.value.upper()}]\n{formatted}") result = "\n\n".join(parts) return result.replace("{", "{{").replace("}", "}}") def _build_turn_order_str(self, turn_order: list[Role], target_turns: int) -> str: """Build a string showing which role speaks at each turn. Args: turn_order: The role sequence that repeats. target_turns: Total number of turns. Returns: A string like "Turn 1: USER, Turn 2: ASSISTANT, Turn 3: USER, ..." """ parts = [] for i in range(target_turns): role = turn_order[i % len(turn_order)] parts.append(f"Turn {i + 1}: {role.value.upper()}") return ", ".join(parts) def _create_planner_prompt( self, multiturn_attribute: MultiTurnAttribute, sample: dict ) -> Conversation: """Create the planner prompt template with role context and turn order. Returns a Conversation with a one-shot example for consistent formatting. The prompt instructs the model to output JSON wrapped in code fences. """ role_context = self._build_role_context(sample, multiturn_attribute) turn_order = self._default_turn_order target_turns = sample["target_turns"] turn_order_str = self._build_turn_order_str(turn_order, target_turns) system_prompt = ( "You are a conversation planner. Create conversation outlines " "that flow logically from start to finish.\n\n" "IMPORTANT: Output your plan as a JSON array wrapped in ```json code " "fences. Each element must have: turn (number) and instruction (string).\n" "Your instructions MUST be specific to the role context provided. " "Each turn's instruction should reflect what that specific role " "would do at that point in the conversation." ) example_request = ( "Plan a 4-turn conversation.\n" "Turn order: Turn 1: USER, Turn 2: ASSISTANT, Turn 3: USER, " "Turn 4: ASSISTANT\n\n" "Role context:\n" "[USER]\n" "You are a customer who has an issue with a recent order.\n\n" "[ASSISTANT]\n" "You are a helpful support agent who resolves customer issues.\n\n" "Additional instructions: Focus on resolving the order issue " "efficiently while maintaining a polite and helpful tone." ) example_response = """```json [ {"turn": 1, "instruction": "Greet support and explain the issue with the order"}, {"turn": 2, "instruction": "Acknowledge the issue and ask for order details"}, {"turn": 3, "instruction": "Provide order number and describe the problem further"}, {"turn": 4, "instruction": "Confirm the issue and offer a resolution"} ] ```""" base_prompt = ( f"Plan a {target_turns}-turn conversation.\n" f"Turn order: {turn_order_str}\n\n" "Guidelines:\n" "- Each turn should build on the previous turn.\n" f"- Pace the conversation naturally for {target_turns} turns.\n" "- Focus on what happens, not exact wording.\n" "- Instructions MUST be specific to the roles and context provided below.\n" ) if role_context: base_prompt += f"\nRole context:\n{role_context}\n" if multiturn_attribute.conversation_planner: formatted_planner = self._formatter.format( sample, multiturn_attribute.conversation_planner, missing_values_allowed=False, ) base_prompt += f"\nAdditional instructions: {formatted_planner}\n" base_prompt += ( "\nOutput ONLY the JSON array wrapped in ```json code fences. " "No other text." ) return Conversation( messages=[ Message(role=Role.SYSTEM, content=system_prompt), Message(role=Role.USER, content=example_request), Message(role=Role.ASSISTANT, content=example_response), Message(role=Role.USER, content=base_prompt), ], ) def _generate_plan(self, planners: list[Conversation]) -> list[str]: """Generate plans for how the conversations should proceed. Args: planners: The planner conversation templates (already formatted). Returns: A list of plan strings, one per sample. """ inference_results = self._inference_engine.infer( planners, inference_config=self._inference_config, ) return self._extract_response(inference_results) def _synthesize_all_samples( self, samples: list[dict], multiturn_attribute: MultiTurnAttribute, ) -> list[Conversation]: """Synthesize multi-turn conversations for all samples with batched inference. Args: samples: List of sample dicts with runtime fields (target_turns, conversation_plan). multiturn_attribute: The multi-turn attribute defining conversation rules. Returns: List of Conversation objects, one per sample. """ if not samples: return [] histories: list[list[Message]] = [[] for _ in samples] max_turns = max(sample["target_turns"] for sample in samples) for turn_idx in range(max_turns): current_turn = turn_idx + 1 prompts: list[Conversation] = [] sample_indices: list[int] = [] roles_for_turn: list[Role] = [] for i, sample in enumerate(samples): if turn_idx >= sample["target_turns"]: continue turn_order = self._default_turn_order role = turn_order[turn_idx % len(turn_order)] roles_for_turn.append(role) prompt_messages: list[Message] = [] sample_with_turn = {**sample, "current_turn": current_turn} persona = multiturn_attribute.role_instruction_messages[role] formatted_persona = self._format_persona( sample_with_turn, persona, role ) prompt_messages.append(formatted_persona) prompt_messages.extend(histories[i]) target_turns = sample["target_turns"] parsed_turn_plans = sample.get("parsed_turn_plans", []) turn_instruction = "" if turn_idx < len(parsed_turn_plans): turn_instruction = parsed_turn_plans[turn_idx] turn_info = ( f"You are generating turn {current_turn} of {target_turns} " f"as the {role.value.upper()}.\n\n" ) if turn_instruction: turn_info += f"For this turn: {turn_instruction}\n\n" turn_info += ( "Generate ONLY your response for this turn. Stay in character." ) prompt_messages.append(Message(role=Role.USER, content=turn_info)) prompts.append(Conversation(messages=prompt_messages)) sample_indices.append(i) if not prompts: break inference_results = self._inference_engine.infer( prompts, inference_config=self._inference_config, ) generated_texts = self._extract_response(inference_results) if len(generated_texts) != len(prompts): raise RuntimeError( f"Inference engine returned {len(generated_texts)} results " f"but {len(prompts)} prompts were submitted. " f"This may indicate an inference engine error." ) for idx, generated_text, role in zip( sample_indices, generated_texts, roles_for_turn ): histories[idx].append(Message(role=role, content=generated_text)) conversations: list[Conversation] = [] for sample, history in zip(samples, histories): output_messages: list[Message] = [] output_message = self._format_output_system_message( sample, multiturn_attribute.output_system_prompt ) if output_message: output_messages.append(output_message) output_messages.extend(history) conversations.append(Conversation(messages=output_messages)) return conversations def _select_target_turns( self, multiturn_attribute: MultiTurnAttribute, turn_order: list[Role] ) -> int: min_turns = multiturn_attribute.min_turns max_turns = multiturn_attribute.max_turns target_turns = random.randint(min_turns, max_turns) if Role.ASSISTANT not in turn_order: return target_turns def role_at(turn_count: int) -> Role: return turn_order[(turn_count - 1) % len(turn_order)] if role_at(target_turns) == Role.ASSISTANT: return target_turns for turn_count in range(target_turns + 1, max_turns + 1): if role_at(turn_count) == Role.ASSISTANT: return turn_count for turn_count in range(target_turns - 1, min_turns - 1, -1): if role_at(turn_count) == Role.ASSISTANT: return turn_count return target_turns def _format_output_system_message( self, sample: dict, system_message: str | None, ) -> Message | None: if system_message is None: return None formatted_content = self._formatter.format( sample, system_message, ) return Message(role=Role.SYSTEM, content=formatted_content.strip())