feat: Pydantic config validation with clear CLI error messages
Convert OptimizationConfig from dataclass to Pydantic BaseModel with field validators for ranges, types, and enum values. Missing/invalid fields now produce actionable CLI errors instead of cryptic KeyErrors. - Range validators: max_iterations>=1, minibatch_size>=1, seed>=0, etc. - Enum validator: error_strategy must be skip|retry|abort - Config migration hook via config_version field - CLI catches ValidationError and prints per-field error messages - Remove unused AppSettings class (Bug #7) - 30 unit tests covering all validation edge cases Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
302
tests/unit/test_config.py
Normal file
302
tests/unit/test_config.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""Unit tests for config and config loading via CLI.
|
||||
|
||||
Tests config validation scenarios: missing fields, wrong types, defaults.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
from prometheus.infrastructure.file_io import YamlPersistence
|
||||
|
||||
|
||||
class TestOptimizationConfig:
|
||||
"""Tests for OptimizationConfig Pydantic model defaults."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="test prompt",
|
||||
task_description="test task",
|
||||
)
|
||||
assert config.task_model == "openai/gpt-4o-mini"
|
||||
assert config.judge_model == "openai/gpt-4o"
|
||||
assert config.proposer_model == "openai/gpt-4o"
|
||||
assert config.synth_model == "openai/gpt-4o"
|
||||
assert config.max_iterations == 30
|
||||
assert config.n_synthetic_inputs == 20
|
||||
assert config.minibatch_size == 5
|
||||
assert config.perfect_score == 1.0
|
||||
assert config.seed == 42
|
||||
assert config.output_path == "output.yaml"
|
||||
assert config.verbose is False
|
||||
assert config.error_strategy == "retry"
|
||||
assert config.max_retries == 3
|
||||
assert config.max_concurrency == 5
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="custom prompt",
|
||||
task_description="custom task",
|
||||
max_iterations=100,
|
||||
minibatch_size=10,
|
||||
seed=123,
|
||||
verbose=True,
|
||||
)
|
||||
assert config.max_iterations == 100
|
||||
assert config.minibatch_size == 10
|
||||
assert config.seed == 123
|
||||
assert config.verbose is True
|
||||
|
||||
def test_roundtrip_to_dict(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="test",
|
||||
task_description="task",
|
||||
)
|
||||
d = config.model_dump()
|
||||
assert d["seed_prompt"] == "test"
|
||||
assert d["task_description"] == "task"
|
||||
assert "history" not in d # OptimizationResult has history, not config
|
||||
|
||||
def test_config_version_defaults(self) -> None:
|
||||
config = OptimizationConfig(seed_prompt="a", task_description="b")
|
||||
assert config.config_version == 1
|
||||
|
||||
def test_config_version_stamps_current(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", config_version=0,
|
||||
)
|
||||
# Migration always stamps current version
|
||||
assert config.config_version == 1
|
||||
|
||||
|
||||
class TestConfigLoading:
|
||||
"""Tests for loading OptimizationConfig from YAML via YamlPersistence."""
|
||||
|
||||
def test_minimal_config_loads(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
data = {
|
||||
"seed_prompt": "You are helpful.",
|
||||
"task_description": "Answer questions.",
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
config = OptimizationConfig.model_validate(raw)
|
||||
|
||||
assert config.seed_prompt == "You are helpful."
|
||||
assert config.task_description == "Answer questions."
|
||||
assert config.max_iterations == 30 # default
|
||||
|
||||
def test_full_config_loads(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
data = {
|
||||
"seed_prompt": "You are helpful.",
|
||||
"task_description": "Answer questions.",
|
||||
"task_model": "openai/gpt-4o",
|
||||
"judge_model": "openai/gpt-4o-mini",
|
||||
"max_iterations": 50,
|
||||
"n_synthetic_inputs": 30,
|
||||
"minibatch_size": 8,
|
||||
"seed": 99,
|
||||
"verbose": True,
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
config = OptimizationConfig.model_validate(raw)
|
||||
|
||||
assert config.task_model == "openai/gpt-4o"
|
||||
assert config.max_iterations == 50
|
||||
assert config.verbose is True
|
||||
|
||||
def test_missing_seed_prompt_raises_validation_error(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
data = {"task_description": "Answer questions."}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
|
||||
with pytest.raises(ValidationError, match="seed_prompt"):
|
||||
OptimizationConfig.model_validate(raw)
|
||||
|
||||
def test_missing_task_description_raises_validation_error(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
data = {"seed_prompt": "You are helpful."}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
|
||||
with pytest.raises(ValidationError, match="task_description"):
|
||||
OptimizationConfig.model_validate(raw)
|
||||
|
||||
def test_empty_yaml_raises_on_required_fields(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
config_file = tmp_path / "empty.yaml"
|
||||
config_file.write_text("{}", encoding="utf-8")
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
OptimizationConfig.model_validate(raw)
|
||||
|
||||
def test_partial_config_uses_defaults(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
data = {
|
||||
"seed_prompt": "test",
|
||||
"task_description": "task",
|
||||
"max_iterations": 10,
|
||||
}
|
||||
config_file = tmp_path / "partial.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
config = OptimizationConfig.model_validate(raw)
|
||||
|
||||
assert config.max_iterations == 10
|
||||
assert config.n_synthetic_inputs == 20 # default
|
||||
assert config.minibatch_size == 5 # default
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
"""Tests for Pydantic validation edge cases."""
|
||||
|
||||
def test_wrong_type_max_iterations(self) -> None:
|
||||
with pytest.raises(ValidationError, match="max_iterations"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a",
|
||||
task_description="b",
|
||||
max_iterations="not_a_number", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def test_wrong_type_seed(self) -> None:
|
||||
with pytest.raises(ValidationError, match="seed"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a",
|
||||
task_description="b",
|
||||
seed="not_a_number", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def test_negative_max_iterations(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", max_iterations=0,
|
||||
)
|
||||
|
||||
def test_negative_seed(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", seed=-1,
|
||||
)
|
||||
|
||||
def test_zero_minibatch_size(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", minibatch_size=0,
|
||||
)
|
||||
|
||||
def test_zero_max_concurrency(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", max_concurrency=0,
|
||||
)
|
||||
|
||||
def test_negative_max_retries(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", max_retries=-1,
|
||||
)
|
||||
|
||||
def test_zero_retry_delay_base(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than 0"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", retry_delay_base=0,
|
||||
)
|
||||
|
||||
def test_negative_retry_delay_base(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than 0"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", retry_delay_base=-0.5,
|
||||
)
|
||||
|
||||
def test_zero_circuit_breaker_threshold(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", circuit_breaker_threshold=0,
|
||||
)
|
||||
|
||||
def test_perfect_score_above_one(self) -> None:
|
||||
with pytest.raises(ValidationError, match="less than or equal to 1"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", perfect_score=1.5,
|
||||
)
|
||||
|
||||
def test_perfect_score_negative(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", perfect_score=-0.1,
|
||||
)
|
||||
|
||||
def test_invalid_error_strategy(self) -> None:
|
||||
with pytest.raises(ValidationError, match="error_strategy must be one of"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", error_strategy="invalid",
|
||||
)
|
||||
|
||||
def test_valid_error_strategies(self) -> None:
|
||||
for strategy in ("skip", "retry", "abort"):
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", error_strategy=strategy,
|
||||
)
|
||||
assert config.error_strategy == strategy
|
||||
|
||||
def test_empty_seed_prompt(self) -> None:
|
||||
with pytest.raises(ValidationError, match="seed_prompt"):
|
||||
OptimizationConfig(seed_prompt="", task_description="b")
|
||||
|
||||
def test_empty_task_description(self) -> None:
|
||||
with pytest.raises(ValidationError, match="task_description"):
|
||||
OptimizationConfig(seed_prompt="a", task_description="")
|
||||
|
||||
def test_empty_model_string(self) -> None:
|
||||
with pytest.raises(ValidationError, match="task_model"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", task_model="",
|
||||
)
|
||||
|
||||
def test_extra_fields_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="extra"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a",
|
||||
task_description="b",
|
||||
nonexistent_field="value", # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
def test_boundary_values_accepted(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a",
|
||||
task_description="b",
|
||||
max_iterations=1,
|
||||
n_synthetic_inputs=1,
|
||||
minibatch_size=1,
|
||||
max_concurrency=1,
|
||||
max_retries=0,
|
||||
retry_delay_base=0.001,
|
||||
circuit_breaker_threshold=1,
|
||||
perfect_score=0.0,
|
||||
seed=0,
|
||||
)
|
||||
assert config.max_iterations == 1
|
||||
assert config.perfect_score == 0.0
|
||||
Reference in New Issue
Block a user