feat: Pydantic config validation with clear CLI error messages

Convert OptimizationConfig from dataclass to Pydantic BaseModel with
field validators for ranges, types, and enum values. Missing/invalid
fields now produce actionable CLI errors instead of cryptic KeyErrors.

- Range validators: max_iterations>=1, minibatch_size>=1, seed>=0, etc.
- Enum validator: error_strategy must be skip|retry|abort
- Config migration hook via config_version field
- CLI catches ValidationError and prints per-field error messages
- Remove unused AppSettings class (Bug #7)
- 30 unit tests covering all validation edge cases

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
FullStackDev
2026-03-29 13:25:44 +00:00
parent c92ca4a2b8
commit 336774a164
4 changed files with 404 additions and 74 deletions

View File

@@ -4,20 +4,48 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator
@dataclass
class OptimizationConfig:
"""Complete configuration for a PROMETHEUS run."""
# --- Prompt --- # Current config schema version.
seed_prompt: str CONFIG_VERSION = 1
task_description: str
_ERROR_STRATEGY_VALUES = {"skip", "retry", "abort"}
class OptimizationConfig(BaseModel):
"""Complete configuration for a PROMETHEUS run.
Validated with Pydantic so missing or wrong-type fields produce clear,
actionable error messages at the CLI boundary instead of cryptic failures
deep in the pipeline.
"""
model_config = {"extra": "forbid"}
# --- Schema version (for migration support) ---
config_version: int = Field(
default=CONFIG_VERSION,
description="Config schema version. Set automatically; used for migration.",
)
# --- Prompt (required) ---
seed_prompt: str = Field(
...,
min_length=1,
description="The initial prompt to optimize.",
)
task_description: str = Field(
...,
min_length=1,
description="Description of the task the prompt should accomplish.",
)
# --- Models --- # --- Models ---
task_model: str = "openai/gpt-4o-mini" task_model: str = Field(default="openai/gpt-4o-mini", min_length=1)
judge_model: str = "openai/gpt-4o" judge_model: str = Field(default="openai/gpt-4o", min_length=1)
proposer_model: str = "openai/gpt-4o" proposer_model: str = Field(default="openai/gpt-4o", min_length=1)
synth_model: str = "openai/gpt-4o" synth_model: str = Field(default="openai/gpt-4o", min_length=1)
# --- Per-model API overrides (optional, fall back to global api_base/api_key_env) --- # --- Per-model API overrides (optional, fall back to global api_base/api_key_env) ---
task_api_base: str | None = None task_api_base: str | None = None
@@ -29,28 +57,54 @@ class OptimizationConfig:
synth_api_base: str | None = None synth_api_base: str | None = None
synth_api_key_env: str | None = None synth_api_key_env: str | None = None
# --- Global API settings (optional) ---
api_base: str | None = None
api_key_env: str | None = None
# --- Evolution parameters --- # --- Evolution parameters ---
max_iterations: int = 30 max_iterations: int = Field(default=30, ge=1, description="Maximum evolution iterations.")
n_synthetic_inputs: int = 20 n_synthetic_inputs: int = Field(default=20, ge=1, description="Number of synthetic inputs to generate.")
minibatch_size: int = 5 minibatch_size: int = Field(default=5, ge=1, description="Inputs per evaluation minibatch.")
perfect_score: float = 1.0 perfect_score: float = Field(default=1.0, ge=0.0, le=1.0)
# --- Reproducibility --- # --- Reproducibility ---
seed: int = 42 seed: int = Field(default=42, ge=0)
# --- Concurrency --- # --- Concurrency ---
max_concurrency: int = 5 max_concurrency: int = Field(default=5, ge=1, description="Max parallel LLM calls.")
# --- Error handling --- # --- Error handling ---
max_retries: int = 3 max_retries: int = Field(default=3, ge=0, description="Max retry attempts for transient errors.")
retry_delay_base: float = 1.0 retry_delay_base: float = Field(default=1.0, gt=0, description="Base delay in seconds for retry backoff.")
circuit_breaker_threshold: int = 5 circuit_breaker_threshold: int = Field(default=5, ge=1, description="Consecutive failures before circuit opens.")
error_strategy: str = "retry" # skip | retry | abort error_strategy: str = Field(default="retry", description="Error handling strategy: skip | retry | abort.")
# --- Output --- # --- Output ---
output_path: str = "output.yaml" output_path: str = Field(default="output.yaml", min_length=1)
verbose: bool = False verbose: bool = False
@field_validator("error_strategy")
@classmethod
def _validate_error_strategy(cls, v: str) -> str:
if v not in _ERROR_STRATEGY_VALUES:
raise ValueError(
f"error_strategy must be one of {sorted(_ERROR_STRATEGY_VALUES)}, got '{v}'"
)
return v
@model_validator(mode="before")
@classmethod
def _migrate_config(cls, data: Any) -> Any:
"""Apply migration transforms for older config versions."""
if isinstance(data, dict):
version = data.get("config_version", CONFIG_VERSION)
# Future migrations go here, e.g.:
# if version < 2:
# data = _migrate_v1_to_v2(data)
# Always stamp current version.
data["config_version"] = CONFIG_VERSION
return data
@dataclass @dataclass
class OptimizationResult: class OptimizationResult:

View File

@@ -12,6 +12,7 @@ from dataclasses import asdict
import dspy import dspy
import typer import typer
from pydantic import ValidationError
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.table import Table from rich.table import Table
@@ -102,75 +103,58 @@ async def _async_optimize(
) )
) )
# 1. Load config # 1. Load & validate config
persistence = YamlPersistence() persistence = YamlPersistence()
raw_config = persistence.read_config(input) raw_config = persistence.read_config(input)
# CLI flags override config file values
raw_config.setdefault("max_retries", max_retries)
raw_config.setdefault("error_strategy", error_strategy)
raw_config.setdefault("max_concurrency", max_concurrency)
raw_config["output_path"] = output
raw_config["verbose"] = verbose
try:
config = OptimizationConfig.model_validate(raw_config)
except ValidationError as exc:
console.print("[bold red]Configuration error:[/bold red]\n")
for err in exc.errors():
loc = "".join(str(l) for l in err["loc"])
console.print(f" [red]• {loc}: {err['msg']}[/red]")
raise typer.Exit(code=1) from exc
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
# 2. Create per-model DSPy LM instances
def _model_lm_kwargs( def _model_lm_kwargs(
model_api_base: str | None, model_api_base: str | None,
model_api_key_env: str | None, model_api_key_env: str | None,
global_api_base: str | None,
global_api_key_env: str | None,
) -> dict: ) -> dict:
"""Build kwargs for dspy.LM, using per-model overrides with global fallback.""" """Build kwargs for dspy.LM, using per-model overrides with global fallback."""
kwargs: dict = {} kwargs: dict = {}
api_base = model_api_base or global_api_base api_base = model_api_base or config.api_base
api_key_env = model_api_key_env or global_api_key_env api_key_env = model_api_key_env or config.api_key_env
if api_base: if api_base:
kwargs["api_base"] = api_base kwargs["api_base"] = api_base
if api_key_env: if api_key_env:
kwargs["api_key"] = os.environ.get(api_key_env, "") kwargs["api_key"] = os.environ.get(api_key_env, "")
return kwargs return kwargs
global_api_base = raw_config.get("api_base")
global_api_key_env = raw_config.get("api_key_env")
config = OptimizationConfig(
seed_prompt=raw_config["seed_prompt"],
task_description=raw_config["task_description"],
task_model=raw_config.get("task_model", "openai/gpt-4o-mini"),
judge_model=raw_config.get("judge_model", "openai/gpt-4o"),
proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"),
synth_model=raw_config.get("synth_model", "openai/gpt-4o"),
task_api_base=raw_config.get("task_api_base"),
task_api_key_env=raw_config.get("task_api_key_env"),
judge_api_base=raw_config.get("judge_api_base"),
judge_api_key_env=raw_config.get("judge_api_key_env"),
proposer_api_base=raw_config.get("proposer_api_base"),
proposer_api_key_env=raw_config.get("proposer_api_key_env"),
synth_api_base=raw_config.get("synth_api_base"),
synth_api_key_env=raw_config.get("synth_api_key_env"),
max_iterations=raw_config.get("max_iterations", 30),
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
minibatch_size=raw_config.get("minibatch_size", 5),
seed=raw_config.get("seed", 42),
max_retries=raw_config.get("max_retries", max_retries),
retry_delay_base=raw_config.get("retry_delay_base", 1.0),
circuit_breaker_threshold=raw_config.get("circuit_breaker_threshold", 5),
error_strategy=raw_config.get("error_strategy", error_strategy),
max_concurrency=raw_config.get("max_concurrency", max_concurrency),
output_path=output,
verbose=verbose,
)
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
# 2. Create per-model DSPy LM instances
task_lm = dspy.LM( task_lm = dspy.LM(
config.task_model, config.task_model,
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env, global_api_base, global_api_key_env), **_model_lm_kwargs(config.task_api_base, config.task_api_key_env),
) )
judge_lm = dspy.LM( judge_lm = dspy.LM(
config.judge_model, config.judge_model,
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env, global_api_base, global_api_key_env), **_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env),
) )
proposer_lm = dspy.LM( proposer_lm = dspy.LM(
config.proposer_model, config.proposer_model,
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env, global_api_base, global_api_key_env), **_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env),
) )
synth_lm = dspy.LM( synth_lm = dspy.LM(
config.synth_model, config.synth_model,
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env), **_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env),
) )
# 3. Build adapters (Dependency Injection — each gets its own LM + retry config) # 3. Build adapters (Dependency Injection — each gets its own LM + retry config)

View File

@@ -1,12 +1,2 @@
"""Application settings.""" """Application settings."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
@dataclass
class AppSettings:
"""Non-sensitive settings, hardcoded for the MVP."""
app_name: str = "prometheus"
version: str = "0.1.0"

302
tests/unit/test_config.py Normal file
View File

@@ -0,0 +1,302 @@
"""Unit tests for config and config loading via CLI.
Tests config validation scenarios: missing fields, wrong types, defaults.
"""
from __future__ import annotations
from pathlib import Path
import pytest
import yaml
from pydantic import ValidationError
from prometheus.application.dto import OptimizationConfig
from prometheus.infrastructure.file_io import YamlPersistence
class TestOptimizationConfig:
"""Tests for OptimizationConfig Pydantic model defaults."""
def test_default_values(self) -> None:
config = OptimizationConfig(
seed_prompt="test prompt",
task_description="test task",
)
assert config.task_model == "openai/gpt-4o-mini"
assert config.judge_model == "openai/gpt-4o"
assert config.proposer_model == "openai/gpt-4o"
assert config.synth_model == "openai/gpt-4o"
assert config.max_iterations == 30
assert config.n_synthetic_inputs == 20
assert config.minibatch_size == 5
assert config.perfect_score == 1.0
assert config.seed == 42
assert config.output_path == "output.yaml"
assert config.verbose is False
assert config.error_strategy == "retry"
assert config.max_retries == 3
assert config.max_concurrency == 5
def test_custom_values(self) -> None:
config = OptimizationConfig(
seed_prompt="custom prompt",
task_description="custom task",
max_iterations=100,
minibatch_size=10,
seed=123,
verbose=True,
)
assert config.max_iterations == 100
assert config.minibatch_size == 10
assert config.seed == 123
assert config.verbose is True
def test_roundtrip_to_dict(self) -> None:
config = OptimizationConfig(
seed_prompt="test",
task_description="task",
)
d = config.model_dump()
assert d["seed_prompt"] == "test"
assert d["task_description"] == "task"
assert "history" not in d # OptimizationResult has history, not config
def test_config_version_defaults(self) -> None:
config = OptimizationConfig(seed_prompt="a", task_description="b")
assert config.config_version == 1
def test_config_version_stamps_current(self) -> None:
config = OptimizationConfig(
seed_prompt="a", task_description="b", config_version=0,
)
# Migration always stamps current version
assert config.config_version == 1
class TestConfigLoading:
"""Tests for loading OptimizationConfig from YAML via YamlPersistence."""
def test_minimal_config_loads(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
data = {
"seed_prompt": "You are helpful.",
"task_description": "Answer questions.",
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
raw = persistence.read_config(str(config_file))
config = OptimizationConfig.model_validate(raw)
assert config.seed_prompt == "You are helpful."
assert config.task_description == "Answer questions."
assert config.max_iterations == 30 # default
def test_full_config_loads(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
data = {
"seed_prompt": "You are helpful.",
"task_description": "Answer questions.",
"task_model": "openai/gpt-4o",
"judge_model": "openai/gpt-4o-mini",
"max_iterations": 50,
"n_synthetic_inputs": 30,
"minibatch_size": 8,
"seed": 99,
"verbose": True,
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
raw = persistence.read_config(str(config_file))
config = OptimizationConfig.model_validate(raw)
assert config.task_model == "openai/gpt-4o"
assert config.max_iterations == 50
assert config.verbose is True
def test_missing_seed_prompt_raises_validation_error(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
data = {"task_description": "Answer questions."}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
raw = persistence.read_config(str(config_file))
with pytest.raises(ValidationError, match="seed_prompt"):
OptimizationConfig.model_validate(raw)
def test_missing_task_description_raises_validation_error(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
data = {"seed_prompt": "You are helpful."}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
raw = persistence.read_config(str(config_file))
with pytest.raises(ValidationError, match="task_description"):
OptimizationConfig.model_validate(raw)
def test_empty_yaml_raises_on_required_fields(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
config_file = tmp_path / "empty.yaml"
config_file.write_text("{}", encoding="utf-8")
raw = persistence.read_config(str(config_file))
with pytest.raises(ValidationError):
OptimizationConfig.model_validate(raw)
def test_partial_config_uses_defaults(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
data = {
"seed_prompt": "test",
"task_description": "task",
"max_iterations": 10,
}
config_file = tmp_path / "partial.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
raw = persistence.read_config(str(config_file))
config = OptimizationConfig.model_validate(raw)
assert config.max_iterations == 10
assert config.n_synthetic_inputs == 20 # default
assert config.minibatch_size == 5 # default
class TestConfigValidation:
"""Tests for Pydantic validation edge cases."""
def test_wrong_type_max_iterations(self) -> None:
with pytest.raises(ValidationError, match="max_iterations"):
OptimizationConfig(
seed_prompt="a",
task_description="b",
max_iterations="not_a_number", # type: ignore[arg-type]
)
def test_wrong_type_seed(self) -> None:
with pytest.raises(ValidationError, match="seed"):
OptimizationConfig(
seed_prompt="a",
task_description="b",
seed="not_a_number", # type: ignore[arg-type]
)
def test_negative_max_iterations(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 1"):
OptimizationConfig(
seed_prompt="a", task_description="b", max_iterations=0,
)
def test_negative_seed(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 0"):
OptimizationConfig(
seed_prompt="a", task_description="b", seed=-1,
)
def test_zero_minibatch_size(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 1"):
OptimizationConfig(
seed_prompt="a", task_description="b", minibatch_size=0,
)
def test_zero_max_concurrency(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 1"):
OptimizationConfig(
seed_prompt="a", task_description="b", max_concurrency=0,
)
def test_negative_max_retries(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 0"):
OptimizationConfig(
seed_prompt="a", task_description="b", max_retries=-1,
)
def test_zero_retry_delay_base(self) -> None:
with pytest.raises(ValidationError, match="greater than 0"):
OptimizationConfig(
seed_prompt="a", task_description="b", retry_delay_base=0,
)
def test_negative_retry_delay_base(self) -> None:
with pytest.raises(ValidationError, match="greater than 0"):
OptimizationConfig(
seed_prompt="a", task_description="b", retry_delay_base=-0.5,
)
def test_zero_circuit_breaker_threshold(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 1"):
OptimizationConfig(
seed_prompt="a", task_description="b", circuit_breaker_threshold=0,
)
def test_perfect_score_above_one(self) -> None:
with pytest.raises(ValidationError, match="less than or equal to 1"):
OptimizationConfig(
seed_prompt="a", task_description="b", perfect_score=1.5,
)
def test_perfect_score_negative(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 0"):
OptimizationConfig(
seed_prompt="a", task_description="b", perfect_score=-0.1,
)
def test_invalid_error_strategy(self) -> None:
with pytest.raises(ValidationError, match="error_strategy must be one of"):
OptimizationConfig(
seed_prompt="a", task_description="b", error_strategy="invalid",
)
def test_valid_error_strategies(self) -> None:
for strategy in ("skip", "retry", "abort"):
config = OptimizationConfig(
seed_prompt="a", task_description="b", error_strategy=strategy,
)
assert config.error_strategy == strategy
def test_empty_seed_prompt(self) -> None:
with pytest.raises(ValidationError, match="seed_prompt"):
OptimizationConfig(seed_prompt="", task_description="b")
def test_empty_task_description(self) -> None:
with pytest.raises(ValidationError, match="task_description"):
OptimizationConfig(seed_prompt="a", task_description="")
def test_empty_model_string(self) -> None:
with pytest.raises(ValidationError, match="task_model"):
OptimizationConfig(
seed_prompt="a", task_description="b", task_model="",
)
def test_extra_fields_rejected(self) -> None:
with pytest.raises(ValidationError, match="extra"):
OptimizationConfig(
seed_prompt="a",
task_description="b",
nonexistent_field="value", # type: ignore[call-arg]
)
def test_boundary_values_accepted(self) -> None:
config = OptimizationConfig(
seed_prompt="a",
task_description="b",
max_iterations=1,
n_synthetic_inputs=1,
minibatch_size=1,
max_concurrency=1,
max_retries=0,
retry_delay_base=0.001,
circuit_breaker_threshold=1,
perfect_score=0.0,
seed=0,
)
assert config.max_iterations == 1
assert config.perfect_score == 0.0