feat: Pydantic config validation with clear CLI error messages
Convert OptimizationConfig from dataclass to Pydantic BaseModel with field validators for ranges, types, and enum values. Missing/invalid fields now produce actionable CLI errors instead of cryptic KeyErrors. - Range validators: max_iterations>=1, minibatch_size>=1, seed>=0, etc. - Enum validator: error_strategy must be skip|retry|abort - Config migration hook via config_version field - CLI catches ValidationError and prints per-field error messages - Remove unused AppSettings class (Bug #7) - 30 unit tests covering all validation edge cases Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -12,6 +12,7 @@ from dataclasses import asdict
|
||||
|
||||
import dspy
|
||||
import typer
|
||||
from pydantic import ValidationError
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
@@ -102,75 +103,58 @@ async def _async_optimize(
|
||||
)
|
||||
)
|
||||
|
||||
# 1. Load config
|
||||
# 1. Load & validate config
|
||||
persistence = YamlPersistence()
|
||||
raw_config = persistence.read_config(input)
|
||||
|
||||
# CLI flags override config file values
|
||||
raw_config.setdefault("max_retries", max_retries)
|
||||
raw_config.setdefault("error_strategy", error_strategy)
|
||||
raw_config.setdefault("max_concurrency", max_concurrency)
|
||||
raw_config["output_path"] = output
|
||||
raw_config["verbose"] = verbose
|
||||
|
||||
try:
|
||||
config = OptimizationConfig.model_validate(raw_config)
|
||||
except ValidationError as exc:
|
||||
console.print("[bold red]Configuration error:[/bold red]\n")
|
||||
for err in exc.errors():
|
||||
loc = " → ".join(str(l) for l in err["loc"])
|
||||
console.print(f" [red]• {loc}: {err['msg']}[/red]")
|
||||
raise typer.Exit(code=1) from exc
|
||||
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
||||
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
||||
|
||||
# 2. Create per-model DSPy LM instances
|
||||
def _model_lm_kwargs(
|
||||
model_api_base: str | None,
|
||||
model_api_key_env: str | None,
|
||||
global_api_base: str | None,
|
||||
global_api_key_env: str | None,
|
||||
) -> dict:
|
||||
"""Build kwargs for dspy.LM, using per-model overrides with global fallback."""
|
||||
kwargs: dict = {}
|
||||
api_base = model_api_base or global_api_base
|
||||
api_key_env = model_api_key_env or global_api_key_env
|
||||
api_base = model_api_base or config.api_base
|
||||
api_key_env = model_api_key_env or config.api_key_env
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
if api_key_env:
|
||||
kwargs["api_key"] = os.environ.get(api_key_env, "")
|
||||
return kwargs
|
||||
|
||||
global_api_base = raw_config.get("api_base")
|
||||
global_api_key_env = raw_config.get("api_key_env")
|
||||
|
||||
config = OptimizationConfig(
|
||||
seed_prompt=raw_config["seed_prompt"],
|
||||
task_description=raw_config["task_description"],
|
||||
task_model=raw_config.get("task_model", "openai/gpt-4o-mini"),
|
||||
judge_model=raw_config.get("judge_model", "openai/gpt-4o"),
|
||||
proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"),
|
||||
synth_model=raw_config.get("synth_model", "openai/gpt-4o"),
|
||||
task_api_base=raw_config.get("task_api_base"),
|
||||
task_api_key_env=raw_config.get("task_api_key_env"),
|
||||
judge_api_base=raw_config.get("judge_api_base"),
|
||||
judge_api_key_env=raw_config.get("judge_api_key_env"),
|
||||
proposer_api_base=raw_config.get("proposer_api_base"),
|
||||
proposer_api_key_env=raw_config.get("proposer_api_key_env"),
|
||||
synth_api_base=raw_config.get("synth_api_base"),
|
||||
synth_api_key_env=raw_config.get("synth_api_key_env"),
|
||||
max_iterations=raw_config.get("max_iterations", 30),
|
||||
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
|
||||
minibatch_size=raw_config.get("minibatch_size", 5),
|
||||
seed=raw_config.get("seed", 42),
|
||||
max_retries=raw_config.get("max_retries", max_retries),
|
||||
retry_delay_base=raw_config.get("retry_delay_base", 1.0),
|
||||
circuit_breaker_threshold=raw_config.get("circuit_breaker_threshold", 5),
|
||||
error_strategy=raw_config.get("error_strategy", error_strategy),
|
||||
max_concurrency=raw_config.get("max_concurrency", max_concurrency),
|
||||
output_path=output,
|
||||
verbose=verbose,
|
||||
)
|
||||
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
||||
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
||||
|
||||
# 2. Create per-model DSPy LM instances
|
||||
task_lm = dspy.LM(
|
||||
config.task_model,
|
||||
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env, global_api_base, global_api_key_env),
|
||||
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env),
|
||||
)
|
||||
judge_lm = dspy.LM(
|
||||
config.judge_model,
|
||||
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env, global_api_base, global_api_key_env),
|
||||
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env),
|
||||
)
|
||||
proposer_lm = dspy.LM(
|
||||
config.proposer_model,
|
||||
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env, global_api_base, global_api_key_env),
|
||||
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env),
|
||||
)
|
||||
synth_lm = dspy.LM(
|
||||
config.synth_model,
|
||||
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env),
|
||||
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env),
|
||||
)
|
||||
|
||||
# 3. Build adapters (Dependency Injection — each gets its own LM + retry config)
|
||||
|
||||
Reference in New Issue
Block a user