feat: custom judge criteria and multi-dimensional scoring
Add configurable judge rubrics and multi-dimensional scoring with
weighted aggregation. New config fields: judge_criteria (free text)
and judge_dimensions (list of {name, weight, description}). CLI
--judge-criteria flag provides quick overrides. The judge adapter
computes weighted aggregate scores and enriches feedback with
per-dimension breakdowns.
Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -11,6 +11,15 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
|||||||
CONFIG_VERSION = 1
|
CONFIG_VERSION = 1
|
||||||
|
|
||||||
_ERROR_STRATEGY_VALUES = {"skip", "retry", "abort"}
|
_ERROR_STRATEGY_VALUES = {"skip", "retry", "abort"}
|
||||||
|
_EVAL_METRIC_VALUES = {"exact", "bleu", "rouge_l", "cosine", "llm_judge"}
|
||||||
|
|
||||||
|
|
||||||
|
class JudgeDimension(BaseModel):
|
||||||
|
"""A single evaluation dimension for multi-dimensional scoring."""
|
||||||
|
|
||||||
|
name: str = Field(min_length=1, description="Dimension name (e.g. accuracy, clarity, safety).")
|
||||||
|
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="Weight for this dimension (0.0–1.0).")
|
||||||
|
description: str = Field(default="", description="What this dimension measures.")
|
||||||
|
|
||||||
|
|
||||||
class OptimizationConfig(BaseModel):
|
class OptimizationConfig(BaseModel):
|
||||||
@@ -67,6 +76,30 @@ class OptimizationConfig(BaseModel):
|
|||||||
minibatch_size: int = Field(default=5, ge=1, description="Inputs per evaluation minibatch.")
|
minibatch_size: int = Field(default=5, ge=1, description="Inputs per evaluation minibatch.")
|
||||||
perfect_score: float = Field(default=1.0, ge=0.0, le=1.0)
|
perfect_score: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
# --- Population-based evolution ---
|
||||||
|
population_size: int = Field(
|
||||||
|
default=1,
|
||||||
|
ge=1,
|
||||||
|
description="Number of candidates in the evolution population. 1 = single-candidate hill climbing (backward compat).",
|
||||||
|
)
|
||||||
|
crossover_rate: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Probability of applying crossover vs reflective mutation.",
|
||||||
|
)
|
||||||
|
mutation_rate: float = Field(
|
||||||
|
default=0.3,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Probability of applying a mutation operator after crossover or proposal.",
|
||||||
|
)
|
||||||
|
diversity_penalty: float = Field(
|
||||||
|
default=0.1,
|
||||||
|
ge=0.0,
|
||||||
|
description="Penalty weight for similarity to existing population members.",
|
||||||
|
)
|
||||||
|
|
||||||
# --- Reproducibility ---
|
# --- Reproducibility ---
|
||||||
seed: int = Field(default=42, ge=0)
|
seed: int = Field(default=42, ge=0)
|
||||||
|
|
||||||
@@ -79,10 +112,72 @@ class OptimizationConfig(BaseModel):
|
|||||||
circuit_breaker_threshold: int = Field(default=5, ge=1, description="Consecutive failures before circuit opens.")
|
circuit_breaker_threshold: int = Field(default=5, ge=1, description="Consecutive failures before circuit opens.")
|
||||||
error_strategy: str = Field(default="retry", description="Error handling strategy: skip | retry | abort.")
|
error_strategy: str = Field(default="retry", description="Error handling strategy: skip | retry | abort.")
|
||||||
|
|
||||||
|
# --- Logging & observability ---
|
||||||
|
debug: bool = Field(default=False, description="Enable DEBUG-level logging.")
|
||||||
|
log_format: str = Field(default="text", description="Log output format: text | json.")
|
||||||
|
log_file: str | None = Field(default=None, description="Optional file path for log output.")
|
||||||
|
|
||||||
|
# --- Checkpoint & resume ---
|
||||||
|
checkpoint_dir: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Directory for checkpoint files. Set to enable checkpointing.",
|
||||||
|
)
|
||||||
|
checkpoint_interval: int = Field(
|
||||||
|
default=5,
|
||||||
|
ge=1,
|
||||||
|
description="Save a checkpoint every N iterations (and on every accepted improvement).",
|
||||||
|
)
|
||||||
|
resume: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Resume from the latest checkpoint in checkpoint_dir.",
|
||||||
|
)
|
||||||
|
|
||||||
# --- Output ---
|
# --- Output ---
|
||||||
output_path: str = Field(default="output.yaml", min_length=1)
|
output_path: str = Field(default="output.yaml", min_length=1)
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
|
|
||||||
|
# --- Hold-out validation ---
|
||||||
|
validation_split: float = Field(
|
||||||
|
default=0.3,
|
||||||
|
ge=0.0,
|
||||||
|
lt=1.0,
|
||||||
|
description="Fraction of synthetic pool reserved for validation (0 = disabled).",
|
||||||
|
)
|
||||||
|
early_stop_patience: int = Field(
|
||||||
|
default=5,
|
||||||
|
ge=1,
|
||||||
|
description="Stop if validation score degrades for this many consecutive iterations.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Judge criteria & multi-dimensional scoring ---
|
||||||
|
judge_criteria: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Custom judge rubric or evaluation criteria override (free text).",
|
||||||
|
)
|
||||||
|
judge_dimensions: list[JudgeDimension] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Multi-dimensional scoring dimensions with configurable weights.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Ground-truth evaluation ---
|
||||||
|
eval_dataset_path: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
min_length=1,
|
||||||
|
description="Path to a CSV/JSON dataset with 'input' and 'expected_output' columns.",
|
||||||
|
)
|
||||||
|
eval_metric: str = Field(
|
||||||
|
default="bleu",
|
||||||
|
description="Similarity metric for ground-truth eval: exact | bleu | rouge_l | cosine | llm_judge.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("log_format")
|
||||||
|
@classmethod
|
||||||
|
def _validate_log_format(cls, v: str) -> str:
|
||||||
|
allowed = {"text", "json"}
|
||||||
|
if v not in allowed:
|
||||||
|
raise ValueError(f"log_format must be one of {sorted(allowed)}, got '{v}'")
|
||||||
|
return v
|
||||||
|
|
||||||
@field_validator("error_strategy")
|
@field_validator("error_strategy")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_error_strategy(cls, v: str) -> str:
|
def _validate_error_strategy(cls, v: str) -> str:
|
||||||
@@ -92,6 +187,15 @@ class OptimizationConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("eval_metric")
|
||||||
|
@classmethod
|
||||||
|
def _validate_eval_metric(cls, v: str) -> str:
|
||||||
|
if v not in _EVAL_METRIC_VALUES:
|
||||||
|
raise ValueError(
|
||||||
|
f"eval_metric must be one of {sorted(_EVAL_METRIC_VALUES)}, got '{v}'"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _migrate_config(cls, data: Any) -> Any:
|
def _migrate_config(cls, data: Any) -> Any:
|
||||||
@@ -118,3 +222,7 @@ class OptimizationResult:
|
|||||||
final_score: float
|
final_score: float
|
||||||
improvement: float
|
improvement: float
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
# Hold-out validation metrics (populated when validation_split > 0)
|
||||||
|
final_validation_score: float | None = None
|
||||||
|
best_validation_score: float | None = None
|
||||||
|
early_stopped: bool = False
|
||||||
|
|||||||
428
src/prometheus/cli/commands/optimize.py
Normal file
428
src/prometheus/cli/commands/optimize.py
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
"""prometheus optimize — run prompt optimization."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
import typer
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||||
|
from prometheus.application.dto import JudgeDimension, OptimizationConfig, OptimizationResult
|
||||||
|
from prometheus.domain.entities import EvalResult, Prompt
|
||||||
|
from prometheus.application.evaluator import PromptEvaluator
|
||||||
|
from prometheus.application.ground_truth_evaluator import GroundTruthEvaluator
|
||||||
|
from prometheus.application.use_cases import OptimizePromptUseCase
|
||||||
|
from prometheus.cli.logging_setup import configure_logging
|
||||||
|
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
|
||||||
|
from prometheus.infrastructure.checkpoint import JsonCheckpointPersistence
|
||||||
|
from prometheus.infrastructure.file_io import YamlPersistence
|
||||||
|
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||||
|
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
||||||
|
from prometheus.infrastructure.crossover_adapter import DSPyCrossoverAdapter
|
||||||
|
from prometheus.infrastructure.mutation_adapter import DSPyMutationAdapter
|
||||||
|
from prometheus.infrastructure.crossover_adapter import DSPyCrossoverAdapter
|
||||||
|
from prometheus.infrastructure.mutation_adapter import DSPyMutationAdapter
|
||||||
|
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
||||||
|
from prometheus.infrastructure.similarity import create_similarity_adapter
|
||||||
|
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
def register(app: typer.Typer) -> None:
|
||||||
|
"""Register the optimize command on the Typer app."""
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def optimize(
|
||||||
|
input: str = typer.Option(
|
||||||
|
...,
|
||||||
|
"-i",
|
||||||
|
"--input",
|
||||||
|
help="Path to input YAML config file.",
|
||||||
|
exists=True,
|
||||||
|
readable=True,
|
||||||
|
),
|
||||||
|
output: str = typer.Option(
|
||||||
|
"output.yaml",
|
||||||
|
"-o",
|
||||||
|
"--output",
|
||||||
|
help="Path to output YAML result file.",
|
||||||
|
),
|
||||||
|
verbose: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
help="Print detailed progress (INFO level).",
|
||||||
|
),
|
||||||
|
debug: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--debug",
|
||||||
|
help="Enable DEBUG-level logging (overrides -v).",
|
||||||
|
),
|
||||||
|
log_format: str = typer.Option(
|
||||||
|
"text",
|
||||||
|
"--log-format",
|
||||||
|
help="Log output format: text | json.",
|
||||||
|
),
|
||||||
|
log_file: str | None = typer.Option(
|
||||||
|
None,
|
||||||
|
"--log-file",
|
||||||
|
help="Optional file path to write logs to.",
|
||||||
|
),
|
||||||
|
max_retries: int = typer.Option(
|
||||||
|
3,
|
||||||
|
"--max-retries",
|
||||||
|
help="Max retry attempts for transient LLM errors (429, timeout, 5xx).",
|
||||||
|
),
|
||||||
|
error_strategy: str = typer.Option(
|
||||||
|
"retry",
|
||||||
|
"--error-strategy",
|
||||||
|
help="How to handle errors: skip | retry | abort.",
|
||||||
|
),
|
||||||
|
max_concurrency: int = typer.Option(
|
||||||
|
5,
|
||||||
|
"--max-concurrency",
|
||||||
|
help="Max parallel LLM calls for minibatch execution and judging.",
|
||||||
|
),
|
||||||
|
eval_dataset: str | None = typer.Option(
|
||||||
|
None,
|
||||||
|
"--eval-dataset",
|
||||||
|
help="Path to a CSV/JSON dataset with 'input' and 'expected_output' columns.",
|
||||||
|
),
|
||||||
|
eval_metric: str = typer.Option(
|
||||||
|
"bleu",
|
||||||
|
"--eval-metric",
|
||||||
|
help="Similarity metric for ground-truth eval: exact | bleu | rouge_l | cosine | llm_judge.",
|
||||||
|
),
|
||||||
|
checkpoint_dir: str | None = typer.Option(
|
||||||
|
None,
|
||||||
|
"--checkpoint-dir",
|
||||||
|
help="Directory for checkpoint files. Enables periodic checkpointing.",
|
||||||
|
),
|
||||||
|
checkpoint_interval: int = typer.Option(
|
||||||
|
5,
|
||||||
|
"--checkpoint-interval",
|
||||||
|
help="Save a checkpoint every N iterations.",
|
||||||
|
),
|
||||||
|
resume: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--resume",
|
||||||
|
help="Resume from the latest checkpoint in --checkpoint-dir.",
|
||||||
|
),
|
||||||
|
population_size: int = typer.Option(
|
||||||
|
1,
|
||||||
|
"--population-size",
|
||||||
|
help="Number of candidates in the evolution population. 1 = single-candidate hill climbing.",
|
||||||
|
),
|
||||||
|
crossover_rate: float = typer.Option(
|
||||||
|
0.5,
|
||||||
|
"--crossover-rate",
|
||||||
|
help="Probability of applying crossover vs reflective mutation (0.0–1.0). Only used when --population-size > 1.",
|
||||||
|
),
|
||||||
|
mutation_rate: float = typer.Option(
|
||||||
|
0.3,
|
||||||
|
"--mutation-rate",
|
||||||
|
help="Probability of applying a mutation operator after crossover/proposal (0.0–1.0). Only used when --population-size > 1.",
|
||||||
|
),
|
||||||
|
validation_split: float = typer.Option(
|
||||||
|
0.3,
|
||||||
|
"--validation-split",
|
||||||
|
help="Fraction of synthetic pool reserved for hold-out validation (0.0–0.9). 0 disables validation.",
|
||||||
|
),
|
||||||
|
early_stop_patience: int = typer.Option(
|
||||||
|
5,
|
||||||
|
"--early-stop-patience",
|
||||||
|
help="Stop if validation score does not improve for this many consecutive iterations.",
|
||||||
|
),
|
||||||
|
judge_criteria: str | None = typer.Option(
|
||||||
|
None,
|
||||||
|
"--judge-criteria",
|
||||||
|
help="Custom judge rubric or evaluation criteria override (free text).",
|
||||||
|
),
|
||||||
|
) -> None:
|
||||||
|
"""Optimize a prompt without any reference data.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
prometheus optimize -i config.yaml -o result.yaml
|
||||||
|
prometheus optimize -i config.yaml --eval-dataset data.csv --eval-metric bleu
|
||||||
|
prometheus optimize -i config.yaml --checkpoint-dir .prometheus/checkpoints --resume
|
||||||
|
"""
|
||||||
|
asyncio.run(
|
||||||
|
_async_optimize(
|
||||||
|
input, output, verbose, debug, log_format, log_file,
|
||||||
|
max_retries, error_strategy, max_concurrency,
|
||||||
|
eval_dataset, eval_metric,
|
||||||
|
checkpoint_dir, checkpoint_interval, resume,
|
||||||
|
population_size, crossover_rate, mutation_rate,
|
||||||
|
validation_split, early_stop_patience,
|
||||||
|
judge_criteria,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_optimize(
|
||||||
|
input: str,
|
||||||
|
output: str,
|
||||||
|
verbose: bool,
|
||||||
|
debug: bool,
|
||||||
|
log_format: str,
|
||||||
|
log_file: str | None,
|
||||||
|
max_retries: int,
|
||||||
|
error_strategy: str,
|
||||||
|
max_concurrency: int,
|
||||||
|
eval_dataset: str | None,
|
||||||
|
eval_metric: str,
|
||||||
|
checkpoint_dir: str | None,
|
||||||
|
checkpoint_interval: int,
|
||||||
|
resume: bool,
|
||||||
|
population_size: int = 1,
|
||||||
|
crossover_rate: float = 0.5,
|
||||||
|
mutation_rate: float = 0.3,
|
||||||
|
validation_split: float = 0.3,
|
||||||
|
early_stop_patience: int = 5,
|
||||||
|
judge_criteria: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
# Configure structured logging
|
||||||
|
if debug:
|
||||||
|
log_level = logging.DEBUG
|
||||||
|
elif verbose:
|
||||||
|
log_level = logging.INFO
|
||||||
|
else:
|
||||||
|
log_level = logging.WARNING
|
||||||
|
configure_logging(level=log_level, log_format=log_format, log_file=log_file)
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
Panel.fit(
|
||||||
|
"PROMETHEUS — Prompt Evolution Engine",
|
||||||
|
subtitle="No reference data required",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Load & validate config
|
||||||
|
persistence = YamlPersistence()
|
||||||
|
raw_config = persistence.read_config(input)
|
||||||
|
|
||||||
|
# CLI flags override config file values
|
||||||
|
raw_config.setdefault("max_retries", max_retries)
|
||||||
|
raw_config.setdefault("error_strategy", error_strategy)
|
||||||
|
raw_config.setdefault("max_concurrency", max_concurrency)
|
||||||
|
raw_config["output_path"] = output
|
||||||
|
raw_config["verbose"] = verbose
|
||||||
|
raw_config["debug"] = debug
|
||||||
|
raw_config["log_format"] = log_format
|
||||||
|
raw_config["log_file"] = log_file
|
||||||
|
if eval_dataset:
|
||||||
|
raw_config["eval_dataset_path"] = eval_dataset
|
||||||
|
raw_config.setdefault("eval_metric", eval_metric)
|
||||||
|
if checkpoint_dir:
|
||||||
|
raw_config["checkpoint_dir"] = checkpoint_dir
|
||||||
|
raw_config.setdefault("checkpoint_interval", checkpoint_interval)
|
||||||
|
if resume:
|
||||||
|
raw_config["resume"] = True
|
||||||
|
raw_config.setdefault("population_size", population_size)
|
||||||
|
raw_config.setdefault("crossover_rate", crossover_rate)
|
||||||
|
raw_config.setdefault("mutation_rate", mutation_rate)
|
||||||
|
raw_config.setdefault("validation_split", validation_split)
|
||||||
|
raw_config.setdefault("early_stop_patience", early_stop_patience)
|
||||||
|
if judge_criteria:
|
||||||
|
raw_config["judge_criteria"] = judge_criteria
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = OptimizationConfig.model_validate(raw_config)
|
||||||
|
except ValidationError as exc:
|
||||||
|
console.print("[bold red]Configuration error:[/bold red]\n")
|
||||||
|
for err in exc.errors():
|
||||||
|
loc = " → ".join(str(l) for l in err["loc"])
|
||||||
|
console.print(f" [red]• {loc}: {err['msg']}[/red]")
|
||||||
|
raise typer.Exit(code=1) from exc
|
||||||
|
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
||||||
|
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
||||||
|
|
||||||
|
# 2. Create per-model DSPy LM instances
|
||||||
|
def _model_lm_kwargs(
|
||||||
|
model_api_base: str | None,
|
||||||
|
model_api_key_env: str | None,
|
||||||
|
) -> dict:
|
||||||
|
"""Build kwargs for dspy.LM, using per-model overrides with global fallback."""
|
||||||
|
kwargs: dict = {}
|
||||||
|
api_base = model_api_base or config.api_base
|
||||||
|
api_key_env = model_api_key_env or config.api_key_env
|
||||||
|
if api_base:
|
||||||
|
kwargs["api_base"] = api_base
|
||||||
|
if api_key_env:
|
||||||
|
kwargs["api_key"] = os.environ.get(api_key_env, "")
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
task_lm = dspy.LM(
|
||||||
|
config.task_model,
|
||||||
|
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env),
|
||||||
|
)
|
||||||
|
judge_lm = dspy.LM(
|
||||||
|
config.judge_model,
|
||||||
|
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env),
|
||||||
|
)
|
||||||
|
proposer_lm = dspy.LM(
|
||||||
|
config.proposer_model,
|
||||||
|
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env),
|
||||||
|
)
|
||||||
|
synth_lm = dspy.LM(
|
||||||
|
config.synth_model,
|
||||||
|
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Build adapters (Dependency Injection — each gets its own LM + retry config)
|
||||||
|
synth_adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||||
|
llm_adapter = DSPyLLMAdapter(
|
||||||
|
lm=task_lm,
|
||||||
|
max_retries=config.max_retries,
|
||||||
|
retry_delay_base=config.retry_delay_base,
|
||||||
|
)
|
||||||
|
judge_adapter = DSPyJudgeAdapter(
|
||||||
|
lm=judge_lm,
|
||||||
|
max_retries=config.max_retries,
|
||||||
|
retry_delay_base=config.retry_delay_base,
|
||||||
|
max_concurrency=config.max_concurrency,
|
||||||
|
judge_criteria=config.judge_criteria,
|
||||||
|
judge_dimensions=config.judge_dimensions,
|
||||||
|
)
|
||||||
|
proposer_adapter = DSPyProposerAdapter(
|
||||||
|
lm=proposer_lm,
|
||||||
|
max_retries=config.max_retries,
|
||||||
|
retry_delay_base=config.retry_delay_base,
|
||||||
|
)
|
||||||
|
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
||||||
|
evaluator = PromptEvaluator(
|
||||||
|
executor=llm_adapter,
|
||||||
|
judge=judge_adapter,
|
||||||
|
max_concurrency=config.max_concurrency,
|
||||||
|
)
|
||||||
|
# Build checkpoint port if checkpoint_dir is configured
|
||||||
|
checkpoint_port = None
|
||||||
|
if config.checkpoint_dir:
|
||||||
|
checkpoint_port = JsonCheckpointPersistence(checkpoint_dir=config.checkpoint_dir)
|
||||||
|
|
||||||
|
# Build crossover/mutation adapters for population-based evolution
|
||||||
|
crossover_adapter = None
|
||||||
|
mutation_adapter = None
|
||||||
|
if config.population_size > 1:
|
||||||
|
# Reuse proposer LM for crossover and mutation (same model, same role)
|
||||||
|
crossover_adapter = DSPyCrossoverAdapter(
|
||||||
|
lm=proposer_lm,
|
||||||
|
max_retries=config.max_retries,
|
||||||
|
retry_delay_base=config.retry_delay_base,
|
||||||
|
)
|
||||||
|
mutation_adapter = DSPyMutationAdapter(
|
||||||
|
lm=proposer_lm,
|
||||||
|
max_retries=config.max_retries,
|
||||||
|
retry_delay_base=config.retry_delay_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
use_case = OptimizePromptUseCase(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=proposer_adapter,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
checkpoint_port=checkpoint_port,
|
||||||
|
crossover_port=crossover_adapter,
|
||||||
|
mutation_port=mutation_adapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Execute
|
||||||
|
with console.status("[bold green]Evolving prompt..."):
|
||||||
|
result = await use_case.execute(config)
|
||||||
|
|
||||||
|
# 5. Display results
|
||||||
|
_display_result(result)
|
||||||
|
|
||||||
|
# 6. Optional ground-truth evaluation on the optimized prompt
|
||||||
|
if config.eval_dataset_path:
|
||||||
|
dataset = FileDatasetLoader().load(config.eval_dataset_path)
|
||||||
|
if config.eval_metric == "llm_judge":
|
||||||
|
# llm_judge reuses the existing PromptEvaluator with the LLM judge
|
||||||
|
from prometheus.domain.entities import SyntheticExample
|
||||||
|
synth_dataset = [
|
||||||
|
SyntheticExample(input_text=ex.input_text, id=ex.id) for ex in dataset
|
||||||
|
]
|
||||||
|
gt_eval = PromptEvaluator(
|
||||||
|
executor=llm_adapter,
|
||||||
|
judge=judge_adapter,
|
||||||
|
max_concurrency=config.max_concurrency,
|
||||||
|
)
|
||||||
|
with console.status("[bold green]Running ground-truth evaluation (llm_judge)..."):
|
||||||
|
gt_result = await gt_eval.evaluate(
|
||||||
|
prompt=Prompt(text=result.optimized_prompt),
|
||||||
|
minibatch=synth_dataset,
|
||||||
|
task_description=config.task_description,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gt_evaluator = GroundTruthEvaluator(
|
||||||
|
executor=llm_adapter,
|
||||||
|
similarity=create_similarity_adapter(config.eval_metric),
|
||||||
|
max_concurrency=config.max_concurrency,
|
||||||
|
)
|
||||||
|
with console.status("[bold green]Running ground-truth evaluation..."):
|
||||||
|
gt_result = await gt_evaluator.evaluate(
|
||||||
|
prompt=Prompt(text=result.optimized_prompt),
|
||||||
|
dataset=dataset,
|
||||||
|
)
|
||||||
|
_display_ground_truth(gt_result, config.eval_metric, len(dataset))
|
||||||
|
|
||||||
|
# 7. Save
|
||||||
|
_save_result(persistence, output, result)
|
||||||
|
console.print(f"\n[green]Results saved to {output}[/green]")
|
||||||
|
|
||||||
|
|
||||||
|
def _display_result(result: OptimizationResult) -> None:
|
||||||
|
"""Display a Rich summary in the terminal."""
|
||||||
|
console.print()
|
||||||
|
console.print(
|
||||||
|
Panel(
|
||||||
|
f"[bold green]Optimized Prompt[/bold green]\n\n{result.optimized_prompt}",
|
||||||
|
title="Result",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
table = Table(title="Metrics")
|
||||||
|
table.add_column("Metric", style="cyan")
|
||||||
|
table.add_column("Value", style="bold")
|
||||||
|
table.add_row("Initial Score", f"{result.initial_score:.2f}")
|
||||||
|
table.add_row("Final Score", f"{result.final_score:.2f}")
|
||||||
|
table.add_row("Improvement", f"{result.improvement:+.2f}")
|
||||||
|
if result.best_validation_score is not None:
|
||||||
|
table.add_row("Best Validation Score", f"{result.best_validation_score:.4f}")
|
||||||
|
if result.early_stopped:
|
||||||
|
table.add_row("Early Stopped", "[yellow]Yes[/yellow]")
|
||||||
|
table.add_row("Iterations", str(result.iterations_used))
|
||||||
|
table.add_row("LLM Calls", str(result.total_llm_calls))
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_result(
|
||||||
|
persistence: YamlPersistence,
|
||||||
|
path: str,
|
||||||
|
result: OptimizationResult,
|
||||||
|
) -> None:
|
||||||
|
"""Save the result as YAML."""
|
||||||
|
persistence.write_result(path, asdict(result))
|
||||||
|
|
||||||
|
|
||||||
|
def _display_ground_truth(
|
||||||
|
result: EvalResult, metric: str, dataset_size: int
|
||||||
|
) -> None:
|
||||||
|
"""Display ground-truth evaluation results."""
|
||||||
|
console.print()
|
||||||
|
table = Table(title=f"Ground-Truth Evaluation (metric: {metric})")
|
||||||
|
table.add_column("Metric", style="cyan")
|
||||||
|
table.add_column("Value", style="bold")
|
||||||
|
table.add_row("Dataset Size", str(dataset_size))
|
||||||
|
table.add_row("Mean Score", f"{result.mean_score:.4f}")
|
||||||
|
table.add_row("Total Score", f"{result.total_score:.4f}")
|
||||||
|
exact_matches = sum(1 for s in result.scores if s >= 0.99)
|
||||||
|
table.add_row("Exact Matches", f"{exact_matches}/{dataset_size}")
|
||||||
|
table.add_row("Accuracy", f"{exact_matches / dataset_size:.2%}")
|
||||||
|
console.print(table)
|
||||||
@@ -19,3 +19,26 @@ def should_accept(
|
|||||||
def normalize_score(raw: float, min_val: float = 0.0, max_val: float = 1.0) -> float:
|
def normalize_score(raw: float, min_val: float = 0.0, max_val: float = 1.0) -> float:
|
||||||
"""Clamp a score within [min_val, max_val]."""
|
"""Clamp a score within [min_val, max_val]."""
|
||||||
return max(min_val, min(max_val, raw))
|
return max(min_val, min(max_val, raw))
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_aggregate(
|
||||||
|
scores: dict[str, float],
|
||||||
|
weights: dict[str, float],
|
||||||
|
) -> float:
|
||||||
|
"""Compute a weighted average of per-dimension scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scores: Mapping of dimension name → score (0.0–1.0).
|
||||||
|
weights: Mapping of dimension name → weight (0.0–1.0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Weighted average in [0.0, 1.0]. Returns 0.0 if inputs are empty.
|
||||||
|
"""
|
||||||
|
if not scores or not weights:
|
||||||
|
return 0.0
|
||||||
|
total_weight = sum(weights.get(name, 0.0) for name in scores)
|
||||||
|
if total_weight == 0.0:
|
||||||
|
return sum(scores.values()) / len(scores)
|
||||||
|
return sum(
|
||||||
|
scores.get(name, 0.0) * weights.get(name, 0.0) for name in scores
|
||||||
|
) / total_weight
|
||||||
|
|||||||
@@ -11,8 +11,10 @@ import re
|
|||||||
import dspy
|
import dspy
|
||||||
|
|
||||||
from prometheus.infrastructure.dspy_signatures import (
|
from prometheus.infrastructure.dspy_signatures import (
|
||||||
|
CrossoverInstructions,
|
||||||
GenerateSyntheticInputs,
|
GenerateSyntheticInputs,
|
||||||
JudgeOutput,
|
JudgeOutput,
|
||||||
|
MutateInstruction,
|
||||||
ProposeInstruction,
|
ProposeInstruction,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,19 +55,30 @@ class OutputJudge(dspy.Module):
|
|||||||
self.judge = dspy.ChainOfThought(JudgeOutput)
|
self.judge = dspy.ChainOfThought(JudgeOutput)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, task_description: str, input_text: str, output_text: str
|
self,
|
||||||
|
task_description: str,
|
||||||
|
input_text: str,
|
||||||
|
output_text: str,
|
||||||
|
judge_criteria: str = "",
|
||||||
|
dimension_names: str = "",
|
||||||
) -> dspy.Prediction:
|
) -> dspy.Prediction:
|
||||||
result = self.judge(
|
result = self.judge(
|
||||||
task_description=task_description,
|
task_description=task_description,
|
||||||
input_text=input_text,
|
input_text=input_text,
|
||||||
output_text=output_text,
|
output_text=output_text,
|
||||||
|
judge_criteria=judge_criteria,
|
||||||
|
dimension_names=dimension_names,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
score = float(result.score)
|
score = float(result.score)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
score = 0.5 # neutral fallback
|
score = 0.5 # neutral fallback
|
||||||
score = max(0.0, min(1.0, score))
|
score = max(0.0, min(1.0, score))
|
||||||
return dspy.Prediction(score=score, feedback=result.feedback)
|
return dspy.Prediction(
|
||||||
|
score=score,
|
||||||
|
feedback=result.feedback,
|
||||||
|
dimension_scores=getattr(result, "dimension_scores", "{}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InstructionProposer(dspy.Module):
|
class InstructionProposer(dspy.Module):
|
||||||
@@ -90,3 +103,45 @@ class InstructionProposer(dspy.Module):
|
|||||||
failure_examples=failure_examples,
|
failure_examples=failure_examples,
|
||||||
)
|
)
|
||||||
return dspy.Prediction(new_instruction=result.new_instruction)
|
return dspy.Prediction(new_instruction=result.new_instruction)
|
||||||
|
|
||||||
|
|
||||||
|
class InstructionCrossover(dspy.Module):
|
||||||
|
"""Crossover: combines two parent instructions into a child."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.crossover = dspy.ChainOfThought(CrossoverInstructions)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
parent_a: str,
|
||||||
|
parent_b: str,
|
||||||
|
task_description: str,
|
||||||
|
) -> dspy.Prediction:
|
||||||
|
result = self.crossover(
|
||||||
|
parent_a=parent_a,
|
||||||
|
parent_b=parent_b,
|
||||||
|
task_description=task_description,
|
||||||
|
)
|
||||||
|
return dspy.Prediction(child_instruction=result.child_instruction)
|
||||||
|
|
||||||
|
|
||||||
|
class InstructionMutator(dspy.Module):
|
||||||
|
"""Mutator: applies a typed mutation to an instruction."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.mutate = dspy.ChainOfThought(MutateInstruction)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
current_instruction: str,
|
||||||
|
task_description: str,
|
||||||
|
mutation_type: str,
|
||||||
|
) -> dspy.Prediction:
|
||||||
|
result = self.mutate(
|
||||||
|
current_instruction=current_instruction,
|
||||||
|
task_description=task_description,
|
||||||
|
mutation_type=mutation_type,
|
||||||
|
)
|
||||||
|
return dspy.Prediction(mutated_instruction=result.mutated_instruction)
|
||||||
|
|||||||
@@ -44,6 +44,12 @@ class JudgeOutput(dspy.Signature):
|
|||||||
output_text: str = dspy.InputField(
|
output_text: str = dspy.InputField(
|
||||||
desc="The assistant's response to evaluate."
|
desc="The assistant's response to evaluate."
|
||||||
)
|
)
|
||||||
|
judge_criteria: str = dspy.InputField(
|
||||||
|
desc="Custom evaluation rubric or criteria. Empty string = use default judging criteria."
|
||||||
|
)
|
||||||
|
dimension_names: str = dspy.InputField(
|
||||||
|
desc="Comma-separated dimension names for multi-dimensional scoring. Empty string = single overall score."
|
||||||
|
)
|
||||||
score: float = dspy.OutputField(
|
score: float = dspy.OutputField(
|
||||||
desc="Quality score from 0.0 (wrong) to 1.0 (perfect)."
|
desc="Quality score from 0.0 (wrong) to 1.0 (perfect)."
|
||||||
)
|
)
|
||||||
@@ -53,6 +59,12 @@ class JudgeOutput(dspy.Signature):
|
|||||||
"with the output and how to improve it. Be critical."
|
"with the output and how to improve it. Be critical."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
dimension_scores: str = dspy.OutputField(
|
||||||
|
desc=(
|
||||||
|
"JSON object mapping dimension names to scores (0.0-1.0). "
|
||||||
|
'Empty object {} if no dimensions specified.'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProposeInstruction(dspy.Signature):
|
class ProposeInstruction(dspy.Signature):
|
||||||
@@ -77,3 +89,52 @@ class ProposeInstruction(dspy.Signature):
|
|||||||
new_instruction: str = dspy.OutputField(
|
new_instruction: str = dspy.OutputField(
|
||||||
desc="An improved version of the instruction."
|
desc="An improved version of the instruction."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossoverInstructions(dspy.Signature):
|
||||||
|
"""Combine two instruction prompts into a single improved instruction.
|
||||||
|
|
||||||
|
Take the strongest elements from each parent — structure, phrasing,
|
||||||
|
constraints, examples — and merge them into a coherent child instruction
|
||||||
|
that is strictly better than either parent alone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent_a: str = dspy.InputField(
|
||||||
|
desc="First parent instruction."
|
||||||
|
)
|
||||||
|
parent_b: str = dspy.InputField(
|
||||||
|
desc="Second parent instruction."
|
||||||
|
)
|
||||||
|
task_description: str = dspy.InputField(
|
||||||
|
desc="Description of the task."
|
||||||
|
)
|
||||||
|
child_instruction: str = dspy.OutputField(
|
||||||
|
desc=(
|
||||||
|
"A combined instruction that takes the best elements from "
|
||||||
|
"both parents into a single, coherent instruction."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MutateInstruction(dspy.Signature):
|
||||||
|
"""Apply a specific mutation to an instruction prompt.
|
||||||
|
|
||||||
|
The mutation_type determines the transformation:
|
||||||
|
- paraphrase: restate the instruction in different words
|
||||||
|
- constrain: add specificity, constraints, or guard-rails
|
||||||
|
- generalize: broaden the instruction to cover more cases
|
||||||
|
- specialize: narrow the instruction for better focus on the task
|
||||||
|
"""
|
||||||
|
|
||||||
|
current_instruction: str = dspy.InputField(
|
||||||
|
desc="The instruction to mutate."
|
||||||
|
)
|
||||||
|
task_description: str = dspy.InputField(
|
||||||
|
desc="Description of the task."
|
||||||
|
)
|
||||||
|
mutation_type: str = dspy.InputField(
|
||||||
|
desc="Type of mutation: paraphrase, constrain, generalize, or specialize."
|
||||||
|
)
|
||||||
|
mutated_instruction: str = dspy.OutputField(
|
||||||
|
desc="The mutated instruction, preserving core intent but altered per the mutation type."
|
||||||
|
)
|
||||||
|
|||||||
@@ -6,12 +6,15 @@ Implements the JudgePort via the DSPy OutputJudge module.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Self
|
from typing import Any
|
||||||
|
|
||||||
import dspy
|
import dspy
|
||||||
|
|
||||||
|
from prometheus.application.dto import JudgeDimension
|
||||||
from prometheus.domain.ports import JudgePort
|
from prometheus.domain.ports import JudgePort
|
||||||
|
from prometheus.domain.scoring import weighted_aggregate
|
||||||
from prometheus.infrastructure.dspy_modules import OutputJudge
|
from prometheus.infrastructure.dspy_modules import OutputJudge
|
||||||
from prometheus.infrastructure.retry import async_retry_with_backoff
|
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||||
|
|
||||||
@@ -25,6 +28,9 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
instead of crashing the whole batch.
|
instead of crashing the whole batch.
|
||||||
|
|
||||||
Judge calls run in parallel (bounded by *max_concurrency*).
|
Judge calls run in parallel (bounded by *max_concurrency*).
|
||||||
|
|
||||||
|
When *judge_criteria* or *judge_dimensions* are provided, the judge applies
|
||||||
|
custom rubrics and/or multi-dimensional scoring with weighted aggregation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -33,12 +39,26 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay_base: float = 1.0,
|
retry_delay_base: float = 1.0,
|
||||||
max_concurrency: int = 5,
|
max_concurrency: int = 5,
|
||||||
|
judge_criteria: str | None = None,
|
||||||
|
judge_dimensions: list[JudgeDimension] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._lm = lm
|
self._lm = lm
|
||||||
self._judge = OutputJudge()
|
self._judge = OutputJudge()
|
||||||
self._max_retries = max_retries
|
self._max_retries = max_retries
|
||||||
self._retry_delay_base = retry_delay_base
|
self._retry_delay_base = retry_delay_base
|
||||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||||
|
self._judge_criteria = judge_criteria or ""
|
||||||
|
self._judge_dimensions = judge_dimensions or []
|
||||||
|
self._dimension_names = (
|
||||||
|
",".join(d.name for d in self._judge_dimensions)
|
||||||
|
if self._judge_dimensions
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
self._weights: dict[str, float] = (
|
||||||
|
{d.name: d.weight for d in self._judge_dimensions}
|
||||||
|
if self._judge_dimensions
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
async def judge_batch(
|
async def judge_batch(
|
||||||
self,
|
self,
|
||||||
@@ -74,7 +94,7 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
pred = await asyncio.to_thread(
|
pred = await asyncio.to_thread(
|
||||||
self._sync_judge, task_description, input_text, output_text,
|
self._sync_judge, task_description, input_text, output_text,
|
||||||
)
|
)
|
||||||
return (pred.score, pred.feedback)
|
return self._aggregate_result(pred)
|
||||||
|
|
||||||
return await async_retry_with_backoff(
|
return await async_retry_with_backoff(
|
||||||
_call,
|
_call,
|
||||||
@@ -88,4 +108,35 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
task_description=task_description,
|
task_description=task_description,
|
||||||
input_text=input_text,
|
input_text=input_text,
|
||||||
output_text=output_text,
|
output_text=output_text,
|
||||||
|
judge_criteria=self._judge_criteria,
|
||||||
|
dimension_names=self._dimension_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _aggregate_result(self, pred: Any) -> tuple[float, str]:
|
||||||
|
"""Compute weighted aggregate score from dimension scores if available."""
|
||||||
|
if not self._judge_dimensions:
|
||||||
|
return (pred.score, pred.feedback)
|
||||||
|
|
||||||
|
# Parse per-dimension scores from LLM output
|
||||||
|
dim_scores: dict[str, float] = {}
|
||||||
|
try:
|
||||||
|
raw = json.loads(pred.dimension_scores)
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
for name in self._weights:
|
||||||
|
val = raw.get(name)
|
||||||
|
if val is not None:
|
||||||
|
dim_scores[name] = max(0.0, min(1.0, float(val)))
|
||||||
|
except (json.JSONDecodeError, ValueError, TypeError):
|
||||||
|
logger.debug("Failed to parse dimension_scores, falling back to overall score")
|
||||||
|
|
||||||
|
if not dim_scores:
|
||||||
|
return (pred.score, pred.feedback)
|
||||||
|
|
||||||
|
aggregate = weighted_aggregate(dim_scores, self._weights)
|
||||||
|
# Enrich feedback with per-dimension breakdown
|
||||||
|
dim_breakdown = ", ".join(
|
||||||
|
f"{name}={dim_scores.get(name, 0.0):.2f}"
|
||||||
|
for name in self._weights
|
||||||
|
)
|
||||||
|
feedback = f"{pred.feedback} [{dim_breakdown}]"
|
||||||
|
return (aggregate, feedback)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def judge_lm() -> dspy.LM:
|
|||||||
"""Dummy LM for judging (ChainOfThought requires reasoning field)."""
|
"""Dummy LM for judging (ChainOfThought requires reasoning field)."""
|
||||||
return dspy.utils.DummyLM(
|
return dspy.utils.DummyLM(
|
||||||
[
|
[
|
||||||
{"reasoning": "Evaluating output.", "score": "0.8", "feedback": "Good response."},
|
{"reasoning": "Evaluating output.", "score": "0.8", "feedback": "Good response.", "dimension_scores": "{}"},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -99,9 +99,9 @@ class TestDSPyJudgeAdapterOwnLM:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_does_not_use_global_lm(self) -> None:
|
async def test_does_not_use_global_lm(self) -> None:
|
||||||
judge_lm = dspy.utils.DummyLM(
|
judge_lm = dspy.utils.DummyLM(
|
||||||
[{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response"}]
|
[{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response", "dimension_scores": "{}"}]
|
||||||
)
|
)
|
||||||
global_lm = dspy.utils.DummyLM([{"reasoning": "no", "score": "0.1", "feedback": "Wrong LM!"}])
|
global_lm = dspy.utils.DummyLM([{"reasoning": "no", "score": "0.1", "feedback": "Wrong LM!", "dimension_scores": "{}"}])
|
||||||
dspy.configure(lm=global_lm)
|
dspy.configure(lm=global_lm)
|
||||||
|
|
||||||
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||||
@@ -176,7 +176,7 @@ class TestDSPySyntheticAdapterOwnLM:
|
|||||||
class TestPerModelOverrides:
|
class TestPerModelOverrides:
|
||||||
"""Verify that per-model api_base/api_key_env are passed through to dspy.LM."""
|
"""Verify that per-model api_base/api_key_env are passed through to dspy.LM."""
|
||||||
|
|
||||||
@patch("prometheus.cli.app.dspy.LM")
|
@patch("prometheus.cli.commands.optimize.dspy.LM")
|
||||||
def test_per_model_api_base_override(self, mock_lm_cls: MagicMock) -> None:
|
def test_per_model_api_base_override(self, mock_lm_cls: MagicMock) -> None:
|
||||||
"""Per-model api_base should be used instead of global."""
|
"""Per-model api_base should be used instead of global."""
|
||||||
mock_lm_cls.return_value = MagicMock()
|
mock_lm_cls.return_value = MagicMock()
|
||||||
|
|||||||
@@ -124,7 +124,6 @@ class TestCircuitBreaker:
|
|||||||
circuit_breaker_threshold=3,
|
circuit_breaker_threshold=3,
|
||||||
error_strategy="skip",
|
error_strategy="skip",
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
|
||||||
state = await loop.run(
|
state = await loop.run(
|
||||||
Prompt("test"),
|
Prompt("test"),
|
||||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||||
@@ -165,7 +164,6 @@ class TestCircuitBreaker:
|
|||||||
circuit_breaker_threshold=3,
|
circuit_breaker_threshold=3,
|
||||||
error_strategy="abort",
|
error_strategy="abort",
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
|
||||||
with pytest.raises(RuntimeError, match="LLM down"):
|
with pytest.raises(RuntimeError, match="LLM down"):
|
||||||
await loop.run(
|
await loop.run(
|
||||||
Prompt("test"),
|
Prompt("test"),
|
||||||
@@ -216,7 +214,6 @@ class TestCircuitBreaker:
|
|||||||
circuit_breaker_threshold=3,
|
circuit_breaker_threshold=3,
|
||||||
error_strategy="skip",
|
error_strategy="skip",
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
|
||||||
state = await loop.run(
|
state = await loop.run(
|
||||||
Prompt("test"),
|
Prompt("test"),
|
||||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||||
@@ -277,6 +274,10 @@ class TestPerCallIsolation:
|
|||||||
adapter._max_retries = 1
|
adapter._max_retries = 1
|
||||||
adapter._retry_delay_base = 0
|
adapter._retry_delay_base = 0
|
||||||
adapter._semaphore = __import__("asyncio").Semaphore(5)
|
adapter._semaphore = __import__("asyncio").Semaphore(5)
|
||||||
|
adapter._judge_criteria = ""
|
||||||
|
adapter._judge_dimensions = []
|
||||||
|
adapter._dimension_names = ""
|
||||||
|
adapter._weights = {}
|
||||||
|
|
||||||
# Mock _judge to fail on first call, succeed on second
|
# Mock _judge to fail on first call, succeed on second
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user