Source code for oumi.environments.synthetic_environment

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Synthetic environment backed by LLM-simulated or Python-executed tools.

Stateless mode (``state_params=None``) batches LLM-simulated tool outputs
per tool id, cached by ``(tool_id, args)``. Individual tools may still
opt into Python execution by setting ``executor`` -- LLM simulation is
the fallback for tools without one. Stateful mode (``state_params`` is
set) requires every tool to define ``executor``; the env runs them
sequentially so state mutations thread through the batch.
"""

from __future__ import annotations

import copy
import dataclasses
import importlib
import json
import random
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import jsonschema

from oumi.core.configs.inference_config import InferenceConfig
from oumi.core.configs.params.base_params import BaseParams
from oumi.core.configs.params.environment_params import EnvironmentParams
from oumi.core.configs.params.grounding_params import (
    GroundingFact,
    StateGroundingConfig,
)
from oumi.core.configs.params.guided_decoding_params import GuidedDecodingParams
from oumi.core.configs.params.tool_params import ToolError, ToolParams
from oumi.core.registry import register_environment
from oumi.core.types.conversation import Conversation, Message, Role
from oumi.core.types.tool_call import ToolResult
from oumi.environments.base_environment import BaseEnvironment
from oumi.utils.str_utils import extract_json

if TYPE_CHECKING:
    from oumi.core.inference.base_inference_engine import BaseInferenceEngine


[docs] @dataclass class SyntheticStateParams(BaseParams): """Optional state configuration for a synthetic environment. State grounding for these pools is declared at the env level via ``EnvironmentParams.grounding.state`` — each entry's ``state_path`` must resolve to a ``list[dict]`` in ``initial_state``. """ state_schema: dict[str, Any] | None = None initial_state: dict[str, Any] | None = None
[docs] def __post_init__(self): """Validate state config consistency.""" if self.state_schema is not None and self.initial_state is not None: jsonschema.validate(self.initial_state, self.state_schema)
[docs] @dataclass class SyntheticEnvironmentKwargs(BaseParams): """Type-specific kwargs for SyntheticEnvironment.""" system_prompt: str = "" state_params: SyntheticStateParams | None = None cache_by_input: bool = True
[docs] def __post_init__(self) -> None: """Coerce state_params dict into SyntheticStateParams if needed.""" if isinstance(self.state_params, dict): self.state_params = SyntheticStateParams(**self.state_params)
[docs] def __finalize_and_validate__(self) -> None: """Finalize and validate the kwargs.""" if not self.system_prompt: raise ValueError( "SyntheticEnvironmentKwargs.system_prompt cannot be empty." ) if self.state_params is not None and self.cache_by_input: raise ValueError( "SyntheticEnvironmentKwargs.cache_by_input must be False when " "state_params is provided." )
def _import_executor(dotted: str, tool_id: str) -> Callable[..., Any]: """Resolve a dotted import path to a callable. Raises ValueError on failure.""" module_path, _, attr = dotted.rpartition(".") if not module_path or not attr: raise ValueError( f"Tool '{tool_id}': executor '{dotted}' must be a dotted import " f"path (e.g. 'pkg.module.fn')." ) try: module = importlib.import_module(module_path) except ImportError as e: raise ValueError( f"Tool '{tool_id}': cannot import executor module '{module_path}': {e}" ) from e executor = getattr(module, attr, None) if executor is None: raise ValueError( f"Tool '{tool_id}': module '{module_path}' has no attribute '{attr}'." ) if not callable(executor): raise ValueError( f"Tool '{tool_id}': executor '{dotted}' resolved to a non-callable." ) return executor
[docs] @register_environment("synthetic") class SyntheticEnvironment(BaseEnvironment): """LLM-simulated environment with optional mutable state. See the module docstring for the stateless vs stateful contract. """ def __init__( self, params: EnvironmentParams, kwargs: SyntheticEnvironmentKwargs, ) -> None: """Initialize a SyntheticEnvironment with the given params and kwargs.""" self._params = params self._kwargs = kwargs self._cache: dict[str, ToolResult] = {} self._state: dict[str, Any] | None = ( copy.deepcopy(kwargs.state_params.initial_state) if kwargs.state_params is not None and kwargs.state_params.initial_state is not None else None ) self._state_schema: dict[str, Any] | None = ( kwargs.state_params.state_schema if kwargs.state_params is not None else None ) self._state_grounding: list[StateGroundingConfig] = ( list(params.grounding.state) if params.grounding is not None else [] ) if self._state is None and self._state_grounding: raise ValueError( f"SyntheticEnvironment '{params.id}': grounding.state is " f"configured but the env has no state (state_params with " f"initial_state is required)." ) self._executors: dict[str, Callable[..., Any]] = { tool.id: _import_executor(tool.executor, tool.id) for tool in params.tools if tool.executor } if self._state is not None: missing = [t.id for t in params.tools if not t.executor] if missing: raise ValueError( "SyntheticEnvironment in stateful mode (state_params with " "initial_state set) requires every tool to define an executor; " "LLM-simulated tools cannot mutate state. Missing executor: " f"{missing}" ) if self._state is not None: self._validate_state_grounding() self._engine: BaseInferenceEngine | None = None self._base_inference_config: InferenceConfig | None = None
[docs] def attach_inference( self, engine: BaseInferenceEngine, base_config: InferenceConfig, ) -> None: """Inject the orchestrator's inference engine + base config.""" self._engine = engine self._base_inference_config = base_config
[docs] def requires_isolation(self) -> bool: """Stateful synth envs need per-sample isolation; stateless do not.""" return self._state is not None
[docs] @classmethod def from_params(cls, params: EnvironmentParams) -> SyntheticEnvironment: """Build a SyntheticEnvironment from its params object.""" kwargs = SyntheticEnvironmentKwargs(**(params.env_kwargs or {})) kwargs.finalize_and_validate() return cls(params, kwargs)
@property def current_state(self) -> dict[str, Any] | None: """Return the current in-memory state snapshot.""" if self._state is None: return None return copy.deepcopy(self._state) @staticmethod def _cache_key(tool_id: str, arguments: dict[str, Any]) -> str: """Build a stable cache key from tool id and arguments.""" return f"{tool_id}::{json.dumps(arguments, sort_keys=True)}" def _resolve_cached( self, tool_id: str, arguments: dict[str, Any] ) -> ToolResult | None: """Look up a cached result for the given tool call.""" if not self._kwargs.cache_by_input: return None result = self._cache.get(self._cache_key(tool_id, arguments)) if result is None: return None return ToolResult( output=copy.deepcopy(result.output), updated_state=copy.deepcopy(result.updated_state), ) def _cache_result( self, tool_id: str, arguments: dict[str, Any], result: ToolResult ) -> None: """Store a generated result in the cache.""" if not self._kwargs.cache_by_input: return self._cache[self._cache_key(tool_id, arguments)] = ToolResult( output=copy.deepcopy(result.output), updated_state=copy.deepcopy(result.updated_state), ) def _lookup_tool(self, tool_id: str) -> ToolParams: for tool in self._params.tools: if tool.id == tool_id: return tool raise ValueError( f"Tool '{tool_id}' not found in environment '{self._params.id}'. " f"Available tools: {[tool.id for tool in self._params.tools]}" )
[docs] def step(self, calls: list[tuple[str, dict[str, Any]]]) -> list[ToolResult]: """Execute tool calls. See module docstring for routing rules. Raises: ValueError: If any tool id is unknown. RuntimeError: If an LLM-simulated tool is invoked before ``attach_inference`` was called. ToolError: On simulator parse failure or schema mismatch. """ if not calls: return [] for tool_id, _ in calls: self._lookup_tool(tool_id) stateful = self._state is not None results: list[ToolResult | None] = [None] * len(calls) sim_misses: list[tuple[int, str, dict[str, Any]]] = [] for i, (tool_id, args) in enumerate(calls): if tool_id in self._executors: if stateful: results[i] = self._step_stateful_one(tool_id, args) else: results[i] = self._step_executor_one(tool_id, args) continue cached = self._resolve_cached(tool_id, args) if cached is not None: results[i] = cached else: sim_misses.append((i, tool_id, args)) if sim_misses: if self._engine is None or self._base_inference_config is None: raise RuntimeError( "SyntheticEnvironment.step called before " "attach_inference(). Wire the synthesizer's engine via " "attach_inference(engine, base_config) before invoking " "step()." ) groups: dict[str, list[tuple[int, dict[str, Any]]]] = {} for i, tool_id, args in sim_misses: groups.setdefault(tool_id, []).append((i, args)) for tool_id, group in groups.items(): tool = self._lookup_tool(tool_id) convs = [self._build_call_conv(tool, args) for _, args in group] inferred = self._engine.infer( convs, self._simulator_inference_config(tool) ) if len(inferred) != len(group): raise RuntimeError( f"Simulator returned {len(inferred)} responses for " f"{len(group)} calls to '{tool_id}'." ) for (idx, args), conv in zip(group, inferred): raw = self._extract_text(conv) result = self._parse_and_validate(raw, tool) self._cache_result(tool_id, args, result) results[idx] = result assert all(r is not None for r in results), ( "every call must produce a ToolResult" ) return results # type: ignore[return-value]
def _step_executor_one(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: """Execute a stateless tool via its executor callable.""" tool = self._lookup_tool(tool_id) tool.validate_arguments(arguments) result = self._executors[tool_id](arguments=arguments) self._validate_executor_output(tool, result) if result.updated_state is not None: raise ToolError( f"Tool '{tool.id}' executor returned updated_state but the " f"environment is stateless." ) return result def _step_stateful_one(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult: """Dispatch a stateful tool and commit ``updated_state`` after validation. ``state_in`` is a deep copy so the executor's reference to ``state`` can't reach back into ``self._state``: executors that mutate ``state`` in place (or hand the same dict back as ``updated_state``) end up touching the copy, and ``self._state`` is only reassigned via the explicit deepcopy of ``result.updated_state`` below. """ assert self._state is not None tool = self._lookup_tool(tool_id) tool.validate_arguments(arguments) state_in = copy.deepcopy(self._state) result = self._executors[tool_id](arguments=arguments, state=state_in) self._validate_executor_output(tool, result) if result.updated_state is not None: if tool.read_only: raise ToolError( f"Tool '{tool_id}' is read_only but executor returned " f"updated_state. Read-only tools must not mutate state." ) if self._state_schema is not None: try: jsonschema.validate(result.updated_state, self._state_schema) except jsonschema.ValidationError as e: raise ToolError( f"Tool '{tool_id}' updated_state failed state_schema " f"validation: {e}" ) from e self._state = copy.deepcopy(result.updated_state) return result def _validate_executor_output(self, tool: ToolParams, result: Any) -> None: """Validate executor return type + ``output_schema`` conformance.""" if not isinstance(result, ToolResult): raise ToolError( f"Tool '{tool.id}' executor must return ToolResult, got " f"{type(result).__name__}." ) if tool.output_schema is not None: try: jsonschema.validate(result.output, tool.output_schema) except jsonschema.ValidationError as e: raise ToolError( f"Tool '{tool.id}' executor output failed schema validation: {e}" ) from e
[docs] def sample_grounding( self, n: int, *, rng: random.Random, tool_ids: set[str] | None = None, ) -> list[GroundingFact]: """Project grounding facts from ``grounding.state`` pools. No-op for stateless envs or envs without ``grounding.state`` entries. ``tool_ids`` is accepted for ``BaseEnvironment`` signature compatibility but ignored — state grounding is pool-scoped, not tool-scoped. ``_validate_state_grounding`` at init guarantees each ``state_path`` resolves to a list in ``self._state``, and ``state_schema`` validation on every commit keeps it that way, so the projection loop trusts the shape. """ del tool_ids if self._state is None or not self._state_grounding: return [] pool: list[GroundingFact] = [] for cfg in self._state_grounding: whitelist = set(cfg.fields) for row in self._state[cfg.state_path]: projected = {k: v for k, v in row.items() if k in whitelist} pool.append(GroundingFact(data=projected)) return rng.sample(pool, min(n, len(pool)))
def _validate_state_grounding(self) -> None: """Validate each ``grounding.state`` entry against current state.""" assert self._state is not None for cfg in self._state_grounding: if cfg.state_path not in self._state: raise ValueError( f"SyntheticEnvironment '{self._params.id}': grounding " f"state_path '{cfg.state_path}' is not present in " f"initial_state. Top-level keys: " f"{sorted(self._state.keys())}." ) rows = self._state[cfg.state_path] if not isinstance(rows, list): raise ValueError( f"SyntheticEnvironment '{self._params.id}': grounding " f"state_path '{cfg.state_path}' must resolve to a list, " f"got {type(rows).__name__}." ) def _build_simulator_system_prompt(self, tool: ToolParams) -> str: """Compose the simulator system prompt: env persona + tool schema.""" return ( f"{self._kwargs.system_prompt}\n\n" f"You are simulating the `{tool.id}` tool. Respond ONLY with a " f"JSON object matching the tool's output schema. Do NOT include " f"explanations, markdown, or surrounding prose.\n\n" f"Tool schema:\n{json.dumps(tool.to_llm_schema(), indent=2)}" ) def _build_call_conv( self, tool: ToolParams, arguments: dict[str, Any] ) -> Conversation: """Build the simulator conversation for one tool call.""" user_payload = json.dumps( {"tool": tool.id, "arguments": arguments}, sort_keys=True ) return Conversation( messages=[ Message( role=Role.SYSTEM, content=self._build_simulator_system_prompt(tool), ), Message(role=Role.USER, content=user_payload), ] ) def _simulator_inference_config(self, tool: ToolParams) -> InferenceConfig: """Overlay guided decoding for the tool's output_schema onto base_config. Tools without ``output_schema`` get the permissive ``{"type": "object"}`` constraint. Mirrors ``ConversationSynthesizer._planner_inference_config``. """ assert self._base_inference_config is not None schema = tool.output_schema or {"type": "object"} sim_gen = dataclasses.replace( self._base_inference_config.generation, guided_decoding=GuidedDecodingParams(json=schema), ) return dataclasses.replace(self._base_inference_config, generation=sim_gen) @staticmethod def _extract_text(conv: Conversation) -> str: """Return the assistant's text response, or ``""`` to trigger a ToolError. Engines that passthrough on partial failure can leave the user payload (itself valid JSON) as ``messages[-1]``; the role guard forces the ToolError path so we don't echo arguments back as a tool result. """ if not conv.messages: return "" last = conv.messages[-1] if last.role != Role.ASSISTANT: return "" content = last.content return content.strip() if isinstance(content, str) else "" @staticmethod def _parse_and_validate(raw: str, tool: ToolParams) -> ToolResult: """Parse simulator output and validate against ``tool.output_schema``.""" if not raw: raise ToolError(f"Simulator returned empty response for '{tool.id}'.") try: parsed = json.loads(raw) except json.JSONDecodeError: extracted = extract_json(raw, expected_type=dict) if extracted is None: raise ToolError( f"Simulator output for '{tool.id}' is not valid JSON: {raw[:200]!r}" ) from None parsed = extracted if not isinstance(parsed, dict): raise ToolError( f"Simulator output for '{tool.id}' must be a JSON object, " f"got {type(parsed).__name__}." ) if tool.output_schema is not None: try: jsonschema.validate(parsed, tool.output_schema) except jsonschema.ValidationError as e: raise ToolError( f"Simulator output for '{tool.id}' failed schema validation: {e}" ) from e return ToolResult(output=parsed)