diff --git a/src/prometheus/application/dto.py b/src/prometheus/application/dto.py index 3292d8d..931a095 100644 --- a/src/prometheus/application/dto.py +++ b/src/prometheus/application/dto.py @@ -4,20 +4,48 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Any +from pydantic import BaseModel, Field, field_validator, model_validator -@dataclass -class OptimizationConfig: - """Complete configuration for a PROMETHEUS run.""" - # --- Prompt --- - seed_prompt: str - task_description: str +# Current config schema version. +CONFIG_VERSION = 1 + +_ERROR_STRATEGY_VALUES = {"skip", "retry", "abort"} + + +class OptimizationConfig(BaseModel): + """Complete configuration for a PROMETHEUS run. + + Validated with Pydantic so missing or wrong-type fields produce clear, + actionable error messages at the CLI boundary instead of cryptic failures + deep in the pipeline. + """ + + model_config = {"extra": "forbid"} + + # --- Schema version (for migration support) --- + config_version: int = Field( + default=CONFIG_VERSION, + description="Config schema version. Set automatically; used for migration.", + ) + + # --- Prompt (required) --- + seed_prompt: str = Field( + ..., + min_length=1, + description="The initial prompt to optimize.", + ) + task_description: str = Field( + ..., + min_length=1, + description="Description of the task the prompt should accomplish.", + ) # --- Models --- - task_model: str = "openai/gpt-4o-mini" - judge_model: str = "openai/gpt-4o" - proposer_model: str = "openai/gpt-4o" - synth_model: str = "openai/gpt-4o" + task_model: str = Field(default="openai/gpt-4o-mini", min_length=1) + judge_model: str = Field(default="openai/gpt-4o", min_length=1) + proposer_model: str = Field(default="openai/gpt-4o", min_length=1) + synth_model: str = Field(default="openai/gpt-4o", min_length=1) # --- Per-model API overrides (optional, fall back to global api_base/api_key_env) --- task_api_base: str | None = None @@ -29,28 +57,54 @@ class OptimizationConfig: synth_api_base: str | None = None synth_api_key_env: str | None = None + # --- Global API settings (optional) --- + api_base: str | None = None + api_key_env: str | None = None + # --- Evolution parameters --- - max_iterations: int = 30 - n_synthetic_inputs: int = 20 - minibatch_size: int = 5 - perfect_score: float = 1.0 + max_iterations: int = Field(default=30, ge=1, description="Maximum evolution iterations.") + n_synthetic_inputs: int = Field(default=20, ge=1, description="Number of synthetic inputs to generate.") + minibatch_size: int = Field(default=5, ge=1, description="Inputs per evaluation minibatch.") + perfect_score: float = Field(default=1.0, ge=0.0, le=1.0) # --- Reproducibility --- - seed: int = 42 + seed: int = Field(default=42, ge=0) # --- Concurrency --- - max_concurrency: int = 5 + max_concurrency: int = Field(default=5, ge=1, description="Max parallel LLM calls.") # --- Error handling --- - max_retries: int = 3 - retry_delay_base: float = 1.0 - circuit_breaker_threshold: int = 5 - error_strategy: str = "retry" # skip | retry | abort + max_retries: int = Field(default=3, ge=0, description="Max retry attempts for transient errors.") + retry_delay_base: float = Field(default=1.0, gt=0, description="Base delay in seconds for retry backoff.") + circuit_breaker_threshold: int = Field(default=5, ge=1, description="Consecutive failures before circuit opens.") + error_strategy: str = Field(default="retry", description="Error handling strategy: skip | retry | abort.") # --- Output --- - output_path: str = "output.yaml" + output_path: str = Field(default="output.yaml", min_length=1) verbose: bool = False + @field_validator("error_strategy") + @classmethod + def _validate_error_strategy(cls, v: str) -> str: + if v not in _ERROR_STRATEGY_VALUES: + raise ValueError( + f"error_strategy must be one of {sorted(_ERROR_STRATEGY_VALUES)}, got '{v}'" + ) + return v + + @model_validator(mode="before") + @classmethod + def _migrate_config(cls, data: Any) -> Any: + """Apply migration transforms for older config versions.""" + if isinstance(data, dict): + version = data.get("config_version", CONFIG_VERSION) + # Future migrations go here, e.g.: + # if version < 2: + # data = _migrate_v1_to_v2(data) + # Always stamp current version. + data["config_version"] = CONFIG_VERSION + return data + @dataclass class OptimizationResult: diff --git a/src/prometheus/cli/app.py b/src/prometheus/cli/app.py index 6da3017..776f70e 100644 --- a/src/prometheus/cli/app.py +++ b/src/prometheus/cli/app.py @@ -12,6 +12,7 @@ from dataclasses import asdict import dspy import typer +from pydantic import ValidationError from rich.console import Console from rich.panel import Panel from rich.table import Table @@ -102,75 +103,58 @@ async def _async_optimize( ) ) - # 1. Load config + # 1. Load & validate config persistence = YamlPersistence() raw_config = persistence.read_config(input) + # CLI flags override config file values + raw_config.setdefault("max_retries", max_retries) + raw_config.setdefault("error_strategy", error_strategy) + raw_config.setdefault("max_concurrency", max_concurrency) + raw_config["output_path"] = output + raw_config["verbose"] = verbose + + try: + config = OptimizationConfig.model_validate(raw_config) + except ValidationError as exc: + console.print("[bold red]Configuration error:[/bold red]\n") + for err in exc.errors(): + loc = " → ".join(str(l) for l in err["loc"]) + console.print(f" [red]• {loc}: {err['msg']}[/red]") + raise typer.Exit(code=1) from exc + console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]") + console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]") + + # 2. Create per-model DSPy LM instances def _model_lm_kwargs( model_api_base: str | None, model_api_key_env: str | None, - global_api_base: str | None, - global_api_key_env: str | None, ) -> dict: """Build kwargs for dspy.LM, using per-model overrides with global fallback.""" kwargs: dict = {} - api_base = model_api_base or global_api_base - api_key_env = model_api_key_env or global_api_key_env + api_base = model_api_base or config.api_base + api_key_env = model_api_key_env or config.api_key_env if api_base: kwargs["api_base"] = api_base if api_key_env: kwargs["api_key"] = os.environ.get(api_key_env, "") return kwargs - global_api_base = raw_config.get("api_base") - global_api_key_env = raw_config.get("api_key_env") - - config = OptimizationConfig( - seed_prompt=raw_config["seed_prompt"], - task_description=raw_config["task_description"], - task_model=raw_config.get("task_model", "openai/gpt-4o-mini"), - judge_model=raw_config.get("judge_model", "openai/gpt-4o"), - proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"), - synth_model=raw_config.get("synth_model", "openai/gpt-4o"), - task_api_base=raw_config.get("task_api_base"), - task_api_key_env=raw_config.get("task_api_key_env"), - judge_api_base=raw_config.get("judge_api_base"), - judge_api_key_env=raw_config.get("judge_api_key_env"), - proposer_api_base=raw_config.get("proposer_api_base"), - proposer_api_key_env=raw_config.get("proposer_api_key_env"), - synth_api_base=raw_config.get("synth_api_base"), - synth_api_key_env=raw_config.get("synth_api_key_env"), - max_iterations=raw_config.get("max_iterations", 30), - n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20), - minibatch_size=raw_config.get("minibatch_size", 5), - seed=raw_config.get("seed", 42), - max_retries=raw_config.get("max_retries", max_retries), - retry_delay_base=raw_config.get("retry_delay_base", 1.0), - circuit_breaker_threshold=raw_config.get("circuit_breaker_threshold", 5), - error_strategy=raw_config.get("error_strategy", error_strategy), - max_concurrency=raw_config.get("max_concurrency", max_concurrency), - output_path=output, - verbose=verbose, - ) - console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]") - console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]") - - # 2. Create per-model DSPy LM instances task_lm = dspy.LM( config.task_model, - **_model_lm_kwargs(config.task_api_base, config.task_api_key_env, global_api_base, global_api_key_env), + **_model_lm_kwargs(config.task_api_base, config.task_api_key_env), ) judge_lm = dspy.LM( config.judge_model, - **_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env, global_api_base, global_api_key_env), + **_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env), ) proposer_lm = dspy.LM( config.proposer_model, - **_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env, global_api_base, global_api_key_env), + **_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env), ) synth_lm = dspy.LM( config.synth_model, - **_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env), + **_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env), ) # 3. Build adapters (Dependency Injection — each gets its own LM + retry config) diff --git a/src/prometheus/config.py b/src/prometheus/config.py index 744fbcc..783e687 100644 --- a/src/prometheus/config.py +++ b/src/prometheus/config.py @@ -1,12 +1,2 @@ """Application settings.""" from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass -class AppSettings: - """Non-sensitive settings, hardcoded for the MVP.""" - - app_name: str = "prometheus" - version: str = "0.1.0" diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..bdf138b --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,302 @@ +"""Unit tests for config and config loading via CLI. + +Tests config validation scenarios: missing fields, wrong types, defaults. +""" +from __future__ import annotations + +from pathlib import Path + +import pytest +import yaml +from pydantic import ValidationError + +from prometheus.application.dto import OptimizationConfig +from prometheus.infrastructure.file_io import YamlPersistence + + +class TestOptimizationConfig: + """Tests for OptimizationConfig Pydantic model defaults.""" + + def test_default_values(self) -> None: + config = OptimizationConfig( + seed_prompt="test prompt", + task_description="test task", + ) + assert config.task_model == "openai/gpt-4o-mini" + assert config.judge_model == "openai/gpt-4o" + assert config.proposer_model == "openai/gpt-4o" + assert config.synth_model == "openai/gpt-4o" + assert config.max_iterations == 30 + assert config.n_synthetic_inputs == 20 + assert config.minibatch_size == 5 + assert config.perfect_score == 1.0 + assert config.seed == 42 + assert config.output_path == "output.yaml" + assert config.verbose is False + assert config.error_strategy == "retry" + assert config.max_retries == 3 + assert config.max_concurrency == 5 + + def test_custom_values(self) -> None: + config = OptimizationConfig( + seed_prompt="custom prompt", + task_description="custom task", + max_iterations=100, + minibatch_size=10, + seed=123, + verbose=True, + ) + assert config.max_iterations == 100 + assert config.minibatch_size == 10 + assert config.seed == 123 + assert config.verbose is True + + def test_roundtrip_to_dict(self) -> None: + config = OptimizationConfig( + seed_prompt="test", + task_description="task", + ) + d = config.model_dump() + assert d["seed_prompt"] == "test" + assert d["task_description"] == "task" + assert "history" not in d # OptimizationResult has history, not config + + def test_config_version_defaults(self) -> None: + config = OptimizationConfig(seed_prompt="a", task_description="b") + assert config.config_version == 1 + + def test_config_version_stamps_current(self) -> None: + config = OptimizationConfig( + seed_prompt="a", task_description="b", config_version=0, + ) + # Migration always stamps current version + assert config.config_version == 1 + + +class TestConfigLoading: + """Tests for loading OptimizationConfig from YAML via YamlPersistence.""" + + def test_minimal_config_loads(self, tmp_path: Path) -> None: + persistence = YamlPersistence() + data = { + "seed_prompt": "You are helpful.", + "task_description": "Answer questions.", + } + config_file = tmp_path / "config.yaml" + with open(config_file, "w") as f: + yaml.dump(data, f) + + raw = persistence.read_config(str(config_file)) + config = OptimizationConfig.model_validate(raw) + + assert config.seed_prompt == "You are helpful." + assert config.task_description == "Answer questions." + assert config.max_iterations == 30 # default + + def test_full_config_loads(self, tmp_path: Path) -> None: + persistence = YamlPersistence() + data = { + "seed_prompt": "You are helpful.", + "task_description": "Answer questions.", + "task_model": "openai/gpt-4o", + "judge_model": "openai/gpt-4o-mini", + "max_iterations": 50, + "n_synthetic_inputs": 30, + "minibatch_size": 8, + "seed": 99, + "verbose": True, + } + config_file = tmp_path / "config.yaml" + with open(config_file, "w") as f: + yaml.dump(data, f) + + raw = persistence.read_config(str(config_file)) + config = OptimizationConfig.model_validate(raw) + + assert config.task_model == "openai/gpt-4o" + assert config.max_iterations == 50 + assert config.verbose is True + + def test_missing_seed_prompt_raises_validation_error(self, tmp_path: Path) -> None: + persistence = YamlPersistence() + data = {"task_description": "Answer questions."} + config_file = tmp_path / "config.yaml" + with open(config_file, "w") as f: + yaml.dump(data, f) + + raw = persistence.read_config(str(config_file)) + + with pytest.raises(ValidationError, match="seed_prompt"): + OptimizationConfig.model_validate(raw) + + def test_missing_task_description_raises_validation_error(self, tmp_path: Path) -> None: + persistence = YamlPersistence() + data = {"seed_prompt": "You are helpful."} + config_file = tmp_path / "config.yaml" + with open(config_file, "w") as f: + yaml.dump(data, f) + + raw = persistence.read_config(str(config_file)) + + with pytest.raises(ValidationError, match="task_description"): + OptimizationConfig.model_validate(raw) + + def test_empty_yaml_raises_on_required_fields(self, tmp_path: Path) -> None: + persistence = YamlPersistence() + config_file = tmp_path / "empty.yaml" + config_file.write_text("{}", encoding="utf-8") + + raw = persistence.read_config(str(config_file)) + + with pytest.raises(ValidationError): + OptimizationConfig.model_validate(raw) + + def test_partial_config_uses_defaults(self, tmp_path: Path) -> None: + persistence = YamlPersistence() + data = { + "seed_prompt": "test", + "task_description": "task", + "max_iterations": 10, + } + config_file = tmp_path / "partial.yaml" + with open(config_file, "w") as f: + yaml.dump(data, f) + + raw = persistence.read_config(str(config_file)) + config = OptimizationConfig.model_validate(raw) + + assert config.max_iterations == 10 + assert config.n_synthetic_inputs == 20 # default + assert config.minibatch_size == 5 # default + + +class TestConfigValidation: + """Tests for Pydantic validation edge cases.""" + + def test_wrong_type_max_iterations(self) -> None: + with pytest.raises(ValidationError, match="max_iterations"): + OptimizationConfig( + seed_prompt="a", + task_description="b", + max_iterations="not_a_number", # type: ignore[arg-type] + ) + + def test_wrong_type_seed(self) -> None: + with pytest.raises(ValidationError, match="seed"): + OptimizationConfig( + seed_prompt="a", + task_description="b", + seed="not_a_number", # type: ignore[arg-type] + ) + + def test_negative_max_iterations(self) -> None: + with pytest.raises(ValidationError, match="greater than or equal to 1"): + OptimizationConfig( + seed_prompt="a", task_description="b", max_iterations=0, + ) + + def test_negative_seed(self) -> None: + with pytest.raises(ValidationError, match="greater than or equal to 0"): + OptimizationConfig( + seed_prompt="a", task_description="b", seed=-1, + ) + + def test_zero_minibatch_size(self) -> None: + with pytest.raises(ValidationError, match="greater than or equal to 1"): + OptimizationConfig( + seed_prompt="a", task_description="b", minibatch_size=0, + ) + + def test_zero_max_concurrency(self) -> None: + with pytest.raises(ValidationError, match="greater than or equal to 1"): + OptimizationConfig( + seed_prompt="a", task_description="b", max_concurrency=0, + ) + + def test_negative_max_retries(self) -> None: + with pytest.raises(ValidationError, match="greater than or equal to 0"): + OptimizationConfig( + seed_prompt="a", task_description="b", max_retries=-1, + ) + + def test_zero_retry_delay_base(self) -> None: + with pytest.raises(ValidationError, match="greater than 0"): + OptimizationConfig( + seed_prompt="a", task_description="b", retry_delay_base=0, + ) + + def test_negative_retry_delay_base(self) -> None: + with pytest.raises(ValidationError, match="greater than 0"): + OptimizationConfig( + seed_prompt="a", task_description="b", retry_delay_base=-0.5, + ) + + def test_zero_circuit_breaker_threshold(self) -> None: + with pytest.raises(ValidationError, match="greater than or equal to 1"): + OptimizationConfig( + seed_prompt="a", task_description="b", circuit_breaker_threshold=0, + ) + + def test_perfect_score_above_one(self) -> None: + with pytest.raises(ValidationError, match="less than or equal to 1"): + OptimizationConfig( + seed_prompt="a", task_description="b", perfect_score=1.5, + ) + + def test_perfect_score_negative(self) -> None: + with pytest.raises(ValidationError, match="greater than or equal to 0"): + OptimizationConfig( + seed_prompt="a", task_description="b", perfect_score=-0.1, + ) + + def test_invalid_error_strategy(self) -> None: + with pytest.raises(ValidationError, match="error_strategy must be one of"): + OptimizationConfig( + seed_prompt="a", task_description="b", error_strategy="invalid", + ) + + def test_valid_error_strategies(self) -> None: + for strategy in ("skip", "retry", "abort"): + config = OptimizationConfig( + seed_prompt="a", task_description="b", error_strategy=strategy, + ) + assert config.error_strategy == strategy + + def test_empty_seed_prompt(self) -> None: + with pytest.raises(ValidationError, match="seed_prompt"): + OptimizationConfig(seed_prompt="", task_description="b") + + def test_empty_task_description(self) -> None: + with pytest.raises(ValidationError, match="task_description"): + OptimizationConfig(seed_prompt="a", task_description="") + + def test_empty_model_string(self) -> None: + with pytest.raises(ValidationError, match="task_model"): + OptimizationConfig( + seed_prompt="a", task_description="b", task_model="", + ) + + def test_extra_fields_rejected(self) -> None: + with pytest.raises(ValidationError, match="extra"): + OptimizationConfig( + seed_prompt="a", + task_description="b", + nonexistent_field="value", # type: ignore[call-arg] + ) + + def test_boundary_values_accepted(self) -> None: + config = OptimizationConfig( + seed_prompt="a", + task_description="b", + max_iterations=1, + n_synthetic_inputs=1, + minibatch_size=1, + max_concurrency=1, + max_retries=0, + retry_delay_base=0.001, + circuit_breaker_threshold=1, + perfect_score=0.0, + seed=0, + ) + assert config.max_iterations == 1 + assert config.perfect_score == 0.0