Source code for oumi.judges.rule_based_judge

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing_extensions import override

from oumi.core.configs.judge_config import JudgeConfig
from oumi.core.registry import REGISTRY, RegistryType
from oumi.judges.base_judge import BaseJudge, JudgeOutput, JudgeOutputField

# Keys for output fields
JUDGMENT_KEY = "judgment"


[docs] class RuleBasedJudge(BaseJudge): """A Rule Based Judge for evaluating outputs based on a configuration.""" def __init__(self, judge_config: JudgeConfig | str): """Initialize the RuleBasedJudge. Args: judge_config: JudgeConfig object or path to a judge configuration file. Must contain rule_judge_params for rule-based evaluation. Raises: ValueError: If rule_judge_params is not provided in the config. """ if isinstance(judge_config, str): judge_config = JudgeConfig.from_path(judge_config) if judge_config.rule_judge_params is None: raise ValueError( "rule_judge_params must be provided for RuleBasedJudge. " "Please add rule_judge_params to your JudgeConfig." ) self._judge_params = judge_config.judge_params self._rule_params = judge_config.rule_judge_params output_fields = self._create_output_fields() super().__init__( prompt_template=self._judge_params.prompt_template, prompt_template_placeholders=set(self._rule_params.input_fields), system_instruction=None, example_field_values=[], response_format=self._rule_params.response_format, output_fields=output_fields, inference_engine=None, ) def _create_output_fields(self) -> list[JudgeOutputField]: """Create output fields based on rule parameters.""" return [ JudgeOutputField( field_key=JUDGMENT_KEY, field_type=self._rule_params.judgment_type, field_scores=self._rule_params.judgment_scores, ) ]
[docs] @override def judge(self, inputs: list[dict[str, str]]) -> list[JudgeOutput]: self.validate_dataset(inputs) results = [] for input_data in inputs: judgment, score = self._apply_rule(input_data) results.append( JudgeOutput( raw_output=f"{JUDGMENT_KEY}: {judgment}", field_values={JUDGMENT_KEY: judgment}, field_scores={JUDGMENT_KEY: score}, response_format=self._rule_params.response_format, output_fields=self.output_fields, ) ) return results
def _apply_rule(self, input_data: dict[str, str]) -> tuple[bool, float]: """Pulls specified rule from registry and applies it on the input data.""" rule_type = self._rule_params.rule_type rule_class = REGISTRY.get(rule_type, RegistryType.RULE) if rule_class is None: raise ValueError(f"Rule type {rule_type} not found in registry") rule = rule_class() return rule.apply(input_data, self._rule_params.rule_config)