feat: Pydantic config validation with clear CLI error messages
Convert OptimizationConfig from dataclass to Pydantic BaseModel with field validators for ranges, types, and enum values. Missing/invalid fields now produce actionable CLI errors instead of cryptic KeyErrors. - Range validators: max_iterations>=1, minibatch_size>=1, seed>=0, etc. - Enum validator: error_strategy must be skip|retry|abort - Config migration hook via config_version field - CLI catches ValidationError and prints per-field error messages - Remove unused AppSettings class (Bug #7) - 30 unit tests covering all validation edge cases Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -4,20 +4,48 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OptimizationConfig:
|
|
||||||
"""Complete configuration for a PROMETHEUS run."""
|
|
||||||
|
|
||||||
# --- Prompt ---
|
# Current config schema version.
|
||||||
seed_prompt: str
|
CONFIG_VERSION = 1
|
||||||
task_description: str
|
|
||||||
|
_ERROR_STRATEGY_VALUES = {"skip", "retry", "abort"}
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizationConfig(BaseModel):
|
||||||
|
"""Complete configuration for a PROMETHEUS run.
|
||||||
|
|
||||||
|
Validated with Pydantic so missing or wrong-type fields produce clear,
|
||||||
|
actionable error messages at the CLI boundary instead of cryptic failures
|
||||||
|
deep in the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {"extra": "forbid"}
|
||||||
|
|
||||||
|
# --- Schema version (for migration support) ---
|
||||||
|
config_version: int = Field(
|
||||||
|
default=CONFIG_VERSION,
|
||||||
|
description="Config schema version. Set automatically; used for migration.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Prompt (required) ---
|
||||||
|
seed_prompt: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
description="The initial prompt to optimize.",
|
||||||
|
)
|
||||||
|
task_description: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
description="Description of the task the prompt should accomplish.",
|
||||||
|
)
|
||||||
|
|
||||||
# --- Models ---
|
# --- Models ---
|
||||||
task_model: str = "openai/gpt-4o-mini"
|
task_model: str = Field(default="openai/gpt-4o-mini", min_length=1)
|
||||||
judge_model: str = "openai/gpt-4o"
|
judge_model: str = Field(default="openai/gpt-4o", min_length=1)
|
||||||
proposer_model: str = "openai/gpt-4o"
|
proposer_model: str = Field(default="openai/gpt-4o", min_length=1)
|
||||||
synth_model: str = "openai/gpt-4o"
|
synth_model: str = Field(default="openai/gpt-4o", min_length=1)
|
||||||
|
|
||||||
# --- Per-model API overrides (optional, fall back to global api_base/api_key_env) ---
|
# --- Per-model API overrides (optional, fall back to global api_base/api_key_env) ---
|
||||||
task_api_base: str | None = None
|
task_api_base: str | None = None
|
||||||
@@ -29,28 +57,54 @@ class OptimizationConfig:
|
|||||||
synth_api_base: str | None = None
|
synth_api_base: str | None = None
|
||||||
synth_api_key_env: str | None = None
|
synth_api_key_env: str | None = None
|
||||||
|
|
||||||
|
# --- Global API settings (optional) ---
|
||||||
|
api_base: str | None = None
|
||||||
|
api_key_env: str | None = None
|
||||||
|
|
||||||
# --- Evolution parameters ---
|
# --- Evolution parameters ---
|
||||||
max_iterations: int = 30
|
max_iterations: int = Field(default=30, ge=1, description="Maximum evolution iterations.")
|
||||||
n_synthetic_inputs: int = 20
|
n_synthetic_inputs: int = Field(default=20, ge=1, description="Number of synthetic inputs to generate.")
|
||||||
minibatch_size: int = 5
|
minibatch_size: int = Field(default=5, ge=1, description="Inputs per evaluation minibatch.")
|
||||||
perfect_score: float = 1.0
|
perfect_score: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
|
|
||||||
# --- Reproducibility ---
|
# --- Reproducibility ---
|
||||||
seed: int = 42
|
seed: int = Field(default=42, ge=0)
|
||||||
|
|
||||||
# --- Concurrency ---
|
# --- Concurrency ---
|
||||||
max_concurrency: int = 5
|
max_concurrency: int = Field(default=5, ge=1, description="Max parallel LLM calls.")
|
||||||
|
|
||||||
# --- Error handling ---
|
# --- Error handling ---
|
||||||
max_retries: int = 3
|
max_retries: int = Field(default=3, ge=0, description="Max retry attempts for transient errors.")
|
||||||
retry_delay_base: float = 1.0
|
retry_delay_base: float = Field(default=1.0, gt=0, description="Base delay in seconds for retry backoff.")
|
||||||
circuit_breaker_threshold: int = 5
|
circuit_breaker_threshold: int = Field(default=5, ge=1, description="Consecutive failures before circuit opens.")
|
||||||
error_strategy: str = "retry" # skip | retry | abort
|
error_strategy: str = Field(default="retry", description="Error handling strategy: skip | retry | abort.")
|
||||||
|
|
||||||
# --- Output ---
|
# --- Output ---
|
||||||
output_path: str = "output.yaml"
|
output_path: str = Field(default="output.yaml", min_length=1)
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
|
|
||||||
|
@field_validator("error_strategy")
|
||||||
|
@classmethod
|
||||||
|
def _validate_error_strategy(cls, v: str) -> str:
|
||||||
|
if v not in _ERROR_STRATEGY_VALUES:
|
||||||
|
raise ValueError(
|
||||||
|
f"error_strategy must be one of {sorted(_ERROR_STRATEGY_VALUES)}, got '{v}'"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _migrate_config(cls, data: Any) -> Any:
|
||||||
|
"""Apply migration transforms for older config versions."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
version = data.get("config_version", CONFIG_VERSION)
|
||||||
|
# Future migrations go here, e.g.:
|
||||||
|
# if version < 2:
|
||||||
|
# data = _migrate_v1_to_v2(data)
|
||||||
|
# Always stamp current version.
|
||||||
|
data["config_version"] = CONFIG_VERSION
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OptimizationResult:
|
class OptimizationResult:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from dataclasses import asdict
|
|||||||
|
|
||||||
import dspy
|
import dspy
|
||||||
import typer
|
import typer
|
||||||
|
from pydantic import ValidationError
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
@@ -102,75 +103,58 @@ async def _async_optimize(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. Load config
|
# 1. Load & validate config
|
||||||
persistence = YamlPersistence()
|
persistence = YamlPersistence()
|
||||||
raw_config = persistence.read_config(input)
|
raw_config = persistence.read_config(input)
|
||||||
|
|
||||||
|
# CLI flags override config file values
|
||||||
|
raw_config.setdefault("max_retries", max_retries)
|
||||||
|
raw_config.setdefault("error_strategy", error_strategy)
|
||||||
|
raw_config.setdefault("max_concurrency", max_concurrency)
|
||||||
|
raw_config["output_path"] = output
|
||||||
|
raw_config["verbose"] = verbose
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = OptimizationConfig.model_validate(raw_config)
|
||||||
|
except ValidationError as exc:
|
||||||
|
console.print("[bold red]Configuration error:[/bold red]\n")
|
||||||
|
for err in exc.errors():
|
||||||
|
loc = " → ".join(str(l) for l in err["loc"])
|
||||||
|
console.print(f" [red]• {loc}: {err['msg']}[/red]")
|
||||||
|
raise typer.Exit(code=1) from exc
|
||||||
|
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
||||||
|
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
||||||
|
|
||||||
|
# 2. Create per-model DSPy LM instances
|
||||||
def _model_lm_kwargs(
|
def _model_lm_kwargs(
|
||||||
model_api_base: str | None,
|
model_api_base: str | None,
|
||||||
model_api_key_env: str | None,
|
model_api_key_env: str | None,
|
||||||
global_api_base: str | None,
|
|
||||||
global_api_key_env: str | None,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Build kwargs for dspy.LM, using per-model overrides with global fallback."""
|
"""Build kwargs for dspy.LM, using per-model overrides with global fallback."""
|
||||||
kwargs: dict = {}
|
kwargs: dict = {}
|
||||||
api_base = model_api_base or global_api_base
|
api_base = model_api_base or config.api_base
|
||||||
api_key_env = model_api_key_env or global_api_key_env
|
api_key_env = model_api_key_env or config.api_key_env
|
||||||
if api_base:
|
if api_base:
|
||||||
kwargs["api_base"] = api_base
|
kwargs["api_base"] = api_base
|
||||||
if api_key_env:
|
if api_key_env:
|
||||||
kwargs["api_key"] = os.environ.get(api_key_env, "")
|
kwargs["api_key"] = os.environ.get(api_key_env, "")
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
global_api_base = raw_config.get("api_base")
|
|
||||||
global_api_key_env = raw_config.get("api_key_env")
|
|
||||||
|
|
||||||
config = OptimizationConfig(
|
|
||||||
seed_prompt=raw_config["seed_prompt"],
|
|
||||||
task_description=raw_config["task_description"],
|
|
||||||
task_model=raw_config.get("task_model", "openai/gpt-4o-mini"),
|
|
||||||
judge_model=raw_config.get("judge_model", "openai/gpt-4o"),
|
|
||||||
proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"),
|
|
||||||
synth_model=raw_config.get("synth_model", "openai/gpt-4o"),
|
|
||||||
task_api_base=raw_config.get("task_api_base"),
|
|
||||||
task_api_key_env=raw_config.get("task_api_key_env"),
|
|
||||||
judge_api_base=raw_config.get("judge_api_base"),
|
|
||||||
judge_api_key_env=raw_config.get("judge_api_key_env"),
|
|
||||||
proposer_api_base=raw_config.get("proposer_api_base"),
|
|
||||||
proposer_api_key_env=raw_config.get("proposer_api_key_env"),
|
|
||||||
synth_api_base=raw_config.get("synth_api_base"),
|
|
||||||
synth_api_key_env=raw_config.get("synth_api_key_env"),
|
|
||||||
max_iterations=raw_config.get("max_iterations", 30),
|
|
||||||
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
|
|
||||||
minibatch_size=raw_config.get("minibatch_size", 5),
|
|
||||||
seed=raw_config.get("seed", 42),
|
|
||||||
max_retries=raw_config.get("max_retries", max_retries),
|
|
||||||
retry_delay_base=raw_config.get("retry_delay_base", 1.0),
|
|
||||||
circuit_breaker_threshold=raw_config.get("circuit_breaker_threshold", 5),
|
|
||||||
error_strategy=raw_config.get("error_strategy", error_strategy),
|
|
||||||
max_concurrency=raw_config.get("max_concurrency", max_concurrency),
|
|
||||||
output_path=output,
|
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
||||||
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
|
||||||
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
|
||||||
|
|
||||||
# 2. Create per-model DSPy LM instances
|
|
||||||
task_lm = dspy.LM(
|
task_lm = dspy.LM(
|
||||||
config.task_model,
|
config.task_model,
|
||||||
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env, global_api_base, global_api_key_env),
|
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env),
|
||||||
)
|
)
|
||||||
judge_lm = dspy.LM(
|
judge_lm = dspy.LM(
|
||||||
config.judge_model,
|
config.judge_model,
|
||||||
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env, global_api_base, global_api_key_env),
|
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env),
|
||||||
)
|
)
|
||||||
proposer_lm = dspy.LM(
|
proposer_lm = dspy.LM(
|
||||||
config.proposer_model,
|
config.proposer_model,
|
||||||
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env, global_api_base, global_api_key_env),
|
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env),
|
||||||
)
|
)
|
||||||
synth_lm = dspy.LM(
|
synth_lm = dspy.LM(
|
||||||
config.synth_model,
|
config.synth_model,
|
||||||
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env),
|
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Build adapters (Dependency Injection — each gets its own LM + retry config)
|
# 3. Build adapters (Dependency Injection — each gets its own LM + retry config)
|
||||||
|
|||||||
@@ -1,12 +1,2 @@
|
|||||||
"""Application settings."""
|
"""Application settings."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AppSettings:
|
|
||||||
"""Non-sensitive settings, hardcoded for the MVP."""
|
|
||||||
|
|
||||||
app_name: str = "prometheus"
|
|
||||||
version: str = "0.1.0"
|
|
||||||
|
|||||||
302
tests/unit/test_config.py
Normal file
302
tests/unit/test_config.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
"""Unit tests for config and config loading via CLI.
|
||||||
|
|
||||||
|
Tests config validation scenarios: missing fields, wrong types, defaults.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from prometheus.application.dto import OptimizationConfig
|
||||||
|
from prometheus.infrastructure.file_io import YamlPersistence
|
||||||
|
|
||||||
|
|
||||||
|
class TestOptimizationConfig:
|
||||||
|
"""Tests for OptimizationConfig Pydantic model defaults."""
|
||||||
|
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
config = OptimizationConfig(
|
||||||
|
seed_prompt="test prompt",
|
||||||
|
task_description="test task",
|
||||||
|
)
|
||||||
|
assert config.task_model == "openai/gpt-4o-mini"
|
||||||
|
assert config.judge_model == "openai/gpt-4o"
|
||||||
|
assert config.proposer_model == "openai/gpt-4o"
|
||||||
|
assert config.synth_model == "openai/gpt-4o"
|
||||||
|
assert config.max_iterations == 30
|
||||||
|
assert config.n_synthetic_inputs == 20
|
||||||
|
assert config.minibatch_size == 5
|
||||||
|
assert config.perfect_score == 1.0
|
||||||
|
assert config.seed == 42
|
||||||
|
assert config.output_path == "output.yaml"
|
||||||
|
assert config.verbose is False
|
||||||
|
assert config.error_strategy == "retry"
|
||||||
|
assert config.max_retries == 3
|
||||||
|
assert config.max_concurrency == 5
|
||||||
|
|
||||||
|
def test_custom_values(self) -> None:
|
||||||
|
config = OptimizationConfig(
|
||||||
|
seed_prompt="custom prompt",
|
||||||
|
task_description="custom task",
|
||||||
|
max_iterations=100,
|
||||||
|
minibatch_size=10,
|
||||||
|
seed=123,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
assert config.max_iterations == 100
|
||||||
|
assert config.minibatch_size == 10
|
||||||
|
assert config.seed == 123
|
||||||
|
assert config.verbose is True
|
||||||
|
|
||||||
|
def test_roundtrip_to_dict(self) -> None:
|
||||||
|
config = OptimizationConfig(
|
||||||
|
seed_prompt="test",
|
||||||
|
task_description="task",
|
||||||
|
)
|
||||||
|
d = config.model_dump()
|
||||||
|
assert d["seed_prompt"] == "test"
|
||||||
|
assert d["task_description"] == "task"
|
||||||
|
assert "history" not in d # OptimizationResult has history, not config
|
||||||
|
|
||||||
|
def test_config_version_defaults(self) -> None:
|
||||||
|
config = OptimizationConfig(seed_prompt="a", task_description="b")
|
||||||
|
assert config.config_version == 1
|
||||||
|
|
||||||
|
def test_config_version_stamps_current(self) -> None:
|
||||||
|
config = OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", config_version=0,
|
||||||
|
)
|
||||||
|
# Migration always stamps current version
|
||||||
|
assert config.config_version == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigLoading:
|
||||||
|
"""Tests for loading OptimizationConfig from YAML via YamlPersistence."""
|
||||||
|
|
||||||
|
def test_minimal_config_loads(self, tmp_path: Path) -> None:
|
||||||
|
persistence = YamlPersistence()
|
||||||
|
data = {
|
||||||
|
"seed_prompt": "You are helpful.",
|
||||||
|
"task_description": "Answer questions.",
|
||||||
|
}
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(data, f)
|
||||||
|
|
||||||
|
raw = persistence.read_config(str(config_file))
|
||||||
|
config = OptimizationConfig.model_validate(raw)
|
||||||
|
|
||||||
|
assert config.seed_prompt == "You are helpful."
|
||||||
|
assert config.task_description == "Answer questions."
|
||||||
|
assert config.max_iterations == 30 # default
|
||||||
|
|
||||||
|
def test_full_config_loads(self, tmp_path: Path) -> None:
|
||||||
|
persistence = YamlPersistence()
|
||||||
|
data = {
|
||||||
|
"seed_prompt": "You are helpful.",
|
||||||
|
"task_description": "Answer questions.",
|
||||||
|
"task_model": "openai/gpt-4o",
|
||||||
|
"judge_model": "openai/gpt-4o-mini",
|
||||||
|
"max_iterations": 50,
|
||||||
|
"n_synthetic_inputs": 30,
|
||||||
|
"minibatch_size": 8,
|
||||||
|
"seed": 99,
|
||||||
|
"verbose": True,
|
||||||
|
}
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(data, f)
|
||||||
|
|
||||||
|
raw = persistence.read_config(str(config_file))
|
||||||
|
config = OptimizationConfig.model_validate(raw)
|
||||||
|
|
||||||
|
assert config.task_model == "openai/gpt-4o"
|
||||||
|
assert config.max_iterations == 50
|
||||||
|
assert config.verbose is True
|
||||||
|
|
||||||
|
def test_missing_seed_prompt_raises_validation_error(self, tmp_path: Path) -> None:
|
||||||
|
persistence = YamlPersistence()
|
||||||
|
data = {"task_description": "Answer questions."}
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(data, f)
|
||||||
|
|
||||||
|
raw = persistence.read_config(str(config_file))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="seed_prompt"):
|
||||||
|
OptimizationConfig.model_validate(raw)
|
||||||
|
|
||||||
|
def test_missing_task_description_raises_validation_error(self, tmp_path: Path) -> None:
|
||||||
|
persistence = YamlPersistence()
|
||||||
|
data = {"seed_prompt": "You are helpful."}
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(data, f)
|
||||||
|
|
||||||
|
raw = persistence.read_config(str(config_file))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="task_description"):
|
||||||
|
OptimizationConfig.model_validate(raw)
|
||||||
|
|
||||||
|
def test_empty_yaml_raises_on_required_fields(self, tmp_path: Path) -> None:
|
||||||
|
persistence = YamlPersistence()
|
||||||
|
config_file = tmp_path / "empty.yaml"
|
||||||
|
config_file.write_text("{}", encoding="utf-8")
|
||||||
|
|
||||||
|
raw = persistence.read_config(str(config_file))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
OptimizationConfig.model_validate(raw)
|
||||||
|
|
||||||
|
def test_partial_config_uses_defaults(self, tmp_path: Path) -> None:
|
||||||
|
persistence = YamlPersistence()
|
||||||
|
data = {
|
||||||
|
"seed_prompt": "test",
|
||||||
|
"task_description": "task",
|
||||||
|
"max_iterations": 10,
|
||||||
|
}
|
||||||
|
config_file = tmp_path / "partial.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(data, f)
|
||||||
|
|
||||||
|
raw = persistence.read_config(str(config_file))
|
||||||
|
config = OptimizationConfig.model_validate(raw)
|
||||||
|
|
||||||
|
assert config.max_iterations == 10
|
||||||
|
assert config.n_synthetic_inputs == 20 # default
|
||||||
|
assert config.minibatch_size == 5 # default
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigValidation:
|
||||||
|
"""Tests for Pydantic validation edge cases."""
|
||||||
|
|
||||||
|
def test_wrong_type_max_iterations(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="max_iterations"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a",
|
||||||
|
task_description="b",
|
||||||
|
max_iterations="not_a_number", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_wrong_type_seed(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="seed"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a",
|
||||||
|
task_description="b",
|
||||||
|
seed="not_a_number", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_negative_max_iterations(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", max_iterations=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_negative_seed(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", seed=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_zero_minibatch_size(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", minibatch_size=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_zero_max_concurrency(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", max_concurrency=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_negative_max_retries(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", max_retries=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_zero_retry_delay_base(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="greater than 0"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", retry_delay_base=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_negative_retry_delay_base(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="greater than 0"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", retry_delay_base=-0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_zero_circuit_breaker_threshold(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", circuit_breaker_threshold=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_perfect_score_above_one(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="less than or equal to 1"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", perfect_score=1.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_perfect_score_negative(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", perfect_score=-0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_invalid_error_strategy(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="error_strategy must be one of"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", error_strategy="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_valid_error_strategies(self) -> None:
|
||||||
|
for strategy in ("skip", "retry", "abort"):
|
||||||
|
config = OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", error_strategy=strategy,
|
||||||
|
)
|
||||||
|
assert config.error_strategy == strategy
|
||||||
|
|
||||||
|
def test_empty_seed_prompt(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="seed_prompt"):
|
||||||
|
OptimizationConfig(seed_prompt="", task_description="b")
|
||||||
|
|
||||||
|
def test_empty_task_description(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="task_description"):
|
||||||
|
OptimizationConfig(seed_prompt="a", task_description="")
|
||||||
|
|
||||||
|
def test_empty_model_string(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="task_model"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", task_model="",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_extra_fields_rejected(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="extra"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a",
|
||||||
|
task_description="b",
|
||||||
|
nonexistent_field="value", # type: ignore[call-arg]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_boundary_values_accepted(self) -> None:
|
||||||
|
config = OptimizationConfig(
|
||||||
|
seed_prompt="a",
|
||||||
|
task_description="b",
|
||||||
|
max_iterations=1,
|
||||||
|
n_synthetic_inputs=1,
|
||||||
|
minibatch_size=1,
|
||||||
|
max_concurrency=1,
|
||||||
|
max_retries=0,
|
||||||
|
retry_delay_base=0.001,
|
||||||
|
circuit_breaker_threshold=1,
|
||||||
|
perfect_score=0.0,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
assert config.max_iterations == 1
|
||||||
|
assert config.perfect_score == 0.0
|
||||||
Reference in New Issue
Block a user