Compare commits

...

6 Commits

Author SHA1 Message Date
FullStackDev
a5bf2ad59c feat: v0.2.0 sprint — ground truth eval, crossover/mutation, checkpointing, similarity guards, dataset loader, CLI commands, extended test coverage
Aggregates all v0.2.0 sprint work (GARAA-30 through GARAA-40) and fixes
2 integration tests that broke when the codebase went async (DSPyLLMAdapter
and full pipeline tests now properly await coroutines).

277 tests pass (260 unit + 17 integration).

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-03-29 19:13:50 +00:00
FullStackDev
b9745566c8 feat: custom judge criteria and multi-dimensional scoring
Add configurable judge rubrics and multi-dimensional scoring with
weighted aggregation. New config fields: judge_criteria (free text)
and judge_dimensions (list of {name, weight, description}). CLI
--judge-criteria flag provides quick overrides. The judge adapter
computes weighted aggregate scores and enriches feedback with
per-dimension breakdowns.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-03-29 15:40:21 +00:00
FullStackDev
336774a164 feat: Pydantic config validation with clear CLI error messages
Convert OptimizationConfig from dataclass to Pydantic BaseModel with
field validators for ranges, types, and enum values. Missing/invalid
fields now produce actionable CLI errors instead of cryptic KeyErrors.

- Range validators: max_iterations>=1, minibatch_size>=1, seed>=0, etc.
- Enum validator: error_strategy must be skip|retry|abort
- Config migration hook via config_version field
- CLI catches ValidationError and prints per-field error messages
- Remove unused AppSettings class (Bug #7)
- 30 unit tests covering all validation edge cases

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-03-29 13:25:44 +00:00
FullStackDev
c92ca4a2b8 feat: async/parallel execution with configurable concurrency
Parallelize LLM calls across minibatches to reduce wall-clock time.
All domain ports (LLMPort, JudgePort, ProposerPort) are now async.
Adapter implementations wrap synchronous DSPy calls with asyncio.to_thread.
Judge calls run in parallel within a batch using asyncio.gather + semaphore.
Evaluator parallelizes minibatch execution with configurable concurrency.
Evolution loop and use case are fully async. Proposer stays sequential.
Added --max-concurrency CLI flag and max_concurrency YAML config field.
Added async_retry_with_backoff for async error handling.
All 139 unit tests pass.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-03-29 13:15:34 +00:00
FullStackDev
e2d111ce5b feat: error handling, retry with backoff, and circuit breaker
Add robust error handling to the evolution loop and LLM adapters:
- Retry utility with exponential backoff for transient errors (429, 5xx, timeouts)
- Per-call error isolation in evaluator and judge adapter
- Circuit breaker in EvolutionLoop (trips after N consecutive failures)
- CLI flags: --max-retries, --error-strategy (skip|retry|abort)
- Config fields: max_retries, retry_delay_base, circuit_breaker_threshold, error_strategy
- 16 new unit tests covering all error handling paths

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-03-29 12:47:55 +00:00
FullStackDev
f516ca4be6 fix: multi-model routing — each adapter uses own dspy.LM instance
- DSPyLLMAdapter now accepts dspy.LM instead of model string, uses dspy.context(lm=...)
- DSPyJudgeAdapter, DSPyProposerAdapter, DSPySyntheticAdapter each accept and use own LM
- OptimizationConfig gains per-model api_base/api_key_env override fields
- cli/app.py creates separate dspy.LM per adapter with per-model overrides
- New unit tests verify each adapter isolates its LM from global config

Fixes Bug #1 (multi-model config not wired) and Bug #2 (DSPyLLMAdapter ignores model param).

Co-Authored-By: Paperclip <noreply@paperclip.ing>
2026-03-29 12:31:48 +00:00
51 changed files with 7102 additions and 450 deletions

369
docs/FEATURE_ROADMAP.md Normal file
View File

@@ -0,0 +1,369 @@
# PROMETHEUS Feature Roadmap
> Complete codebase review — features needed for production-grade prompt optimization.
> Generated from v0.1.0 architecture review (2026-03-29).
---
## Legend
| Marker | Meaning |
|--------|---------|
| **CLI** | Exposed as a CLI option/flag |
| **Config** | YAML config field |
| **Internal** | No user-facing surface, architectural improvement |
| **P1** | Critical / must-have for reliability |
| **P2** | High value, should-have |
| **P3** | Nice-to-have, deferred to later versions |
---
## 1. Multi-Model Routing (P1)
**Current state:** `OptimizationConfig` defines four model slots (`task_model`, `judge_model`, `proposer_model`, `synth_model`), but `cli/app.py` only configures a single global DSPy LM from `task_model`. All adapters silently use the same model regardless of config.
**Feature:**
- Each adapter (`DSPyLLMAdapter`, `DSPyJudgeAdapter`, `DSPyProposerAdapter`, `DSPySyntheticAdapter`) must instantiate its own `dspy.LM` from the corresponding config field.
- Support per-model `api_base` and `api_key_env` overrides (e.g., judge on GPT-4o, propose on a cheaper model).
**Surface:** Config (already partially defined) — `judge_model`, `proposer_model`, `synth_model` become functional. No new CLI flags needed; the YAML already has the fields.
**Scope:** Infrastructure layer (`llm_adapter.py`, `judge_adapter.py`, `proposer_adapter.py`, `synth_adapter.py`) + `cli/app.py` DI wiring.
---
## 2. Async / Parallel Execution (P1)
**Current state:** All LLM calls (execute, judge, propose) are sequential. A single iteration with `minibatch_size=5` makes ~11 sequential LLM calls. Wall-clock time scales linearly with minibatch size.
**Feature:**
- Parallelize execution of the prompt across a minibatch (`asyncio.gather` or `dspy.Parallel`).
- Parallelize judge calls within a batch.
- Keep the proposer sequential (single call per iteration).
**Surface:** Internal. Optionally exposed via `--max-concurrency` CLI flag and `max_concurrency` YAML field.
**Scope:** `evaluator.py`, `judge_adapter.py`, `llm_adapter.py`.
---
## 3. Robust Error Handling & Retry (P1)
**Current state:** The evolution loop catches broad `Exception` per iteration and logs it, then continues. Individual LLM call failures (timeouts, rate limits, malformed responses) are not retried. DSPy module fallbacks only cover parsing, not network errors.
**Feature:**
- Retry with exponential backoff for transient errors (rate limits, timeouts, 5xx).
- Configurable `max_retries` and `retry_delay_base`.
- Circuit breaker: if N consecutive iterations fail, pause and alert.
- Per-call error isolation: one bad minibatch item shouldn't fail the whole evaluation.
**Surface:** `--max-retries` CLI flag, `max_retries` Config field. `--error-strategy` (skip | retry | abort) CLI flag.
**Scope:** Infrastructure adapters + evolution loop.
---
## 4. Checkpoint & Resume (P2)
**Current state:** If a long optimization run crashes or is interrupted, all progress is lost. There is no intermediate state persistence.
**Feature:**
- Save `OptimizationState` to disk every K iterations (or every accepted improvement).
- Resume from the latest checkpoint file on restart.
- Checkpoint includes: current best candidate, all candidates, iteration number, LLM call count, RNG seed state.
**Surface:** `--checkpoint-dir` CLI flag (default: `.prometheus/checkpoints/`). `--resume` CLI flag to resume from latest checkpoint. `checkpoint_interval` Config field.
**Scope:** New `CheckpointPort` in domain, `JsonCheckpointPersistence` in infrastructure, modifications to `EvolutionLoop.run()`.
---
## 5. Population-Based Evolution (P2)
**Current state:** The evolution loop keeps only a single best candidate (hill climbing). No diversity, no crossover, no population dynamics. The `Candidate` entity has `generation` and `parent_id` fields that suggest population support was planned.
**Feature:**
- Maintain a population of K candidates (e.g., top-K by score or Pareto front).
- Crossover: combine instructions from two parent candidates.
- Mutation operators: paraphrase, constrain, generalize, specialize.
- Diversity maintenance: penalize candidates too similar to existing ones (cosine similarity or edit distance).
**Surface:** `--population-size` CLI flag, `population_size` Config field. `--crossover-rate`, `--mutation-rate` CLI flags.
**Scope:** `EvolutionLoop` refactor, new `CrossoverPort` and `MutationPort` in domain, new DSPy signatures for crossover/mutation in infrastructure.
---
## 6. Hold-Out Validation (P2)
**Current state:** The same synthetic inputs are used for both optimization and evaluation. No train/test split. Risk of overfitting to synthetic inputs.
**Feature:**
- Split synthetic pool into train (e.g., 70%) and validation (30%) sets.
- Evolution uses train minibatches for accept/reject decisions.
- After each iteration, evaluate the best candidate on the hold-out set.
- Report both train and validation scores in results.
- Optional early stopping if validation score degrades for K consecutive iterations.
**Surface:** `--validation-split` CLI flag (default: 0.3). `--early-stop-patience` CLI flag (default: 5). Config fields: `validation_split`, `early_stop_patience`.
**Scope:** `SyntheticBootstrap`, `EvolutionLoop`, `OptimizationResult` (add validation metrics).
---
## 7. Custom Judge Criteria (P2)
**Current state:** The judge uses a hardcoded rubric in `JudgeOutput` DSPy signature ("score 0.0-1.0" with generic quality assessment). Users cannot customize evaluation criteria.
**Feature:**
- Allow users to define custom judge rubrics, criteria, and scoring scales.
- Support multi-dimensional scoring (e.g., accuracy: 0-10, clarity: 0-10, safety: 0-10) with configurable weights.
- Allow `perfect_score` to reflect the custom scale.
**Surface:** `judge_criteria` YAML field (free text). `judge_dimensions` YAML field (list of `{name, weight, description}`). CLI: `--judge-criteria` for quick overrides.
**Scope:** `JudgeOutput` signature (dynamic instructions), `JudgePort`, `DSPyJudgeAdapter`, `scoring.py` (weighted aggregation).
---
## 8. Real-World Evaluation Harness (P2)
**Current state:** The system only evaluates against synthetic inputs. There is no way to test optimized prompts against real inputs with known-good outputs.
**Feature:**
- Accept an optional evaluation dataset (CSV/JSON with `input` and `expected_output` columns).
- When provided, use exact/semantic similarity matching against expected outputs instead of (or in addition to) LLM-as-Judge.
- Report metrics: accuracy, BLEU, ROUGE, or embedding cosine similarity vs expected.
**Surface:** `--eval-dataset` CLI flag. `eval_dataset_path` Config field. `--eval-metric` CLI flag (exact | semantic | llm_judge).
**Scope:** New `GroundTruthEvaluator` in application, new `SimilarityPort` in domain, dataset loader in infrastructure.
---
## 9. Logging & Observability (P2)
**Current state:** Verbose mode (`-v`) configures Python's `logging` module but no handler is attached (Bug #4 in TEST_REPORT.md). No structured logging, no tracing.
**Feature:**
- Proper structured logging with configurable levels (DEBUG, INFO, WARNING, ERROR).
- JSON-formatted log output for machine parsing.
- Per-iteration trace: minibatch sample IDs, execution outputs, judge scores, proposer prompt diff.
- Optional OpenTelemetry export for distributed tracing.
**Surface:** `-v` / `--verbose` enables INFO level. `--debug` enables DEBUG level. `--log-format` (text | json). `--log-file` for file output. Config fields: `log_level`, `log_format`, `log_file`.
**Scope:** `cli/app.py` (logging setup), `evolution.py` (structured traces), new `TracingPort` in domain.
---
## 10. CLI Improvements (P2)
**Current state:** Single `optimize` command. Known Typer 0.24 bug absorbing subcommands (Bug #1). No `version`, `init`, or `list-results` commands.
**Feature:**
- Fix Typer subcommand routing.
- `prometheus version` — show version.
- `prometheus init` — scaffold a config YAML interactively.
- `prometheus list` — list past optimization runs.
- `prometheus diff` — compare two result files (before/after prompt diff, score improvement).
- `prometheus eval` — evaluate a prompt against a dataset without optimization.
**Surface:** CLI subcommands.
**Scope:** `cli/app.py` restructured into `cli/commands/` with one module per command.
---
## 11. Input Validation & Schema Enforcement (P2)
**Current state:** Config YAML is parsed as a raw dict with no schema validation. Missing or wrong-type fields cause cryptic errors deep in the pipeline.
**Feature:**
- Validate input YAML against a Pydantic schema (leveraging the existing `pydantic` dependency).
- Provide clear, actionable error messages for missing/invalid fields.
- Support config migration/upgrade from older versions.
**Surface:** Internal. Errors surface as clear CLI messages.
**Scope:** `OptimizationConfig` converted to Pydantic model with validators, `cli/app.py` validation step before pipeline execution.
---
## 12. Adaptive Minibatch Sizing (P3)
**Current state:** Minibatch size is static throughout the run. Small batches are noisy; large batches are expensive.
**Feature:**
- Start with a small minibatch for quick early iterations.
- Increase minibatch size as the prompt improves (higher confidence needed for marginal gains).
- Shrink if too many evaluations fail (cost optimization).
**Surface:** `--adaptive-minibatch` CLI flag (boolean toggle). `minibatch_size` becomes `minibatch_size_min` and `minibatch_size_max` in config.
**Scope:** `EvolutionLoop`, `SyntheticBootstrap`.
---
## 13. Prompt Diversity Tracking (P3)
**Current state:** No visibility into how much the prompt is actually changing between iterations. A "successful" optimization might just rephrase without structural change.
**Feature:**
- Compute edit distance (Levenshtein) or embedding cosine similarity between consecutive prompts.
- Report diversity metrics in the result.
- Flag stagnation (N iterations with <epsilon change).
**Surface:** Internal. Reported in `OptimizationResult.history` entries.
**Scope:** `EvolutionLoop`, `OptimizationResult` (add diversity field per history entry).
---
## 14. Temperature & Sampling Control (P3)
**Current state:** No way to control LLM temperature, top_p, or other sampling parameters for any of the four model slots. DSPy defaults apply.
**Feature:**
- Per-model-slot temperature and sampling parameters.
- Higher temperature for proposer (creativity), lower for judge (consistency).
**Surface:** `task_temperature`, `judge_temperature`, `proposer_temperature`, `synth_temperature` Config fields. `--temperature` CLI flag for global override.
**Scope:** `cli/app.py` (DSPy LM configuration), infrastructure adapters.
---
## 15. Cost Estimation & Budget Caps (P3)
**Current state:** `total_llm_calls` is tracked (inaccurately). No cost estimation, no budget caps.
**Feature:**
- Estimate cost per run based on model pricing and approximate token counts.
- Allow users to set a budget cap (`--max-cost-usd`).
- Report estimated cost in the result.
**Surface:** `--max-cost-usd` CLI flag. `max_cost_usd` Config field. Cost breakdown in result output.
**Scope:** `cli/app.py`, `OptimizationResult` (add cost fields), token counting in adapters.
---
## 16. Multi-Objective Optimization (P3)
**Current state:** Single scalar score from the judge. The `Prompt` entity comment mentions "Pareto tracking" but it's not implemented.
**Feature:**
- Optimize for multiple objectives simultaneously (quality, latency, token efficiency, safety).
- Maintain a Pareto front of non-dominated candidates.
- Allow users to set objective weights or constraints.
**Surface:** `objectives` Config field (list of `{name, weight, judge_criteria}`). CLI: `--objective` repeatable flag.
**Scope:** `EvolutionLoop` (Pareto front), `scoring.py` (multi-objective acceptance), `OptimizationResult` (Pareto set).
---
## 17. Export Optimized Prompt (P3)
**Current state:** The optimized prompt is embedded in the YAML result file. No easy way to extract it for use.
**Feature:**
- `prometheus export` command to extract the optimized prompt as plain text.
- Support multiple export formats: plain text, Markdown, JSON, LangChain template, DSPy module.
- Copy to clipboard option.
**Surface:** `prometheus export --format <txt|md|json|langchain|dspy>` CLI subcommand. `--clipboard` flag.
**Scope:** New `cli/commands/export.py`, format renderers in infrastructure.
---
## 18. Config Profiles / Presets (P3)
**Current state:** Every run requires a full config YAML. Common patterns (fast iterate, thorough optimize, cheap run) are not captured.
**Feature:**
- Named profiles: `fast`, `thorough`, `economy`, `research`.
- Profile overrides individual config fields.
- User-defined profiles stored in `~/.prometheus/profiles/`.
**Surface:** `--profile` CLI flag. `prometheus profile list` / `prometheus profile create` subcommands.
**Scope:** `cli/app.py`, new `ProfileManager` in application.
---
## Summary Table
| # | Feature | Priority | CLI Surface | Config Surface | Estimated Scope |
|---|---------|----------|-------------|----------------|-----------------|
| 1 | Multi-Model Routing | P1 | Existing | Existing | Small |
| 2 | Async / Parallel Execution | P1 | `--max-concurrency` | `max_concurrency` | Medium |
| 3 | Error Handling & Retry | P1 | `--max-retries`, `--error-strategy` | `max_retries`, `error_strategy` | Medium |
| 4 | Checkpoint & Resume | P2 | `--checkpoint-dir`, `--resume` | `checkpoint_interval` | Medium |
| 5 | Population-Based Evolution | P2 | `--population-size`, `--crossover-rate` | `population_size`, `crossover_rate` | Large |
| 6 | Hold-Out Validation | P2 | `--validation-split`, `--early-stop-patience` | `validation_split`, `early_stop_patience` | Medium |
| 7 | Custom Judge Criteria | P2 | `--judge-criteria` | `judge_criteria`, `judge_dimensions` | Medium |
| 8 | Real-World Eval Harness | P2 | `--eval-dataset`, `--eval-metric` | `eval_dataset_path` | Large |
| 9 | Logging & Observability | P2 | `--debug`, `--log-format`, `--log-file` | `log_level`, `log_format` | Medium |
| 10 | CLI Improvements | P2 | Subcommands | — | Medium |
| 11 | Input Validation | P2 | — (error messages) | — | Small |
| 12 | Adaptive Minibatch | P3 | `--adaptive-minibatch` | `minibatch_size_min/max` | Small |
| 13 | Prompt Diversity Tracking | P3 | — | — | Small |
| 14 | Temperature & Sampling | P3 | `--temperature` | `*_temperature` | Small |
| 15 | Cost Estimation | P3 | `--max-cost-usd` | `max_cost_usd` | Small |
| 16 | Multi-Objective Optimization | P3 | `--objective` | `objectives` | Large |
| 17 | Export Optimized Prompt | P3 | `prometheus export` | — | Small |
| 18 | Config Profiles / Presets | P3 | `--profile` | — | Small |
---
## Known Bugs (from TEST_REPORT.md and code review)
| # | Bug | Severity | File |
|---|-----|----------|------|
| 1 | Multi-model config not wired — all adapters use single global LM | HIGH | `cli/app.py`, all adapters |
| 2 | `DSPyLLMAdapter` accepts `model` param but never uses it | HIGH | `infrastructure/llm_adapter.py` |
| 3 | CLI subcommand `optimize` absorbed by Typer 0.24 | HIGH | `cli/app.py` |
| 4 | Verbose logging produces no output — no handler configured | MEDIUM | `cli/app.py` |
| 5 | `total_llm_calls` counter is inaccurate | LOW | `application/use_cases.py`, `evolution.py` |
| 6 | `normalize_score()` is dead code — never called | LOW | `domain/scoring.py` |
| 7 | `AppSettings` is never imported or used | LOW | `config.py` |
| 8 | No LLM error handling in evolution loop | MEDIUM | `evolution.py` |
| 9 | Unpinned dependencies (dspy, typer) | LOW | `pyproject.toml` |
---
## Test Coverage Gaps
| Area | Current | Needed |
|------|---------|--------|
| CLI commands | 0 tests | Unit + integration for each subcommand |
| Config validation | 0 tests | Schema validation, missing fields, type errors |
| Evolution loop | 3 tests (single iteration each) | Multi-iteration, mixed accept/reject, failure recovery |
| Integration pipeline | 1 test (happy path only) | Error paths, mixed results, real adapters |
| Adapter coverage | 1 adapter tested | All 4 adapters + error scenarios |
| Use case orchestration | 1 indirect test | Direct unit tests for `OptimizePromptUseCase` |
---
## Recommended Implementation Order
### Phase 1 — Production Reliability (P1)
1. Fix multi-model routing (#1) — highest impact, smallest scope
2. Add error handling & retry (#3) — essential for production runs
3. Implement async/parallel execution (#2) — biggest wall-clock improvement
### Phase 2 — Optimization Quality (P2)
4. Input validation (#11) — small scope, high reliability gain
5. Logging & observability (#9) — enables debugging long runs
6. CLI improvements (#10) — fix Typer bug, add basic commands
7. Hold-out validation (#6) — prevents overfitting
8. Checkpoint & resume (#4) — essential for long runs
9. Custom judge criteria (#7) — enables domain-specific optimization
### Phase 3 — Advanced Features (P3)
10. Population-based evolution (#5)
11. Real-world eval harness (#8)
12. Remaining P3 features as demand dictates

View File

@@ -5,17 +5,18 @@ description = "Prompt evolution without reference data"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"dspy>=2.6,<3.0", "dspy==2.6.27",
"typer>=0.15,<0.20", "typer==0.19.2",
"pydantic>=2.10", "pydantic==2.12.5",
"pydantic-settings>=2.7", "pydantic-settings==2.13.1",
"pyyaml>=6.0", "pyyaml==6.0.3",
"rich>=13.9", "rich==14.3.3",
] ]
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"pytest>=8.3", "pytest>=8.3",
"pytest-asyncio>=0.24",
"pytest-cov>=6.0", "pytest-cov>=6.0",
"ruff>=0.9", "ruff>=0.9",
"mypy>=1.14", "mypy>=1.14",
@@ -37,11 +38,14 @@ target-version = "py312"
python_version = "3.12" python_version = "3.12"
strict = true strict = true
[tool.pytest.ini_options]
asyncio_mode = "auto"
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = ["dspy", "dspy.*"] module = ["dspy", "dspy.*"]
ignore_missing_imports = true ignore_missing_imports = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = ["prometheus.infrastructure.*", "prometheus.cli.app"] module = ["prometheus.infrastructure.*", "prometheus.cli.app", "prometheus.cli.commands.*"]
disable_error_code = ["misc", "import-untyped"] disable_error_code = ["misc", "import-untyped"]

View File

@@ -22,6 +22,24 @@ class SyntheticBootstrap:
self._generator = generator self._generator = generator
self._rng = random.Random(seed) self._rng = random.Random(seed)
@staticmethod
def split_pool(
pool: list[SyntheticExample],
validation_fraction: float,
rng: random.Random | None = None,
) -> tuple[list[SyntheticExample], list[SyntheticExample]]:
"""Split *pool* into (train, validation) sets.
Returns (pool, []) when *validation_fraction* is 0.
"""
if validation_fraction <= 0.0 or len(pool) < 2:
return pool, []
n_val = max(1, int(len(pool) * validation_fraction))
shuffled = list(pool)
_rng = rng or random.Random(42)
_rng.shuffle(shuffled)
return shuffled[:-n_val], shuffled[-n_val:]
def run(self, task_description: str, n_examples: int) -> list[SyntheticExample]: def run(self, task_description: str, n_examples: int) -> list[SyntheticExample]:
"""Generate the synthetic pool in a single call. """Generate the synthetic pool in a single call.

View File

@@ -4,34 +4,211 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator
@dataclass
class OptimizationConfig:
"""Complete configuration for a PROMETHEUS run."""
# --- Prompt --- # Current config schema version.
seed_prompt: str CONFIG_VERSION = 1
task_description: str
_ERROR_STRATEGY_VALUES = {"skip", "retry", "abort"}
_EVAL_METRIC_VALUES = {"exact", "bleu", "rouge_l", "cosine", "llm_judge"}
class JudgeDimension(BaseModel):
"""A single evaluation dimension for multi-dimensional scoring."""
name: str = Field(min_length=1, description="Dimension name (e.g. accuracy, clarity, safety).")
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="Weight for this dimension (0.01.0).")
description: str = Field(default="", description="What this dimension measures.")
class OptimizationConfig(BaseModel):
"""Complete configuration for a PROMETHEUS run.
Validated with Pydantic so missing or wrong-type fields produce clear,
actionable error messages at the CLI boundary instead of cryptic failures
deep in the pipeline.
"""
model_config = {"extra": "forbid"}
# --- Schema version (for migration support) ---
config_version: int = Field(
default=CONFIG_VERSION,
description="Config schema version. Set automatically; used for migration.",
)
# --- Prompt (required) ---
seed_prompt: str = Field(
...,
min_length=1,
description="The initial prompt to optimize.",
)
task_description: str = Field(
...,
min_length=1,
description="Description of the task the prompt should accomplish.",
)
# --- Models --- # --- Models ---
task_model: str = "openai/gpt-4o-mini" task_model: str = Field(default="openai/gpt-4o-mini", min_length=1)
judge_model: str = "openai/gpt-4o" judge_model: str = Field(default="openai/gpt-4o", min_length=1)
proposer_model: str = "openai/gpt-4o" proposer_model: str = Field(default="openai/gpt-4o", min_length=1)
synth_model: str = "openai/gpt-4o" synth_model: str = Field(default="openai/gpt-4o", min_length=1)
# --- Per-model API overrides (optional, fall back to global api_base/api_key_env) ---
task_api_base: str | None = None
task_api_key_env: str | None = None
judge_api_base: str | None = None
judge_api_key_env: str | None = None
proposer_api_base: str | None = None
proposer_api_key_env: str | None = None
synth_api_base: str | None = None
synth_api_key_env: str | None = None
# --- Global API settings (optional) ---
api_base: str | None = None
api_key_env: str | None = None
# --- Evolution parameters --- # --- Evolution parameters ---
max_iterations: int = 30 max_iterations: int = Field(default=30, ge=1, description="Maximum evolution iterations.")
n_synthetic_inputs: int = 20 n_synthetic_inputs: int = Field(default=20, ge=1, description="Number of synthetic inputs to generate.")
minibatch_size: int = 5 minibatch_size: int = Field(default=5, ge=1, description="Inputs per evaluation minibatch.")
perfect_score: float = 1.0 perfect_score: float = Field(default=1.0, ge=0.0, le=1.0)
# --- Population-based evolution ---
population_size: int = Field(
default=1,
ge=1,
description="Number of candidates in the evolution population. 1 = single-candidate hill climbing (backward compat).",
)
crossover_rate: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Probability of applying crossover vs reflective mutation.",
)
mutation_rate: float = Field(
default=0.3,
ge=0.0,
le=1.0,
description="Probability of applying a mutation operator after crossover or proposal.",
)
diversity_penalty: float = Field(
default=0.1,
ge=0.0,
description="Penalty weight for similarity to existing population members.",
)
# --- Reproducibility --- # --- Reproducibility ---
seed: int = 42 seed: int = Field(default=42, ge=0)
# --- Concurrency ---
max_concurrency: int = Field(default=5, ge=1, description="Max parallel LLM calls.")
# --- Error handling ---
max_retries: int = Field(default=3, ge=0, description="Max retry attempts for transient errors.")
retry_delay_base: float = Field(default=1.0, gt=0, description="Base delay in seconds for retry backoff.")
circuit_breaker_threshold: int = Field(default=5, ge=1, description="Consecutive failures before circuit opens.")
error_strategy: str = Field(default="retry", description="Error handling strategy: skip | retry | abort.")
# --- Logging & observability ---
debug: bool = Field(default=False, description="Enable DEBUG-level logging.")
log_format: str = Field(default="text", description="Log output format: text | json.")
log_file: str | None = Field(default=None, description="Optional file path for log output.")
# --- Checkpoint & resume ---
checkpoint_dir: str | None = Field(
default=None,
description="Directory for checkpoint files. Set to enable checkpointing.",
)
checkpoint_interval: int = Field(
default=5,
ge=1,
description="Save a checkpoint every N iterations (and on every accepted improvement).",
)
resume: bool = Field(
default=False,
description="Resume from the latest checkpoint in checkpoint_dir.",
)
# --- Output --- # --- Output ---
output_path: str = "output.yaml" output_path: str = Field(default="output.yaml", min_length=1)
verbose: bool = False verbose: bool = False
# --- Hold-out validation ---
validation_split: float = Field(
default=0.3,
ge=0.0,
lt=1.0,
description="Fraction of synthetic pool reserved for validation (0 = disabled).",
)
early_stop_patience: int = Field(
default=5,
ge=1,
description="Stop if validation score degrades for this many consecutive iterations.",
)
# --- Judge criteria & multi-dimensional scoring ---
judge_criteria: str | None = Field(
default=None,
description="Custom judge rubric or evaluation criteria override (free text).",
)
judge_dimensions: list[JudgeDimension] | None = Field(
default=None,
description="Multi-dimensional scoring dimensions with configurable weights.",
)
# --- Ground-truth evaluation ---
eval_dataset_path: str | None = Field(
default=None,
min_length=1,
description="Path to a CSV/JSON dataset with 'input' and 'expected_output' columns.",
)
eval_metric: str = Field(
default="bleu",
description="Similarity metric for ground-truth eval: exact | bleu | rouge_l | cosine | llm_judge.",
)
@field_validator("log_format")
@classmethod
def _validate_log_format(cls, v: str) -> str:
allowed = {"text", "json"}
if v not in allowed:
raise ValueError(f"log_format must be one of {sorted(allowed)}, got '{v}'")
return v
@field_validator("error_strategy")
@classmethod
def _validate_error_strategy(cls, v: str) -> str:
if v not in _ERROR_STRATEGY_VALUES:
raise ValueError(
f"error_strategy must be one of {sorted(_ERROR_STRATEGY_VALUES)}, got '{v}'"
)
return v
@field_validator("eval_metric")
@classmethod
def _validate_eval_metric(cls, v: str) -> str:
if v not in _EVAL_METRIC_VALUES:
raise ValueError(
f"eval_metric must be one of {sorted(_EVAL_METRIC_VALUES)}, got '{v}'"
)
return v
@model_validator(mode="before")
@classmethod
def _migrate_config(cls, data: Any) -> Any:
"""Apply migration transforms for older config versions."""
if isinstance(data, dict):
version = data.get("config_version", CONFIG_VERSION)
# Future migrations go here, e.g.:
# if version < 2:
# data = _migrate_v1_to_v2(data)
# Always stamp current version.
data["config_version"] = CONFIG_VERSION
return data
@dataclass @dataclass
class OptimizationResult: class OptimizationResult:
@@ -45,3 +222,7 @@ class OptimizationResult:
final_score: float final_score: float
improvement: float improvement: float
history: list[dict[str, Any]] = field(default_factory=list) history: list[dict[str, Any]] = field(default_factory=list)
# Hold-out validation metrics (populated when validation_split > 0)
final_validation_score: float | None = None
best_validation_score: float | None = None
early_stopped: bool = False

View File

@@ -6,6 +6,9 @@ Combines candidate prompt execution + LLM-as-Judge evaluation.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import logging
from prometheus.domain.entities import ( from prometheus.domain.entities import (
EvalResult, EvalResult,
Prompt, Prompt,
@@ -13,6 +16,9 @@ from prometheus.domain.entities import (
Trajectory, Trajectory,
) )
from prometheus.domain.ports import JudgePort, LLMPort from prometheus.domain.ports import JudgePort, LLMPort
from prometheus.domain.scoring import normalize_score
logger = logging.getLogger(__name__)
class PromptEvaluator: class PromptEvaluator:
@@ -21,13 +27,23 @@ class PromptEvaluator:
Pipeline: execute → judge → build trajectories. Pipeline: execute → judge → build trajectories.
Replaces GEPA's EvaluatorFn. Instead of comparing to ground truth, Replaces GEPA's EvaluatorFn. Instead of comparing to ground truth,
uses an LLM-as-Judge. uses an LLM-as-Judge.
Execution and judge calls run in parallel (bounded by *max_concurrency*).
Per-call isolation: a failure on one minibatch item produces a
zero-score trajectory instead of crashing the whole batch.
""" """
def __init__(self, executor: LLMPort, judge: JudgePort): def __init__(
self,
executor: LLMPort,
judge: JudgePort,
max_concurrency: int = 5,
):
self._executor = executor self._executor = executor
self._judge = judge self._judge = judge
self._semaphore = asyncio.Semaphore(max_concurrency)
def evaluate( async def evaluate(
self, self,
prompt: Prompt, prompt: Prompt,
minibatch: list[SyntheticExample], minibatch: list[SyntheticExample],
@@ -36,19 +52,20 @@ class PromptEvaluator:
"""Evaluate the prompt on the minibatch. """Evaluate the prompt on the minibatch.
Steps: Steps:
1. Execute the prompt on each input in the minibatch 1. Execute the prompt on each input in the minibatch (parallel)
2. Judge each (input, output) pair 2. Judge each (input, output) pair
3. Build trajectories with feedback 3. Build trajectories with feedback
""" """
# Step 1: Execution # Step 1: Parallel execution (per-item isolation)
outputs: list[str] = [] output_coros = [
for example in minibatch: self._execute_single(prompt, example)
raw_output = self._executor.execute(prompt, example.input_text) for example in minibatch
outputs.append(raw_output) ]
outputs = await asyncio.gather(*output_coros)
# Step 2: Judgement # Step 2: Judgement (judge_adapter handles its own per-call isolation + parallelism)
pairs = [(ex.input_text, out) for ex, out in zip(minibatch, outputs)] pairs = [(ex.input_text, out) for ex, out in zip(minibatch, outputs)]
judge_results = self._judge.judge_batch(task_description, pairs) judge_results = await self._judge.judge_batch(task_description, pairs)
# Step 3: Build trajectories # Step 3: Build trajectories
scores: list[float] = [] scores: list[float] = []
@@ -56,6 +73,7 @@ class PromptEvaluator:
trajectories: list[Trajectory] = [] trajectories: list[Trajectory] = []
for i, (example, output) in enumerate(zip(minibatch, outputs)): for i, (example, output) in enumerate(zip(minibatch, outputs)):
score, feedback = judge_results[i] score, feedback = judge_results[i]
score = normalize_score(score)
scores.append(score) scores.append(score)
feedbacks.append(feedback) feedbacks.append(feedback)
trajectories.append( trajectories.append(
@@ -73,3 +91,17 @@ class PromptEvaluator:
feedbacks=feedbacks, feedbacks=feedbacks,
trajectories=trajectories, trajectories=trajectories,
) )
async def _execute_single(
self, prompt: Prompt, example: SyntheticExample
) -> str:
async with self._semaphore:
try:
return await self._executor.execute(prompt, example.input_text)
except Exception as exc:
logger.warning(
"Execution failed for input '%s': %s",
example.input_text[:40],
exc,
)
return f"[execution error: {exc}]"

View File

@@ -2,33 +2,51 @@
Evolution loop — core PROMETHEUS engine. Evolution loop — core PROMETHEUS engine.
Orchestrates the select → evaluate → propose → accept cycle. Orchestrates the select → evaluate → propose → accept cycle.
Equivalent to GEPAEngine.run(), adapted to work without a valset. Supports two modes:
- Single-candidate hill climbing (population_size=1, backward compat)
- Population-based evolution with crossover & mutation (population_size>1)
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
import random
from prometheus.application.bootstrap import SyntheticBootstrap from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.evaluator import PromptEvaluator from prometheus.application.evaluator import PromptEvaluator
from prometheus.cli.logging_setup import get_logger
from prometheus.domain.entities import ( from prometheus.domain.entities import (
Candidate, Candidate,
OptimizationState, OptimizationState,
Prompt, Prompt,
SyntheticExample, SyntheticExample,
) )
from prometheus.domain.ports import ProposerPort from prometheus.domain.ports import (
CheckpointPort,
CrossoverPort,
MutationPort,
ProposerPort,
)
from prometheus.domain.scoring import should_accept from prometheus.domain.scoring import should_accept
logger = logging.getLogger(__name__) logger = get_logger("evolution")
class CircuitBreakerOpen(Exception):
"""Raised when the circuit breaker trips due to too many consecutive failures."""
class EvolutionLoop: class EvolutionLoop:
"""Main evolution loop. """Main evolution loop.
Design: Design:
- Keeps only the best candidate (no full population). - population_size=1: classic single-candidate hill climbing (backward compat).
- Simplifies vs GEPA (no Pareto, no merge). - population_size>1: population-based evolution with crossover, mutation,
- Population support deferred to v2. and diversity maintenance.
Error handling:
- Transient errors are retried by adapters.
- Circuit breaker trips after N consecutive iteration failures.
- error_strategy controls what happens on non-transient errors.
""" """
def __init__( def __init__(
@@ -40,6 +58,19 @@ class EvolutionLoop:
minibatch_size: int = 5, minibatch_size: int = 5,
perfect_score: float = 1.0, perfect_score: float = 1.0,
verbose: bool = False, verbose: bool = False,
circuit_breaker_threshold: int = 5,
error_strategy: str = "retry",
checkpoint_port: CheckpointPort | None = None,
checkpoint_interval: int = 5,
# --- Population-based evolution params ---
population_size: int = 1,
crossover_rate: float = 0.5,
mutation_rate: float = 0.3,
diversity_penalty: float = 0.1,
crossover_port: CrossoverPort | None = None,
mutation_port: MutationPort | None = None,
# --- Hold-out validation params ---
early_stop_patience: int = 5,
): ):
self._evaluator = evaluator self._evaluator = evaluator
self._proposer = proposer self._proposer = proposer
@@ -48,53 +79,358 @@ class EvolutionLoop:
self._minibatch_size = minibatch_size self._minibatch_size = minibatch_size
self._perfect_score = perfect_score self._perfect_score = perfect_score
self._verbose = verbose self._verbose = verbose
self._circuit_breaker_threshold = circuit_breaker_threshold
self._error_strategy = error_strategy
self._checkpoint_port = checkpoint_port
self._checkpoint_interval = checkpoint_interval
self._population_size = population_size
self._crossover_rate = crossover_rate
self._mutation_rate = mutation_rate
self._diversity_penalty = diversity_penalty
self._crossover_port = crossover_port
self._mutation_port = mutation_port
self._early_stop_patience = early_stop_patience
def run( async def run(
self, self,
seed_prompt: Prompt, seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample], synthetic_pool: list[SyntheticExample],
task_description: str, task_description: str,
initial_state: OptimizationState | None = None,
validation_pool: list[SyntheticExample] | None = None,
) -> OptimizationState: ) -> OptimizationState:
"""Execute the complete evolution loop.""" """Execute the complete evolution loop.
state = OptimizationState()
# Evaluate the seed If *initial_state* is provided (from a checkpoint), resume from that
point — skipping the seed evaluation and continuing at the saved iteration.
If *validation_pool* is provided (non-empty), the best candidate is
evaluated on the hold-out set after each iteration and early stopping
is applied when validation score degrades for ``early_stop_patience``
consecutive iterations.
"""
state = initial_state or OptimizationState()
consecutive_failures = 0
# Hold-out validation tracking
has_validation = bool(validation_pool)
best_validation_score: float = -1.0
validation_patience_counter: int = 0
# Only evaluate the seed when starting fresh (no checkpoint resume)
if initial_state is None:
initial_batch = self._bootstrap.sample_minibatch( initial_batch = self._bootstrap.sample_minibatch(
synthetic_pool, self._minibatch_size synthetic_pool, self._minibatch_size
) )
initial_eval = self._evaluator.evaluate( initial_eval = await self._evaluator.evaluate(
seed_prompt, initial_batch, task_description seed_prompt, initial_batch, task_description
) )
state.total_llm_calls += 2 * self._minibatch_size # N executions + N judge calls state.total_llm_calls += 2 * self._minibatch_size # N executions + N judge calls
best_candidate = Candidate( seed_candidate = Candidate(
prompt=seed_prompt, prompt=seed_prompt,
best_score=initial_eval.total_score, best_score=initial_eval.total_score,
generation=0, generation=0,
) )
state.best_candidate = best_candidate state.best_candidate = seed_candidate
state.candidates.append(best_candidate) state.candidates.append(seed_candidate)
self._log(f"Initial score: {initial_eval.total_score:.2f}") logger.info(
"Initial evaluation complete",
extra={
"structured": {
"event": "initial_eval",
"score": round(initial_eval.total_score, 4),
"minibatch_size": self._minibatch_size,
"sample_ids": [ex.id for ex in initial_batch],
},
},
)
# Evaluate seed on validation set
if has_validation and state.best_candidate is not None:
val_eval = await self._evaluator.evaluate(
state.best_candidate.prompt, validation_pool, task_description
)
state.total_llm_calls += 2 * len(validation_pool)
best_validation_score = val_eval.mean_score
logger.info(
"Initial validation evaluation",
extra={
"structured": {
"event": "validation_eval",
"iteration": 0,
"validation_score": round(best_validation_score, 4),
"validation_pool_size": len(validation_pool),
},
},
)
# Population initialization: seed the population with mutations
if self._population_size > 1:
await self._initialize_population(
state, seed_prompt, seed_candidate, task_description
)
else:
logger.info(
"Resuming from checkpoint",
extra={
"structured": {
"event": "resume",
"iteration": state.iteration,
"total_llm_calls": state.total_llm_calls,
},
},
)
# Restore validation tracking from state history
if has_validation:
for entry in reversed(state.history):
if entry.get("event") == "validation_eval":
best_validation_score = entry.get("best_validation_score", -1.0)
validation_patience_counter = entry.get("validation_patience", 0)
break
# Determine starting iteration
start_iteration = state.iteration + 1
# Main loop # Main loop
for i in range(1, self._max_iterations + 1): for i in range(start_iteration, self._max_iterations + 1):
state.iteration = i state.iteration = i
try: try:
if self._population_size > 1 and len(state.candidates) > 1:
await self._run_population_iteration(
i, state, synthetic_pool, task_description
)
else:
await self._run_single_iteration(
i, state, synthetic_pool, task_description
)
consecutive_failures = 0
# Hold-out validation: evaluate best candidate on validation set
if has_validation and state.best_candidate is not None:
val_eval = await self._evaluator.evaluate(
state.best_candidate.prompt, validation_pool, task_description
)
state.total_llm_calls += 2 * len(validation_pool)
current_val_score = val_eval.mean_score
if current_val_score > best_validation_score:
best_validation_score = current_val_score
validation_patience_counter = 0
else:
validation_patience_counter += 1
state.history.append({
"iteration": i,
"event": "validation_eval",
"validation_score": round(current_val_score, 4),
"best_validation_score": round(best_validation_score, 4),
"validation_patience": validation_patience_counter,
})
logger.info(
"Validation evaluation",
extra={
"structured": {
"event": "validation_eval",
"iteration": i,
"validation_score": round(current_val_score, 4),
"best_validation_score": round(best_validation_score, 4),
"patience": f"{validation_patience_counter}/{self._early_stop_patience}",
},
},
)
if validation_patience_counter >= self._early_stop_patience:
logger.warning(
"Early stopping triggered — validation score did not improve for %d iterations",
self._early_stop_patience,
extra={
"structured": {
"event": "early_stop",
"iteration": i,
"best_validation_score": round(best_validation_score, 4),
"patience": self._early_stop_patience,
},
},
)
state.history.append({
"iteration": i,
"event": "early_stop",
"best_validation_score": round(best_validation_score, 4),
"patience": self._early_stop_patience,
})
state.best_validation_score = best_validation_score
state.early_stopped = True
if self._checkpoint_port is not None:
self._checkpoint_port.save(state)
break
# Checkpoint on accepted improvement (detected via state change)
self._maybe_checkpoint(state)
except Exception as exc:
consecutive_failures += 1
logger.error(
"Iteration error",
extra={
"structured": {
"event": "iteration_error",
"iteration": i,
"consecutive_failures": consecutive_failures,
"error": str(exc),
},
},
exc_info=True,
)
state.history.append(
{
"iteration": i,
"event": "error",
"error": str(exc),
"consecutive_failures": consecutive_failures,
}
)
# Check circuit breaker
if consecutive_failures >= self._circuit_breaker_threshold:
logger.warning(
"Circuit breaker tripped",
extra={
"structured": {
"event": "circuit_breaker",
"iteration": i,
"consecutive_failures": consecutive_failures,
"error_strategy": self._error_strategy,
},
},
)
state.history.append(
{
"iteration": i,
"event": "circuit_breaker",
"consecutive_failures": consecutive_failures,
}
)
if self._error_strategy == "abort":
raise CircuitBreakerOpen(
f"Circuit breaker tripped after "
f"{consecutive_failures} consecutive failures"
) from exc
# skip / retry strategies: save checkpoint, then stop the loop gracefully
if self._checkpoint_port is not None:
self._checkpoint_port.save(state)
break
if self._error_strategy == "abort":
raise
# skip / retry: continue to next iteration
continue
# Store final validation metadata on state
if has_validation:
state.best_validation_score = best_validation_score
return state
# ------------------------------------------------------------------
# Population initialization
# ------------------------------------------------------------------
async def _initialize_population(
self,
state: OptimizationState,
seed_prompt: Prompt,
seed_candidate: Candidate,
task_description: str,
) -> None:
"""Fill the population with mutated variants of the seed prompt."""
n_needed = self._population_size - 1
mutation_types = ["paraphrase", "constrain", "generalize", "specialize"]
for idx in range(n_needed):
mutation_type = mutation_types[idx % len(mutation_types)]
if self._mutation_port is not None:
new_prompt = await self._mutation_port.mutate(
seed_prompt, task_description, mutation_type
)
else:
# Fallback: use proposer for reflective mutation
new_prompt = await self._proposer.propose(
seed_prompt, [], task_description
)
state.total_llm_calls += 1
new_candidate = Candidate(
prompt=new_prompt,
best_score=seed_candidate.best_score, # estimate until evaluated
generation=0,
parent_id=id(seed_candidate),
)
state.candidates.append(new_candidate)
logger.info(
"Population initialized",
extra={
"structured": {
"event": "population_init",
"population_size": len(state.candidates),
},
},
)
# ------------------------------------------------------------------
# Single-candidate iteration (original hill-climbing)
# ------------------------------------------------------------------
async def _run_single_iteration(
self,
i: int,
state: OptimizationState,
synthetic_pool: list[SyntheticExample],
task_description: str,
) -> None:
"""Execute a single-candidate iteration. Mutates *state* in-place."""
best_candidate = state.best_candidate # type: ignore[assignment]
# 1. Sample a fresh minibatch # 1. Sample a fresh minibatch
batch = self._bootstrap.sample_minibatch( batch = self._bootstrap.sample_minibatch(
synthetic_pool, self._minibatch_size synthetic_pool, self._minibatch_size
) )
sample_ids = [ex.id for ex in batch]
# 2. Evaluate the current candidate # 2. Evaluate the current candidate
current_eval = self._evaluator.evaluate( current_eval = await self._evaluator.evaluate(
best_candidate.prompt, batch, task_description best_candidate.prompt, batch, task_description
) )
state.total_llm_calls += 2 * self._minibatch_size state.total_llm_calls += 2 * self._minibatch_size
logger.debug(
"Iteration minibatch evaluated",
extra={
"structured": {
"event": "minibatch_eval",
"iteration": i,
"sample_ids": sample_ids,
"scores": [round(s, 4) for s in current_eval.scores],
"total_score": round(current_eval.total_score, 4),
},
},
)
# 3. Skip if perfect # 3. Skip if perfect
if all(s >= self._perfect_score for s in current_eval.scores): if all(s >= self._perfect_score for s in current_eval.scores):
self._log(f"Iter {i}: All scores perfect, skipping.") logger.info(
"Iteration skipped — all scores perfect",
extra={
"structured": {
"event": "skip_perfect",
"iteration": i,
"total_score": round(current_eval.total_score, 4),
},
},
)
state.history.append( state.history.append(
{ {
"iteration": i, "iteration": i,
@@ -102,35 +438,62 @@ class EvolutionLoop:
"current_score": current_eval.total_score, "current_score": current_eval.total_score,
} }
) )
continue return
# 4. Propose a new prompt (reflective mutation) # 4. Propose a new prompt (reflective mutation) — sequential
new_prompt = self._proposer.propose( state.total_llm_calls += 1
new_prompt = await self._proposer.propose(
best_candidate.prompt, best_candidate.prompt,
current_eval.trajectories, current_eval.trajectories,
task_description, task_description,
) )
state.total_llm_calls += 1 # 1 proposition call
prompt_diff = self._compute_prompt_diff(
best_candidate.prompt.text, new_prompt.text
)
logger.debug(
"Proposed new prompt",
extra={
"structured": {
"event": "proposer_output",
"iteration": i,
"prompt_diff": prompt_diff,
},
},
)
# 5. Evaluate the new prompt on the same minibatch # 5. Evaluate the new prompt on the same minibatch
new_eval = self._evaluator.evaluate( new_eval = await self._evaluator.evaluate(
new_prompt, batch, task_description new_prompt, batch, task_description
) )
state.total_llm_calls += 2 * self._minibatch_size state.total_llm_calls += 2 * self._minibatch_size
# 6. Accept or reject # 6. Accept or reject
if should_accept(current_eval, new_eval): if should_accept(current_eval, new_eval):
best_candidate = Candidate( new_candidate = Candidate(
prompt=new_prompt, prompt=new_prompt,
best_score=new_eval.total_score, best_score=new_eval.total_score,
generation=i, generation=i,
parent_id=id(best_candidate), parent_id=id(best_candidate),
) )
state.best_candidate = best_candidate state.best_candidate = new_candidate
state.candidates.append(best_candidate) state.candidates.append(new_candidate)
self._log( logger.info(
f"Iter {i}: ACCEPTED " "Iteration accepted",
f"({current_eval.total_score:.2f} -> {new_eval.total_score:.2f})" extra={
"structured": {
"event": "accepted",
"iteration": i,
"old_score": round(current_eval.total_score, 4),
"new_score": round(new_eval.total_score, 4),
"improvement": round(
new_eval.total_score - current_eval.total_score, 4
),
"sample_ids": sample_ids,
"new_scores": [round(s, 4) for s in new_eval.scores],
"prompt_diff": prompt_diff,
},
},
) )
state.history.append( state.history.append(
{ {
@@ -143,9 +506,19 @@ class EvolutionLoop:
} }
) )
else: else:
self._log( logger.info(
f"Iter {i}: REJECTED " "Iteration rejected",
f"({new_eval.total_score:.2f} <= {current_eval.total_score:.2f})" extra={
"structured": {
"event": "rejected",
"iteration": i,
"old_score": round(current_eval.total_score, 4),
"new_score": round(new_eval.total_score, 4),
"sample_ids": sample_ids,
"new_scores": [round(s, 4) for s in new_eval.scores],
"prompt_diff": prompt_diff,
},
},
) )
state.history.append( state.history.append(
{ {
@@ -156,19 +529,213 @@ class EvolutionLoop:
} }
) )
except Exception as exc: # ------------------------------------------------------------------
self._log(f"Iter {i}: ERROR — {exc}. Skipping iteration.") # Population-based iteration
# ------------------------------------------------------------------
async def _run_population_iteration(
self,
i: int,
state: OptimizationState,
synthetic_pool: list[SyntheticExample],
task_description: str,
) -> None:
"""Execute a population-based iteration. Mutates *state* in-place."""
population = state.candidates
# 1. Sample a fresh minibatch
batch = self._bootstrap.sample_minibatch(
synthetic_pool, self._minibatch_size
)
sample_ids = [ex.id for ex in batch]
# 2. Select two parents via tournament selection
parent_a = self._tournament_select(population)
parent_b = self._tournament_select(population)
# 3. Generate child: crossover or reflective mutation
use_crossover = (
random.random() < self._crossover_rate
and self._crossover_port is not None
)
if use_crossover:
state.total_llm_calls += 1
child_prompt = await self._crossover_port.crossover(
parent_a.prompt, parent_b.prompt, task_description
)
origin = "crossover"
else:
# Reflective mutation: evaluate a parent, propose improvement
state.total_llm_calls += 2 * self._minibatch_size
parent_eval = await self._evaluator.evaluate(
parent_a.prompt, batch, task_description
)
state.total_llm_calls += 1
child_prompt = await self._proposer.propose(
parent_a.prompt,
parent_eval.trajectories,
task_description,
)
origin = "reflective"
# 4. Optional mutation
if random.random() < self._mutation_rate and self._mutation_port is not None:
mutation_type = random.choice(
["paraphrase", "constrain", "generalize", "specialize"]
)
state.total_llm_calls += 1
child_prompt = await self._mutation_port.mutate(
child_prompt, task_description, mutation_type
)
origin += f"+mutation({mutation_type})"
# 5. Evaluate the child
state.total_llm_calls += 2 * self._minibatch_size
child_eval = await self._evaluator.evaluate(
child_prompt, batch, task_description
)
# 6. Compute fitness with diversity penalty
child_score = child_eval.total_score
diversity_sim = self._compute_diversity_score(
child_prompt, population
)
child_fitness = child_score - self._diversity_penalty * (1.0 - diversity_sim)
# 7. Find worst candidate and replace if child is better
worst_idx = min(
range(len(population)),
key=lambda idx: population[idx].best_score,
)
worst_fitness = (
population[worst_idx].best_score
- self._diversity_penalty * (1.0 - self._compute_diversity_score(
population[worst_idx].prompt, population
))
)
accepted = child_fitness > worst_fitness
if accepted:
new_candidate = Candidate(
prompt=child_prompt,
best_score=child_score,
generation=i,
parent_id=id(parent_a),
)
population[worst_idx] = new_candidate
# Update best if this child is the new best
if child_score > (state.best_candidate.best_score if state.best_candidate else 0):
state.best_candidate = new_candidate
logger.info(
"Population iteration accepted",
extra={
"structured": {
"event": "pop_accepted",
"iteration": i,
"origin": origin,
"child_score": round(child_score, 4),
"child_fitness": round(child_fitness, 4),
"diversity_sim": round(diversity_sim, 4),
"replaced_idx": worst_idx,
"sample_ids": sample_ids,
},
},
)
state.history.append( state.history.append(
{ {
"iteration": i, "iteration": i,
"event": "error", "event": "pop_accepted",
"error": str(exc), "origin": origin,
"child_score": child_score,
"child_fitness": child_fitness,
"diversity_sim": diversity_sim,
} }
) )
else:
logger.info(
"Population iteration rejected",
extra={
"structured": {
"event": "pop_rejected",
"iteration": i,
"origin": origin,
"child_score": round(child_score, 4),
"child_fitness": round(child_fitness, 4),
"worst_fitness": round(worst_fitness, 4),
"sample_ids": sample_ids,
},
},
)
state.history.append(
{
"iteration": i,
"event": "pop_rejected",
"origin": origin,
"child_score": child_score,
"child_fitness": child_fitness,
"worst_fitness": worst_fitness,
}
)
# ------------------------------------------------------------------
# Selection and diversity helpers
# ------------------------------------------------------------------
def _tournament_select(
self,
population: list[Candidate],
tournament_size: int = 3,
) -> Candidate:
"""Tournament selection: pick the best from a random subset."""
k = min(tournament_size, len(population))
contestants = random.sample(population, k)
return max(contestants, key=lambda c: c.best_score)
@staticmethod
def _compute_diversity_score(
prompt: Prompt,
population: list[Candidate],
) -> float:
"""Compute the average Jaccard similarity between *prompt* and all
population members. Returns 1.0 when population has only one member
(no diversity penalty)."""
if len(population) <= 1:
return 1.0
prompt_words = set(prompt.text.lower().split())
if not prompt_words:
return 0.0
similarities: list[float] = []
for candidate in population:
other_words = set(candidate.prompt.text.lower().split())
if not other_words:
continue continue
intersection = prompt_words & other_words
union = prompt_words | other_words
sim = len(intersection) / len(union) if union else 0.0
similarities.append(sim)
return state # Average similarity (lower = more diverse)
return sum(similarities) / len(similarities) if similarities else 0.0
def _log(self, msg: str) -> None: @staticmethod
if self._verbose: def _compute_prompt_diff(old: str, new: str) -> dict[str, int]:
logger.info("[PROMETHEUS] %s", msg) """Compute a simple diff summary between two prompts."""
old_lines = set(old.splitlines())
new_lines = set(new.splitlines())
return {
"lines_added": len(new_lines - old_lines),
"lines_removed": len(old_lines - new_lines),
"chars_delta": len(new) - len(old),
}
def _maybe_checkpoint(self, state: OptimizationState) -> None:
"""Save a checkpoint if the interval is met or on accepted improvements."""
if self._checkpoint_port is None:
return
if state.iteration % self._checkpoint_interval == 0:
self._checkpoint_port.save(state)

View File

@@ -0,0 +1,116 @@
"""
Ground-truth evaluator — execution + similarity comparison.
Produces a quality signal *with* ground truth by comparing model outputs
against expected outputs using a configurable similarity metric.
"""
from __future__ import annotations
import asyncio
import logging
from prometheus.domain.entities import (
EvalResult,
GroundTruthExample,
Prompt,
Trajectory,
)
from prometheus.domain.ports import LLMPort, SimilarityPort
logger = logging.getLogger(__name__)
class GroundTruthEvaluator:
"""Evaluates a prompt against a ground-truth dataset.
Pipeline: execute → compare with similarity metric → build trajectories.
Unlike PromptEvaluator (which uses LLM-as-Judge), this compares outputs
directly against known-good expected outputs.
"""
def __init__(
self,
executor: LLMPort,
similarity: SimilarityPort,
max_concurrency: int = 5,
):
self._executor = executor
self._similarity = similarity
self._semaphore = asyncio.Semaphore(max_concurrency)
async def evaluate(
self,
prompt: Prompt,
dataset: list[GroundTruthExample],
) -> EvalResult:
"""Evaluate the prompt on the ground-truth dataset.
Steps:
1. Execute the prompt on each input (parallel, bounded)
2. Compare each output against expected using similarity metric
3. Build trajectories with feedback
"""
# Step 1: Parallel execution (per-item isolation)
output_coros = [
self._execute_single(prompt, example) for example in dataset
]
outputs = await asyncio.gather(*output_coros)
# Step 2: Compute similarity scores
scores: list[float] = []
feedbacks: list[str] = []
trajectories: list[Trajectory] = []
for example, output in zip(dataset, outputs):
score = self._similarity.compute(output, example.expected_output)
score = max(0.0, min(1.0, score)) # normalize to [0, 1]
scores.append(score)
feedback = self._build_feedback(output, example.expected_output, score)
feedbacks.append(feedback)
trajectories.append(
Trajectory(
input_text=example.input_text,
output_text=output,
score=score,
feedback=feedback,
prompt_used=prompt.text,
)
)
logger.info(
"Ground-truth evaluation complete: %d items, mean_score=%.4f",
len(dataset),
sum(scores) / len(scores) if scores else 0.0,
)
return EvalResult(
scores=scores,
feedbacks=feedbacks,
trajectories=trajectories,
)
async def _execute_single(
self, prompt: Prompt, example: GroundTruthExample
) -> str:
async with self._semaphore:
try:
return await self._executor.execute(prompt, example.input_text)
except Exception as exc:
logger.warning(
"Execution failed for input '%s': %s",
example.input_text[:40],
exc,
)
return f"[execution error: {exc}]"
@staticmethod
def _build_feedback(output: str, expected: str, score: float) -> str:
"""Build human-readable feedback for a ground-truth comparison."""
if score >= 0.99:
return "Exact match."
elif score >= 0.7:
return f"Close match (score={score:.2f}). Expected: {expected[:100]}"
elif score >= 0.3:
return f"Partial match (score={score:.2f}). Expected: {expected[:100]}"
else:
return f"Poor match (score={score:.2f}). Expected: {expected[:100]}"

View File

@@ -10,8 +10,16 @@ from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.dto import OptimizationConfig, OptimizationResult from prometheus.application.dto import OptimizationConfig, OptimizationResult
from prometheus.application.evaluator import PromptEvaluator from prometheus.application.evaluator import PromptEvaluator
from prometheus.application.evolution import EvolutionLoop from prometheus.application.evolution import EvolutionLoop
from prometheus.cli.logging_setup import get_logger
from prometheus.domain.entities import Prompt from prometheus.domain.entities import Prompt
from prometheus.domain.ports import ProposerPort from prometheus.domain.ports import (
CheckpointPort,
CrossoverPort,
MutationPort,
ProposerPort,
)
logger = get_logger("use_cases")
class OptimizePromptUseCase: class OptimizePromptUseCase:
@@ -25,24 +33,60 @@ class OptimizePromptUseCase:
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
proposer: ProposerPort, proposer: ProposerPort,
bootstrap: SyntheticBootstrap, bootstrap: SyntheticBootstrap,
checkpoint_port: CheckpointPort | None = None,
crossover_port: CrossoverPort | None = None,
mutation_port: MutationPort | None = None,
): ):
self._evaluator = evaluator self._evaluator = evaluator
self._proposer = proposer self._proposer = proposer
self._bootstrap = bootstrap self._bootstrap = bootstrap
self._checkpoint_port = checkpoint_port
self._crossover_port = crossover_port
self._mutation_port = mutation_port
def execute(self, config: OptimizationConfig) -> OptimizationResult: async def execute(self, config: OptimizationConfig) -> OptimizationResult:
"""Full pipeline: """Full pipeline:
1. Bootstrap → generate synthetic inputs 1. Bootstrap → generate synthetic inputs
2. Evolution → optimization loop 2. Evolution → optimization loop (with optional checkpoint resume)
3. Return result 3. Return result
""" """
# Phase 0: Bootstrap # Phase 0: Bootstrap (skip synthetic generation on resume if pool was saved)
initial_state = None
if config.resume and self._checkpoint_port is not None:
initial_state = self._checkpoint_port.load()
if initial_state is not None and initial_state.synthetic_pool:
synthetic_pool = initial_state.synthetic_pool
logger.info(
"Resumed checkpoint includes %d synthetic inputs — skipping bootstrap",
extra={"structured": {"event": "resume_skip_bootstrap", "pool_size": len(synthetic_pool)}},
)
else:
synthetic_pool = self._bootstrap.run(
task_description=config.task_description,
n_examples=config.n_synthetic_inputs,
)
else:
synthetic_pool = self._bootstrap.run( synthetic_pool = self._bootstrap.run(
task_description=config.task_description, task_description=config.task_description,
n_examples=config.n_synthetic_inputs, n_examples=config.n_synthetic_inputs,
) )
# Phase 1: Evolution # Split into train / validation if configured
validation_pool: list = []
if config.validation_split > 0:
synthetic_pool, validation_pool = SyntheticBootstrap.split_pool(
synthetic_pool, config.validation_split,
)
logger.info(
"Split synthetic pool: %d train, %d validation (%.0f%% hold-out)",
len(synthetic_pool), len(validation_pool),
config.validation_split * 100,
extra={"structured": {
"event": "pool_split",
"train_size": len(synthetic_pool),
"val_size": len(validation_pool),
}},
)
loop = EvolutionLoop( loop = EvolutionLoop(
evaluator=self._evaluator, evaluator=self._evaluator,
proposer=self._proposer, proposer=self._proposer,
@@ -51,9 +95,24 @@ class OptimizePromptUseCase:
minibatch_size=config.minibatch_size, minibatch_size=config.minibatch_size,
perfect_score=config.perfect_score, perfect_score=config.perfect_score,
verbose=config.verbose, verbose=config.verbose,
circuit_breaker_threshold=config.circuit_breaker_threshold,
error_strategy=config.error_strategy,
checkpoint_port=self._checkpoint_port,
checkpoint_interval=config.checkpoint_interval,
population_size=config.population_size,
crossover_rate=config.crossover_rate,
mutation_rate=config.mutation_rate,
diversity_penalty=config.diversity_penalty,
crossover_port=self._crossover_port,
mutation_port=self._mutation_port,
early_stop_patience=config.early_stop_patience,
) )
seed_prompt = Prompt(text=config.seed_prompt) seed_prompt = Prompt(text=config.seed_prompt)
state = loop.run(seed_prompt, synthetic_pool, config.task_description) state = await loop.run(
seed_prompt, synthetic_pool, config.task_description,
initial_state=initial_state,
validation_pool=validation_pool or None,
)
# Phase 2: Result # Phase 2: Result
initial_score = ( initial_score = (
@@ -69,9 +128,12 @@ class OptimizePromptUseCase:
), ),
initial_prompt=config.seed_prompt, initial_prompt=config.seed_prompt,
iterations_used=state.iteration, iterations_used=state.iteration,
total_llm_calls=state.total_llm_calls + 1, # +1 for bootstrap total_llm_calls=state.total_llm_calls + 1, # +1 for bootstrap synthesis call
initial_score=initial_score, initial_score=initial_score,
final_score=final_score, final_score=final_score,
improvement=final_score - initial_score, improvement=final_score - initial_score,
history=state.history, history=state.history,
final_validation_score=state.best_validation_score,
best_validation_score=state.best_validation_score,
early_stopped=state.early_stopped,
) )

View File

@@ -1,29 +1,13 @@
""" """
CLI — user entry point. CLI — user entry point.
Typer interface with -i (input) and -o (output) options. Registers all subcommands and delegates to cli/commands/.
""" """
from __future__ import annotations from __future__ import annotations
import logging
import os
from dataclasses import asdict
import dspy
import typer import typer
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from prometheus.application.bootstrap import SyntheticBootstrap from prometheus.cli.commands import init, list_runs, optimize, version
from prometheus.application.dto import OptimizationConfig, OptimizationResult
from prometheus.application.evaluator import PromptEvaluator
from prometheus.application.use_cases import OptimizePromptUseCase
from prometheus.infrastructure.file_io import YamlPersistence
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
app = typer.Typer( app = typer.Typer(
name="prometheus", name="prometheus",
@@ -31,137 +15,12 @@ app = typer.Typer(
no_args_is_help=True, no_args_is_help=True,
) )
console = Console() # Register all subcommands — having multiple commands fixes the
# Typer 0.24+ bug where a single-command app absorbs the subcommand.
optimize.register(app)
@app.command() version.register(app)
def optimize( init.register(app)
input: str = typer.Option( list_runs.register(app)
...,
"-i",
"--input",
help="Path to input YAML config file.",
exists=True,
readable=True,
),
output: str = typer.Option(
"output.yaml",
"-o",
"--output",
help="Path to output YAML result file.",
),
verbose: bool = typer.Option(
False,
"-v",
"--verbose",
help="Print detailed progress.",
),
) -> None:
"""Optimize a prompt without any reference data.
Usage:
prometheus optimize -i config.yaml -o result.yaml
"""
# Configure verbose logging
if verbose:
logging.basicConfig(level=logging.INFO, format="[PROMETHEUS] %(message)s")
console.print(
Panel.fit(
"PROMETHEUS — Prompt Evolution Engine",
subtitle="No reference data required",
)
)
# 1. Load config
persistence = YamlPersistence()
raw_config = persistence.read_config(input)
config = OptimizationConfig(
seed_prompt=raw_config["seed_prompt"],
task_description=raw_config["task_description"],
task_model=raw_config.get("task_model", "openai/gpt-4o-mini"),
judge_model=raw_config.get("judge_model", "openai/gpt-4o"),
proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"),
synth_model=raw_config.get("synth_model", "openai/gpt-4o"),
max_iterations=raw_config.get("max_iterations", 30),
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
minibatch_size=raw_config.get("minibatch_size", 5),
seed=raw_config.get("seed", 42),
output_path=output,
verbose=verbose,
)
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
# 2. Configure DSPy with optional api_base/api_key from config
lm_kwargs: dict = {}
api_base = raw_config.get("api_base")
api_key_env = raw_config.get("api_key_env")
if api_base:
lm_kwargs["api_base"] = api_base
if api_key_env:
lm_kwargs["api_key"] = os.environ.get(api_key_env, "")
task_lm = dspy.LM(config.task_model, **lm_kwargs)
dspy.configure(lm=task_lm)
# 3. Build adapters (Dependency Injection)
synth_adapter = DSPySyntheticAdapter()
llm_adapter = DSPyLLMAdapter(model=config.task_model)
judge_adapter = DSPyJudgeAdapter()
proposer_adapter = DSPyProposerAdapter()
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter)
use_case = OptimizePromptUseCase(
evaluator=evaluator,
proposer=proposer_adapter,
bootstrap=bootstrap,
)
# 4. Execute
with console.status("[bold green]Evolving prompt..."):
result = use_case.execute(config)
# 5. Display results
_display_result(result)
# 6. Save
_save_result(persistence, output, result)
console.print(f"\n[green]Results saved to {output}[/green]")
def _display_result(result: OptimizationResult) -> None:
"""Display a Rich summary in the terminal."""
console.print()
console.print(
Panel(
f"[bold green]Optimized Prompt[/bold green]\n\n{result.optimized_prompt}",
title="Result",
)
)
table = Table(title="Metrics")
table.add_column("Metric", style="cyan")
table.add_column("Value", style="bold")
table.add_row("Initial Score", f"{result.initial_score:.2f}")
table.add_row("Final Score", f"{result.final_score:.2f}")
table.add_row("Improvement", f"{result.improvement:+.2f}")
table.add_row("Iterations", str(result.iterations_used))
table.add_row("LLM Calls", str(result.total_llm_calls))
console.print(table)
def _save_result(
persistence: YamlPersistence,
path: str,
result: OptimizationResult,
) -> None:
"""Save the result as YAML."""
persistence.write_result(path, asdict(result))
@app.command(hidden=True)
def _help() -> None:
"""Internal placeholder to force multi-command Typer behavior."""
pass
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1 @@
"""CLI command modules."""

View File

@@ -0,0 +1,97 @@
"""prometheus init — scaffold a config YAML interactively."""
from __future__ import annotations
from pathlib import Path
import typer
import yaml
from rich.console import Console
console = Console()
_TEMPLATE = """\
# PROMETHEUS configuration
# Generated by `prometheus init`
# --- Required ---
seed_prompt: {seed_prompt}
task_description: {task_description}
# --- Models ---
task_model: {task_model}
judge_model: {judge_model}
proposer_model: {proposer_model}
synth_model: {synth_model}
# --- Global API settings (optional) ---
# api_base: https://api.openai.com/v1
# api_key_env: OPENAI_API_KEY
# --- Evolution parameters ---
max_iterations: 30
n_synthetic_inputs: 20
minibatch_size: 5
perfect_score: 1.0
seed: 42
# --- Concurrency ---
max_concurrency: 5
# --- Error handling ---
max_retries: 3
retry_delay_base: 1.0
circuit_breaker_threshold: 5
error_strategy: retry
"""
def register(app: typer.Typer) -> None:
"""Register the init command on the Typer app."""
@app.command()
def init(
output: str = typer.Option(
"config.yaml",
"-o",
"--output",
help="Path for the generated config file.",
),
) -> None:
"""Interactively scaffold a PROMETHEUS config YAML.
Prompts for required fields and writes a ready-to-edit config file.
"""
target = Path(output)
if target.exists() and not typer.confirm(
f"{output} already exists. Overwrite?", default=False
):
raise typer.Exit(code=0)
seed_prompt: str = typer.prompt("Seed prompt")
task_description: str = typer.prompt("Task description")
task_model: str = typer.prompt("Task model", default="openai/gpt-4o-mini")
judge_model: str = typer.prompt("Judge model", default="openai/gpt-4o")
proposer_model: str = typer.prompt("Proposer model", default="openai/gpt-4o")
synth_model: str = typer.prompt("Synth model", default="openai/gpt-4o")
content = _TEMPLATE.format(
seed_prompt=_yaml_string(seed_prompt),
task_description=_yaml_string(task_description),
task_model=task_model,
judge_model=judge_model,
proposer_model=proposer_model,
synth_model=synth_model,
)
target.write_text(content, encoding="utf-8")
console.print(f"[green]Config written to {output}[/green]")
console.print("[dim]Edit it as needed, then run: prometheus optimize -i config.yaml[/dim]")
def _yaml_string(value: str) -> str:
"""Quote a string for YAML if it contains special characters."""
if any(ch in value for ch in (":", "#", "'", '"', "\n", "{", "}", "[", "]", ",")):
escaped = value.replace("'", "''")
return f"'{escaped}'"
return value

View File

@@ -0,0 +1,101 @@
"""prometheus list — list past optimization runs."""
from __future__ import annotations
import glob as globmod
from pathlib import Path
import typer
import yaml
from rich.console import Console
from rich.table import Table
console = Console()
_DEFAULT_PATTERNS = ("output.yaml", "results/*.yaml", "*.result.yaml")
def register(app: typer.Typer) -> None:
"""Register the list command on the Typer app."""
@app.command("list")
def list_runs(
directory: str = typer.Option(
".",
"-d",
"--directory",
help="Directory to scan for result YAML files.",
),
) -> None:
"""List past optimization runs found in result YAML files.
Scans the given directory for YAML files that look like PROMETHEUS
output (they contain 'optimized_prompt' and 'final_score' keys) and
displays a summary table.
"""
base = Path(directory)
if not base.is_dir():
console.print(f"[red]Directory not found: {directory}[/red]")
raise typer.Exit(code=1)
runs: list[dict] = []
for pattern in _DEFAULT_PATTERNS:
for path_str in globmod.glob(str(base / pattern), recursive=False):
_try_read_run(path_str, runs)
# Also try nested directories one level deep
for path_str in globmod.glob(str(base / "**/*.yaml"), recursive=True):
_try_read_run(path_str, runs)
if not runs:
console.print("[dim]No optimization runs found.[/dim]")
raise typer.Exit(code=0)
# Deduplicate by path
seen: set[str] = set()
unique_runs: list[dict] = []
for run in runs:
if run["path"] not in seen:
seen.add(run["path"])
unique_runs.append(run)
table = Table(title="PROMETHEUS Runs")
table.add_column("File", style="cyan")
table.add_column("Initial", justify="right")
table.add_column("Final", justify="right", style="green")
table.add_column("Delta", justify="right")
table.add_column("Iters", justify="right")
table.add_column("Prompt (first 60 chars)", style="dim")
for run in sorted(unique_runs, key=lambda r: r["path"]):
table.add_row(
run["path"],
f"{run['initial_score']:.2f}",
f"{run['final_score']:.2f}",
f"{run['improvement']:+.2f}",
str(run["iterations"]),
run["prompt_preview"],
)
console.print(table)
def _try_read_run(path_str: str, runs: list[dict]) -> None:
"""Try to parse a YAML file as a PROMETHEUS result and append metadata."""
try:
with open(path_str, encoding="utf-8") as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
return
if "optimized_prompt" not in data or "final_score" not in data:
return
runs.append({
"path": path_str,
"initial_score": float(data.get("initial_score", 0.0)),
"final_score": float(data.get("final_score", 0.0)),
"improvement": float(data.get("improvement", 0.0)),
"iterations": int(data.get("iterations_used", 0)),
"prompt_preview": str(data.get("optimized_prompt", ""))[:60],
})
except (OSError, yaml.YAMLError, ValueError, TypeError):
pass

View File

@@ -0,0 +1,449 @@
"""prometheus optimize — run prompt optimization."""
from __future__ import annotations
import asyncio
import logging
import os
from dataclasses import asdict
import dspy
import typer
from pydantic import ValidationError
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.dto import JudgeDimension, OptimizationConfig, OptimizationResult
from prometheus.domain.entities import EvalResult, Prompt
from prometheus.application.evaluator import PromptEvaluator
from prometheus.application.ground_truth_evaluator import GroundTruthEvaluator
from prometheus.application.use_cases import OptimizePromptUseCase
from prometheus.cli.logging_setup import configure_logging
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
from prometheus.infrastructure.checkpoint import JsonCheckpointPersistence
from prometheus.infrastructure.file_io import YamlPersistence
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
from prometheus.infrastructure.crossover_adapter import DSPyCrossoverAdapter
from prometheus.infrastructure.mutation_adapter import DSPyMutationAdapter
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
from prometheus.infrastructure.similarity import create_similarity_adapter
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
console = Console()
def register(app: typer.Typer) -> None:
"""Register the optimize command on the Typer app."""
@app.command()
def optimize(
input: str = typer.Option(
...,
"-i",
"--input",
help="Path to input YAML config file.",
exists=True,
readable=True,
),
output: str = typer.Option(
"output.yaml",
"-o",
"--output",
help="Path to output YAML result file.",
),
verbose: bool = typer.Option(
False,
"-v",
"--verbose",
help="Print detailed progress (INFO level).",
),
debug: bool = typer.Option(
False,
"--debug",
help="Enable DEBUG-level logging (overrides -v).",
),
log_format: str = typer.Option(
"text",
"--log-format",
help="Log output format: text | json.",
),
log_file: str | None = typer.Option(
None,
"--log-file",
help="Optional file path to write logs to.",
),
max_retries: int = typer.Option(
3,
"--max-retries",
help="Max retry attempts for transient LLM errors (429, timeout, 5xx).",
),
error_strategy: str = typer.Option(
"retry",
"--error-strategy",
help="How to handle errors: skip | retry | abort.",
),
max_concurrency: int = typer.Option(
5,
"--max-concurrency",
help="Max parallel LLM calls for minibatch execution and judging.",
),
eval_dataset: str | None = typer.Option(
None,
"--eval-dataset",
help="Path to a CSV/JSON dataset with 'input' and 'expected_output' columns.",
),
eval_metric: str = typer.Option(
"bleu",
"--eval-metric",
help="Similarity metric for ground-truth eval: exact | bleu | rouge_l | cosine | llm_judge.",
),
checkpoint_dir: str | None = typer.Option(
None,
"--checkpoint-dir",
help="Directory for checkpoint files. Enables periodic checkpointing.",
),
checkpoint_interval: int = typer.Option(
5,
"--checkpoint-interval",
help="Save a checkpoint every N iterations.",
),
resume: bool = typer.Option(
False,
"--resume",
help="Resume from the latest checkpoint in --checkpoint-dir.",
),
population_size: int = typer.Option(
1,
"--population-size",
help="Number of candidates in the evolution population. 1 = single-candidate hill climbing.",
),
crossover_rate: float = typer.Option(
0.5,
"--crossover-rate",
help="Probability of applying crossover vs reflective mutation (0.01.0). Only used when --population-size > 1.",
),
mutation_rate: float = typer.Option(
0.3,
"--mutation-rate",
help="Probability of applying a mutation operator after crossover/proposal (0.01.0). Only used when --population-size > 1.",
),
validation_split: float = typer.Option(
0.3,
"--validation-split",
help="Fraction of synthetic pool reserved for hold-out validation (0.00.9). 0 disables validation.",
),
early_stop_patience: int = typer.Option(
5,
"--early-stop-patience",
help="Stop if validation score does not improve for this many consecutive iterations.",
),
judge_criteria: str | None = typer.Option(
None,
"--judge-criteria",
help="Custom judge rubric or evaluation criteria override (free text).",
),
) -> None:
"""Optimize a prompt without any reference data.
Usage:
prometheus optimize -i config.yaml -o result.yaml
prometheus optimize -i config.yaml --eval-dataset data.csv --eval-metric bleu
prometheus optimize -i config.yaml --checkpoint-dir .prometheus/checkpoints --resume
"""
asyncio.run(
_async_optimize(
input, output, verbose, debug, log_format, log_file,
max_retries, error_strategy, max_concurrency,
eval_dataset, eval_metric,
checkpoint_dir, checkpoint_interval, resume,
population_size, crossover_rate, mutation_rate,
validation_split, early_stop_patience,
judge_criteria,
)
)
async def _async_optimize(
input: str,
output: str,
verbose: bool,
debug: bool,
log_format: str,
log_file: str | None,
max_retries: int,
error_strategy: str,
max_concurrency: int,
eval_dataset: str | None,
eval_metric: str,
checkpoint_dir: str | None,
checkpoint_interval: int,
resume: bool,
population_size: int = 1,
crossover_rate: float = 0.5,
mutation_rate: float = 0.3,
validation_split: float = 0.3,
early_stop_patience: int = 5,
judge_criteria: str | None = None,
) -> None:
# Configure structured logging
if debug:
log_level = logging.DEBUG
elif verbose:
log_level = logging.INFO
else:
log_level = logging.WARNING
configure_logging(level=log_level, log_format=log_format, log_file=log_file)
console.print(
Panel.fit(
"PROMETHEUS — Prompt Evolution Engine",
subtitle="No reference data required",
)
)
# 1. Load & validate config
persistence = YamlPersistence()
raw_config = persistence.read_config(input)
# CLI flags override config file values
raw_config.setdefault("max_retries", max_retries)
raw_config.setdefault("error_strategy", error_strategy)
raw_config.setdefault("max_concurrency", max_concurrency)
raw_config["output_path"] = output
raw_config["verbose"] = verbose
raw_config["debug"] = debug
raw_config["log_format"] = log_format
raw_config["log_file"] = log_file
if eval_dataset:
raw_config["eval_dataset_path"] = eval_dataset
raw_config.setdefault("eval_metric", eval_metric)
if checkpoint_dir:
raw_config["checkpoint_dir"] = checkpoint_dir
raw_config.setdefault("checkpoint_interval", checkpoint_interval)
if resume:
raw_config["resume"] = True
raw_config.setdefault("population_size", population_size)
raw_config.setdefault("crossover_rate", crossover_rate)
raw_config.setdefault("mutation_rate", mutation_rate)
raw_config.setdefault("validation_split", validation_split)
raw_config.setdefault("early_stop_patience", early_stop_patience)
if judge_criteria:
raw_config["judge_criteria"] = judge_criteria
try:
config = OptimizationConfig.model_validate(raw_config)
except ValidationError as exc:
console.print("[bold red]Configuration error:[/bold red]\n")
for err in exc.errors():
loc = "".join(str(l) for l in err["loc"])
console.print(f" [red]• {loc}: {err['msg']}[/red]")
raise typer.Exit(code=1) from exc
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
# 2. Create per-model DSPy LM instances
def _model_lm_kwargs(
model_api_base: str | None,
model_api_key_env: str | None,
) -> dict:
"""Build kwargs for dspy.LM, using per-model overrides with global fallback."""
kwargs: dict = {}
api_base = model_api_base or config.api_base
api_key_env = model_api_key_env or config.api_key_env
if api_base:
kwargs["api_base"] = api_base
if api_key_env:
kwargs["api_key"] = os.environ.get(api_key_env, "")
return kwargs
task_lm = dspy.LM(
config.task_model,
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env),
)
judge_lm = dspy.LM(
config.judge_model,
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env),
)
proposer_lm = dspy.LM(
config.proposer_model,
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env),
)
synth_lm = dspy.LM(
config.synth_model,
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env),
)
# 3. Build adapters (Dependency Injection — each gets its own LM + retry config)
synth_adapter = DSPySyntheticAdapter(lm=synth_lm)
llm_adapter = DSPyLLMAdapter(
lm=task_lm,
max_retries=config.max_retries,
retry_delay_base=config.retry_delay_base,
)
judge_adapter = DSPyJudgeAdapter(
lm=judge_lm,
max_retries=config.max_retries,
retry_delay_base=config.retry_delay_base,
max_concurrency=config.max_concurrency,
judge_criteria=config.judge_criteria,
judge_dimensions=config.judge_dimensions,
)
proposer_adapter = DSPyProposerAdapter(
lm=proposer_lm,
max_retries=config.max_retries,
retry_delay_base=config.retry_delay_base,
)
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
evaluator = PromptEvaluator(
executor=llm_adapter,
judge=judge_adapter,
max_concurrency=config.max_concurrency,
)
# Build checkpoint port if checkpoint_dir is configured
checkpoint_port = None
if config.checkpoint_dir:
checkpoint_port = JsonCheckpointPersistence(checkpoint_dir=config.checkpoint_dir)
# Build crossover/mutation adapters for population-based evolution
crossover_adapter = None
mutation_adapter = None
if config.population_size > 1:
# Reuse proposer LM for crossover and mutation (same model, same role)
crossover_adapter = DSPyCrossoverAdapter(
lm=proposer_lm,
max_retries=config.max_retries,
retry_delay_base=config.retry_delay_base,
)
mutation_adapter = DSPyMutationAdapter(
lm=proposer_lm,
max_retries=config.max_retries,
retry_delay_base=config.retry_delay_base,
)
use_case = OptimizePromptUseCase(
evaluator=evaluator,
proposer=proposer_adapter,
bootstrap=bootstrap,
checkpoint_port=checkpoint_port,
crossover_port=crossover_adapter,
mutation_port=mutation_adapter,
)
# 4. Execute
with console.status("[bold green]Evolving prompt..."):
result = await use_case.execute(config)
# 4b. Compute actual LLM call count from adapter counters
actual_llm_calls = (
llm_adapter.call_count
+ judge_adapter.call_count
+ proposer_adapter.call_count
+ synth_adapter.call_count
+ (crossover_adapter.call_count if crossover_adapter else 0)
+ (mutation_adapter.call_count if mutation_adapter else 0)
)
result = OptimizationResult(
optimized_prompt=result.optimized_prompt,
initial_prompt=result.initial_prompt,
iterations_used=result.iterations_used,
total_llm_calls=actual_llm_calls,
initial_score=result.initial_score,
final_score=result.final_score,
improvement=result.improvement,
history=result.history,
final_validation_score=result.final_validation_score,
best_validation_score=result.best_validation_score,
early_stopped=result.early_stopped,
)
# 5. Display results
_display_result(result)
# 6. Optional ground-truth evaluation on the optimized prompt
if config.eval_dataset_path:
dataset = FileDatasetLoader().load(config.eval_dataset_path)
if config.eval_metric == "llm_judge":
# llm_judge reuses the existing PromptEvaluator with the LLM judge
from prometheus.domain.entities import SyntheticExample
synth_dataset = [
SyntheticExample(input_text=ex.input_text, id=ex.id) for ex in dataset
]
gt_eval = PromptEvaluator(
executor=llm_adapter,
judge=judge_adapter,
max_concurrency=config.max_concurrency,
)
with console.status("[bold green]Running ground-truth evaluation (llm_judge)..."):
gt_result = await gt_eval.evaluate(
prompt=Prompt(text=result.optimized_prompt),
minibatch=synth_dataset,
task_description=config.task_description,
)
else:
gt_evaluator = GroundTruthEvaluator(
executor=llm_adapter,
similarity=create_similarity_adapter(config.eval_metric),
max_concurrency=config.max_concurrency,
)
with console.status("[bold green]Running ground-truth evaluation..."):
gt_result = await gt_evaluator.evaluate(
prompt=Prompt(text=result.optimized_prompt),
dataset=dataset,
)
_display_ground_truth(gt_result, config.eval_metric, len(dataset))
# 7. Save
_save_result(persistence, output, result)
console.print(f"\n[green]Results saved to {output}[/green]")
def _display_result(result: OptimizationResult) -> None:
"""Display a Rich summary in the terminal."""
console.print()
console.print(
Panel(
f"[bold green]Optimized Prompt[/bold green]\n\n{result.optimized_prompt}",
title="Result",
)
)
table = Table(title="Metrics")
table.add_column("Metric", style="cyan")
table.add_column("Value", style="bold")
table.add_row("Initial Score", f"{result.initial_score:.2f}")
table.add_row("Final Score", f"{result.final_score:.2f}")
table.add_row("Improvement", f"{result.improvement:+.2f}")
if result.best_validation_score is not None:
table.add_row("Best Validation Score", f"{result.best_validation_score:.4f}")
if result.early_stopped:
table.add_row("Early Stopped", "[yellow]Yes[/yellow]")
table.add_row("Iterations", str(result.iterations_used))
table.add_row("LLM Calls", str(result.total_llm_calls))
console.print(table)
def _save_result(
persistence: YamlPersistence,
path: str,
result: OptimizationResult,
) -> None:
"""Save the result as YAML."""
persistence.write_result(path, asdict(result))
def _display_ground_truth(
result: EvalResult, metric: str, dataset_size: int
) -> None:
"""Display ground-truth evaluation results."""
console.print()
table = Table(title=f"Ground-Truth Evaluation (metric: {metric})")
table.add_column("Metric", style="cyan")
table.add_column("Value", style="bold")
table.add_row("Dataset Size", str(dataset_size))
table.add_row("Mean Score", f"{result.mean_score:.4f}")
table.add_row("Total Score", f"{result.total_score:.4f}")
exact_matches = sum(1 for s in result.scores if s >= 0.99)
table.add_row("Exact Matches", f"{exact_matches}/{dataset_size}")
table.add_row("Accuracy", f"{exact_matches / dataset_size:.2%}")
console.print(table)

View File

@@ -0,0 +1,18 @@
"""prometheus version — print the current version."""
from __future__ import annotations
import typer
from rich.console import Console
from prometheus import __version__
console = Console()
def register(app: typer.Typer) -> None:
"""Register the version command on the Typer app."""
@app.command()
def version() -> None:
"""Print the PROMETHEUS version."""
console.print(f"PROMETHEUS {__version__}")

View File

@@ -0,0 +1,96 @@
"""
Structured logging configuration for PROMETHEUS.
Supports text (human-readable) and JSON (machine-parseable) output,
configurable log levels, and optional file output.
Fixes Bug #4: verbose mode now reliably produces output by configuring
handlers explicitly instead of relying on ``logging.basicConfig``.
"""
from __future__ import annotations
import json
import logging
import sys
from datetime import datetime, timezone
from typing import TextIO
_PROMETHEUS_LOGGER = "prometheus"
class _JsonFormatter(logging.Formatter):
"""Emit one JSON object per log line."""
def format(self, record: logging.LogRecord) -> str:
message = record.getMessage()
payload: dict = {
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": message,
}
# Merge any extra structured fields the caller attached.
if hasattr(record, "structured"):
payload["structured"] = record.structured # type: ignore[attr-defined]
if record.exc_info and record.exc_info[1] is not None:
payload["exception"] = self.formatException(record.exc_info)
return json.dumps(payload, default=str)
class _TextFormatter(logging.Formatter):
"""Human-readable format with structured extras appended."""
def format(self, record: logging.LogRecord) -> str:
base = super().format(record)
if hasattr(record, "structured") and record.structured:
extras = " ".join(f"{k}={v}" for k, v in record.structured.items())
base = f"{base} {extras}"
return base
def configure_logging(
*,
level: int = logging.WARNING,
log_format: str = "text",
log_file: str | None = None,
) -> None:
"""Configure the prometheus root logger.
Args:
level: Logging level (e.g. logging.DEBUG, logging.INFO).
log_format: ``"text"`` for human-readable or ``"json"`` for
machine-parseable output.
log_file: Optional path to also write logs to a file.
"""
prom_logger = logging.getLogger(_PROMETHEUS_LOGGER)
prom_logger.setLevel(level)
# Remove any stale handlers so re-configuration is idempotent.
prom_logger.handlers.clear()
if log_format == "json":
fmt = _JsonFormatter()
else:
fmt = _TextFormatter(
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%H:%M:%S",
)
# Console handler (stderr so it doesn't mix with Rich stdout)
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setFormatter(fmt)
prom_logger.addHandler(console_handler)
# Optional file handler
if log_file:
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(fmt)
prom_logger.addHandler(file_handler)
# Prevent propagation to root logger to avoid duplicate output
prom_logger.propagate = False
def get_logger(name: str) -> logging.Logger:
"""Return a child logger under the prometheus namespace."""
return logging.getLogger(f"{_PROMETHEUS_LOGGER}.{name}")

View File

@@ -1,12 +1,2 @@
"""Application settings.""" """Application settings."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
@dataclass
class AppSettings:
"""Non-sensitive settings, hardcoded for the MVP."""
app_name: str = "prometheus"
version: str = "0.1.0"

View File

@@ -31,6 +31,15 @@ class SyntheticExample:
id: int = 0 id: int = 0
@dataclass(frozen=True)
class GroundTruthExample:
"""A ground-truth evaluation example with a known-good expected output."""
input_text: str
expected_output: str
id: int = 0
@dataclass @dataclass
class Trajectory: class Trajectory:
"""Execution trace of a prompt on an input. """Execution trace of a prompt on an input.
@@ -85,3 +94,6 @@ class OptimizationState:
synthetic_pool: list[SyntheticExample] = field(default_factory=list) synthetic_pool: list[SyntheticExample] = field(default_factory=list)
history: list[dict[str, Any]] = field(default_factory=list) history: list[dict[str, Any]] = field(default_factory=list)
total_llm_calls: int = 0 total_llm_calls: int = 0
# Hold-out validation
best_validation_score: float | None = None
early_stopped: bool = False

View File

@@ -8,7 +8,14 @@ from abc import ABC, abstractmethod
from typing import Any from typing import Any
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory from prometheus.domain.entities import (
Candidate,
GroundTruthExample,
OptimizationState,
Prompt,
SyntheticExample,
Trajectory,
)
class LLMPort(ABC): class LLMPort(ABC):
@@ -18,7 +25,7 @@ class LLMPort(ABC):
""" """
@abstractmethod @abstractmethod
def execute(self, prompt: Prompt, input_text: str) -> str: async def execute(self, prompt: Prompt, input_text: str) -> str:
"""Execute the prompt on the input, return the raw response.""" """Execute the prompt on the input, return the raw response."""
... ...
@@ -31,7 +38,7 @@ class JudgePort(ABC):
""" """
@abstractmethod @abstractmethod
def judge_batch( async def judge_batch(
self, self,
task_description: str, task_description: str,
pairs: list[tuple[str, str]], pairs: list[tuple[str, str]],
@@ -50,7 +57,7 @@ class ProposerPort(ABC):
""" """
@abstractmethod @abstractmethod
def propose( async def propose(
self, self,
current_prompt: Prompt, current_prompt: Prompt,
trajectories: list[Trajectory], trajectories: list[Trajectory],
@@ -73,6 +80,34 @@ class SyntheticGeneratorPort(ABC):
... ...
class CrossoverPort(ABC):
"""Port for crossover — combining instructions from two parent candidates."""
@abstractmethod
async def crossover(
self,
parent_a: Prompt,
parent_b: Prompt,
task_description: str,
) -> Prompt:
"""Combine instructions from two parents into a child prompt."""
...
class MutationPort(ABC):
"""Port for mutating a prompt — paraphrase, constrain, generalize, specialize."""
@abstractmethod
async def mutate(
self,
prompt: Prompt,
task_description: str,
mutation_type: str = "paraphrase",
) -> Prompt:
"""Apply a mutation to the prompt."""
...
class PersistencePort(ABC): class PersistencePort(ABC):
"""Port for reading/writing files.""" """Port for reading/writing files."""
@@ -83,3 +118,49 @@ class PersistencePort(ABC):
@abstractmethod @abstractmethod
def write_result(self, path: str, data: dict[str, Any]) -> None: def write_result(self, path: str, data: dict[str, Any]) -> None:
... ...
class SimilarityPort(ABC):
"""Port for computing similarity between a prediction and expected output.
Infrastructure provides concrete metrics (exact match, BLEU, ROUGE, cosine).
"""
@abstractmethod
def compute(self, prediction: str, expected: str) -> float:
"""Compute similarity score in [0, 1]. 1.0 = perfect match."""
...
class DatasetLoaderPort(ABC):
"""Port for loading ground-truth evaluation datasets."""
@abstractmethod
def load(self, path: str) -> list[GroundTruthExample]:
"""Load a dataset from a CSV or JSON file.
Each row must have 'input' and 'expected_output' fields.
"""
...
class CheckpointPort(ABC):
"""Port for saving and loading optimization checkpoints.
Enables resuming long-running optimizations after interruption.
"""
@abstractmethod
def save(self, state: OptimizationState) -> None:
"""Persist the current optimization state to disk."""
...
@abstractmethod
def load(self) -> OptimizationState | None:
"""Load the latest checkpoint. Returns None if no checkpoint exists."""
...
@abstractmethod
def latest_exists(self) -> bool:
"""Check if a checkpoint file is available for resuming."""
...

View File

@@ -19,3 +19,26 @@ def should_accept(
def normalize_score(raw: float, min_val: float = 0.0, max_val: float = 1.0) -> float: def normalize_score(raw: float, min_val: float = 0.0, max_val: float = 1.0) -> float:
"""Clamp a score within [min_val, max_val].""" """Clamp a score within [min_val, max_val]."""
return max(min_val, min(max_val, raw)) return max(min_val, min(max_val, raw))
def weighted_aggregate(
scores: dict[str, float],
weights: dict[str, float],
) -> float:
"""Compute a weighted average of per-dimension scores.
Args:
scores: Mapping of dimension name → score (0.01.0).
weights: Mapping of dimension name → weight (0.01.0).
Returns:
Weighted average in [0.0, 1.0]. Returns 0.0 if inputs are empty.
"""
if not scores or not weights:
return 0.0
total_weight = sum(weights.get(name, 0.0) for name in scores)
if total_weight == 0.0:
return sum(scores.values()) / len(scores)
return sum(
scores.get(name, 0.0) * weights.get(name, 0.0) for name in scores
) / total_weight

View File

@@ -0,0 +1,149 @@
"""
JSON checkpoint persistence — save/load optimization state to disk.
Implements the CheckpointPort with JSON for human-readable, versionable snapshots.
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict
from pathlib import Path
from prometheus.cli.logging_setup import get_logger
from prometheus.domain.entities import (
Candidate,
OptimizationState,
Prompt,
SyntheticExample,
)
from prometheus.domain.ports import CheckpointPort
logger = get_logger("checkpoint")
_CHECKPOINT_FILE = "latest.json"
class JsonCheckpointPersistence(CheckpointPort):
"""Saves optimization state as JSON to a configurable directory."""
def __init__(self, checkpoint_dir: str | Path = ".prometheus/checkpoints") -> None:
self._dir = Path(checkpoint_dir)
def save(self, state: OptimizationState) -> None:
"""Persist the current optimization state to disk."""
self._dir.mkdir(parents=True, exist_ok=True)
path = self._dir / _CHECKPOINT_FILE
data = _serialize_state(state)
path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
logger.info(
"Checkpoint saved",
extra={
"structured": {
"event": "checkpoint_saved",
"path": str(path),
"iteration": state.iteration,
"total_llm_calls": state.total_llm_calls,
},
},
)
def load(self) -> OptimizationState | None:
"""Load the latest checkpoint. Returns None if no checkpoint exists."""
path = self._dir / _CHECKPOINT_FILE
if not path.exists():
return None
raw = json.loads(path.read_text(encoding="utf-8"))
state = _deserialize_state(raw)
logger.info(
"Checkpoint loaded",
extra={
"structured": {
"event": "checkpoint_loaded",
"path": str(path),
"iteration": state.iteration,
"total_llm_calls": state.total_llm_calls,
},
},
)
return state
def latest_exists(self) -> bool:
"""Check if a checkpoint file is available for resuming."""
return (self._dir / _CHECKPOINT_FILE).exists()
# ---------------------------------------------------------------------------
# Serialization helpers — keep the JSON format stable and self-describing.
# ---------------------------------------------------------------------------
_SCHEMA_VERSION = 1
def _serialize_state(state: OptimizationState) -> dict:
"""Convert OptimizationState to a JSON-safe dict."""
return {
"schema_version": _SCHEMA_VERSION,
"iteration": state.iteration,
"best_candidate": _serialize_candidate(state.best_candidate),
"candidates": [_serialize_candidate(c) for c in state.candidates],
"synthetic_pool": [
{"input_text": ex.input_text, "category": ex.category, "id": ex.id}
for ex in state.synthetic_pool
],
"history": state.history,
"total_llm_calls": state.total_llm_calls,
}
def _serialize_candidate(candidate: Candidate | None) -> dict | None:
if candidate is None:
return None
return {
"prompt_text": candidate.prompt.text,
"prompt_metadata": candidate.prompt.metadata,
"best_score": candidate.best_score,
"generation": candidate.generation,
"parent_id": candidate.parent_id,
}
def _deserialize_state(data: dict) -> OptimizationState:
"""Reconstruct OptimizationState from a checkpoint dict."""
version = data.get("schema_version", 1)
# Future migration hooks go here: if version < 2: ...
best_raw = data.get("best_candidate")
best_candidate = _deserialize_candidate(best_raw)
candidates = [_deserialize_candidate(c) for c in data.get("candidates", [])]
synthetic_pool = [
SyntheticExample(
input_text=ex["input_text"],
category=ex.get("category", "default"),
id=ex.get("id", 0),
)
for ex in data.get("synthetic_pool", [])
]
state = OptimizationState(
iteration=data.get("iteration", 0),
best_candidate=best_candidate,
candidates=candidates,
synthetic_pool=synthetic_pool,
history=data.get("history", []),
total_llm_calls=data.get("total_llm_calls", 0),
)
return state
def _deserialize_candidate(raw: dict | None) -> Candidate | None:
if raw is None:
return None
return Candidate(
prompt=Prompt(text=raw["prompt_text"], metadata=raw.get("prompt_metadata", {})),
best_score=raw.get("best_score", 0.0),
generation=raw.get("generation", 0),
parent_id=raw.get("parent_id"),
)

View File

@@ -0,0 +1,63 @@
"""
Adapter: Instruction Crossover via DSPy.
Implements CrossoverPort — combines two parent prompts into a child.
"""
from __future__ import annotations
import asyncio
import dspy
from prometheus.domain.entities import Prompt
from prometheus.domain.ports import CrossoverPort
from prometheus.infrastructure.dspy_modules import InstructionCrossover
from prometheus.infrastructure.retry import async_retry_with_backoff
class DSPyCrossoverAdapter(CrossoverPort):
"""Uses DSPy to combine two parent instructions into a child."""
def __init__(
self,
lm: dspy.LM,
max_retries: int = 3,
retry_delay_base: float = 1.0,
) -> None:
self._lm = lm
self._crossover = InstructionCrossover()
self._max_retries = max_retries
self._retry_delay_base = retry_delay_base
self.call_count: int = 0
async def crossover(
self,
parent_a: Prompt,
parent_b: Prompt,
task_description: str,
) -> Prompt:
async def _call() -> Prompt:
return await asyncio.to_thread(
self._sync_crossover, parent_a, parent_b, task_description,
)
return await async_retry_with_backoff(
_call,
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)
def _sync_crossover(
self,
parent_a: Prompt,
parent_b: Prompt,
task_description: str,
) -> Prompt:
with dspy.context(lm=self._lm):
pred = self._crossover(
parent_a=parent_a.text,
parent_b=parent_b.text,
task_description=task_description,
)
self.call_count += 1
return Prompt(text=pred.child_instruction)

View File

@@ -0,0 +1,75 @@
"""Dataset loader — loads ground-truth CSV/JSON datasets."""
from __future__ import annotations
import csv
import json
import logging
from pathlib import Path
from prometheus.domain.entities import GroundTruthExample
from prometheus.domain.ports import DatasetLoaderPort
logger = logging.getLogger(__name__)
class FileDatasetLoader(DatasetLoaderPort):
"""Loads evaluation datasets from CSV or JSON files.
CSV files must have 'input' and 'expected_output' columns.
JSON files must be an array of objects with 'input' and 'expected_output' keys.
"""
def load(self, path: str) -> list[GroundTruthExample]:
suffix = Path(path).suffix.lower()
if suffix == ".csv":
return self._load_csv(path)
elif suffix in (".json", ".jsonl"):
return self._load_json(path)
else:
raise ValueError(
f"Unsupported dataset format '{suffix}'. Use .csv, .json, or .jsonl."
)
def _load_csv(self, path: str) -> list[GroundTruthExample]:
examples: list[GroundTruthExample] = []
with open(path, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader):
input_text = row.get("input", "").strip()
expected = row.get("expected_output", "").strip()
if not input_text:
logger.warning("Skipping CSV row %d: empty 'input' field", i + 1)
continue
examples.append(
GroundTruthExample(
input_text=input_text,
expected_output=expected,
id=i,
)
)
logger.info("Loaded %d examples from CSV: %s", len(examples), path)
return examples
def _load_json(self, path: str) -> list[GroundTruthExample]:
with open(path, encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("JSON dataset must be an array of objects.")
examples: list[GroundTruthExample] = []
for i, item in enumerate(data):
input_text = item.get("input", "").strip() if isinstance(item, dict) else ""
expected = (
item.get("expected_output", "").strip() if isinstance(item, dict) else ""
)
if not input_text:
logger.warning("Skipping JSON item %d: empty 'input' field", i)
continue
examples.append(
GroundTruthExample(
input_text=input_text,
expected_output=expected,
id=i,
)
)
logger.info("Loaded %d examples from JSON: %s", len(examples), path)
return examples

View File

@@ -11,8 +11,10 @@ import re
import dspy import dspy
from prometheus.infrastructure.dspy_signatures import ( from prometheus.infrastructure.dspy_signatures import (
CrossoverInstructions,
GenerateSyntheticInputs, GenerateSyntheticInputs,
JudgeOutput, JudgeOutput,
MutateInstruction,
ProposeInstruction, ProposeInstruction,
) )
@@ -53,19 +55,30 @@ class OutputJudge(dspy.Module):
self.judge = dspy.ChainOfThought(JudgeOutput) self.judge = dspy.ChainOfThought(JudgeOutput)
def forward( def forward(
self, task_description: str, input_text: str, output_text: str self,
task_description: str,
input_text: str,
output_text: str,
judge_criteria: str = "",
dimension_names: str = "",
) -> dspy.Prediction: ) -> dspy.Prediction:
result = self.judge( result = self.judge(
task_description=task_description, task_description=task_description,
input_text=input_text, input_text=input_text,
output_text=output_text, output_text=output_text,
judge_criteria=judge_criteria,
dimension_names=dimension_names,
) )
try: try:
score = float(result.score) score = float(result.score)
except (ValueError, TypeError): except (ValueError, TypeError):
score = 0.5 # neutral fallback score = 0.5 # neutral fallback
score = max(0.0, min(1.0, score)) score = max(0.0, min(1.0, score))
return dspy.Prediction(score=score, feedback=result.feedback) return dspy.Prediction(
score=score,
feedback=result.feedback,
dimension_scores=getattr(result, "dimension_scores", "{}"),
)
class InstructionProposer(dspy.Module): class InstructionProposer(dspy.Module):
@@ -90,3 +103,45 @@ class InstructionProposer(dspy.Module):
failure_examples=failure_examples, failure_examples=failure_examples,
) )
return dspy.Prediction(new_instruction=result.new_instruction) return dspy.Prediction(new_instruction=result.new_instruction)
class InstructionCrossover(dspy.Module):
"""Crossover: combines two parent instructions into a child."""
def __init__(self) -> None:
super().__init__()
self.crossover = dspy.ChainOfThought(CrossoverInstructions)
def forward(
self,
parent_a: str,
parent_b: str,
task_description: str,
) -> dspy.Prediction:
result = self.crossover(
parent_a=parent_a,
parent_b=parent_b,
task_description=task_description,
)
return dspy.Prediction(child_instruction=result.child_instruction)
class InstructionMutator(dspy.Module):
"""Mutator: applies a typed mutation to an instruction."""
def __init__(self) -> None:
super().__init__()
self.mutate = dspy.ChainOfThought(MutateInstruction)
def forward(
self,
current_instruction: str,
task_description: str,
mutation_type: str,
) -> dspy.Prediction:
result = self.mutate(
current_instruction=current_instruction,
task_description=task_description,
mutation_type=mutation_type,
)
return dspy.Prediction(mutated_instruction=result.mutated_instruction)

View File

@@ -44,6 +44,12 @@ class JudgeOutput(dspy.Signature):
output_text: str = dspy.InputField( output_text: str = dspy.InputField(
desc="The assistant's response to evaluate." desc="The assistant's response to evaluate."
) )
judge_criteria: str = dspy.InputField(
desc="Custom evaluation rubric or criteria. Empty string = use default judging criteria."
)
dimension_names: str = dspy.InputField(
desc="Comma-separated dimension names for multi-dimensional scoring. Empty string = single overall score."
)
score: float = dspy.OutputField( score: float = dspy.OutputField(
desc="Quality score from 0.0 (wrong) to 1.0 (perfect)." desc="Quality score from 0.0 (wrong) to 1.0 (perfect)."
) )
@@ -53,6 +59,12 @@ class JudgeOutput(dspy.Signature):
"with the output and how to improve it. Be critical." "with the output and how to improve it. Be critical."
), ),
) )
dimension_scores: str = dspy.OutputField(
desc=(
"JSON object mapping dimension names to scores (0.0-1.0). "
'Empty object {} if no dimensions specified.'
),
)
class ProposeInstruction(dspy.Signature): class ProposeInstruction(dspy.Signature):
@@ -77,3 +89,52 @@ class ProposeInstruction(dspy.Signature):
new_instruction: str = dspy.OutputField( new_instruction: str = dspy.OutputField(
desc="An improved version of the instruction." desc="An improved version of the instruction."
) )
class CrossoverInstructions(dspy.Signature):
"""Combine two instruction prompts into a single improved instruction.
Take the strongest elements from each parent — structure, phrasing,
constraints, examples — and merge them into a coherent child instruction
that is strictly better than either parent alone.
"""
parent_a: str = dspy.InputField(
desc="First parent instruction."
)
parent_b: str = dspy.InputField(
desc="Second parent instruction."
)
task_description: str = dspy.InputField(
desc="Description of the task."
)
child_instruction: str = dspy.OutputField(
desc=(
"A combined instruction that takes the best elements from "
"both parents into a single, coherent instruction."
),
)
class MutateInstruction(dspy.Signature):
"""Apply a specific mutation to an instruction prompt.
The mutation_type determines the transformation:
- paraphrase: restate the instruction in different words
- constrain: add specificity, constraints, or guard-rails
- generalize: broaden the instruction to cover more cases
- specialize: narrow the instruction for better focus on the task
"""
current_instruction: str = dspy.InputField(
desc="The instruction to mutate."
)
task_description: str = dspy.InputField(
desc="Description of the task."
)
mutation_type: str = dspy.InputField(
desc="Type of mutation: paraphrase, constrain, generalize, or specialize."
)
mutated_instruction: str = dspy.OutputField(
desc="The mutated instruction, preserving core intent but altered per the mutation type."
)

View File

@@ -5,30 +5,141 @@ Implements the JudgePort via the DSPy OutputJudge module.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
import dspy
from prometheus.application.dto import JudgeDimension
from prometheus.domain.ports import JudgePort from prometheus.domain.ports import JudgePort
from prometheus.domain.scoring import weighted_aggregate
from prometheus.infrastructure.dspy_modules import OutputJudge from prometheus.infrastructure.dspy_modules import OutputJudge
from prometheus.infrastructure.retry import async_retry_with_backoff
logger = logging.getLogger(__name__)
class DSPyJudgeAdapter(JudgePort): class DSPyJudgeAdapter(JudgePort):
"""Evaluates a batch of (input, output) pairs by calling the Judge for each. """Evaluates a batch of (input, output) pairs by calling the Judge for each.
Sequential for MVP. Future: parallelize via dspy.Parallel. Per-call isolation: a failure on one item returns a zero-score sentinel
instead of crashing the whole batch.
Judge calls run in parallel (bounded by *max_concurrency*).
When *judge_criteria* or *judge_dimensions* are provided, the judge applies
custom rubrics and/or multi-dimensional scoring with weighted aggregation.
""" """
def __init__(self) -> None: def __init__(
self,
lm: dspy.LM,
max_retries: int = 3,
retry_delay_base: float = 1.0,
max_concurrency: int = 5,
judge_criteria: str | None = None,
judge_dimensions: list[JudgeDimension] | None = None,
) -> None:
self._lm = lm
self._judge = OutputJudge() self._judge = OutputJudge()
self._max_retries = max_retries
self._retry_delay_base = retry_delay_base
self._semaphore = asyncio.Semaphore(max_concurrency)
self._judge_criteria = judge_criteria or ""
self._judge_dimensions = judge_dimensions or []
self._dimension_names = (
",".join(d.name for d in self._judge_dimensions)
if self._judge_dimensions
else ""
)
self._weights: dict[str, float] = (
{d.name: d.weight for d in self._judge_dimensions}
if self._judge_dimensions
else {}
)
self.call_count: int = 0
def judge_batch( async def judge_batch(
self, self,
task_description: str, task_description: str,
pairs: list[tuple[str, str]], pairs: list[tuple[str, str]],
) -> list[tuple[float, str]]: ) -> list[tuple[float, str]]:
results: list[tuple[float, str]] = [] tasks = [
for input_text, output_text in pairs: self._judge_single_safe(task_description, input_text, output_text)
pred = self._judge( for input_text, output_text in pairs
]
return list(await asyncio.gather(*tasks))
async def _judge_single_safe(
self,
task_description: str,
input_text: str,
output_text: str,
) -> tuple[float, str]:
async with self._semaphore:
try:
return await self._judge_single(task_description, input_text, output_text)
except Exception as exc:
logger.warning("Judge call failed for input '%s': %s", input_text[:40], exc)
return (0.0, f"[judge error: {exc}]")
async def _judge_single(
self,
task_description: str,
input_text: str,
output_text: str,
) -> tuple[float, str]:
async def _call() -> tuple[float, str]:
pred = await asyncio.to_thread(
self._sync_judge, task_description, input_text, output_text,
)
return self._aggregate_result(pred)
return await async_retry_with_backoff(
_call,
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)
def _sync_judge(self, task_description: str, input_text: str, output_text: str):
with dspy.context(lm=self._lm):
result = self._judge(
task_description=task_description, task_description=task_description,
input_text=input_text, input_text=input_text,
output_text=output_text, output_text=output_text,
judge_criteria=self._judge_criteria,
dimension_names=self._dimension_names,
) )
results.append((pred.score, pred.feedback)) self.call_count += 1
return results return result
def _aggregate_result(self, pred: Any) -> tuple[float, str]:
"""Compute weighted aggregate score from dimension scores if available."""
if not self._judge_dimensions:
return (pred.score, pred.feedback)
# Parse per-dimension scores from LLM output
dim_scores: dict[str, float] = {}
try:
raw = json.loads(pred.dimension_scores)
if isinstance(raw, dict):
for name in self._weights:
val = raw.get(name)
if val is not None:
dim_scores[name] = max(0.0, min(1.0, float(val)))
except (json.JSONDecodeError, ValueError, TypeError):
logger.debug("Failed to parse dimension_scores, falling back to overall score")
if not dim_scores:
return (pred.score, pred.feedback)
aggregate = weighted_aggregate(dim_scores, self._weights)
# Enrich feedback with per-dimension breakdown
dim_breakdown = ", ".join(
f"{name}={dim_scores.get(name, 0.0):.2f}"
for name in self._weights
)
feedback = f"{pred.feedback} [{dim_breakdown}]"
return (aggregate, feedback)

View File

@@ -5,10 +5,13 @@ Implements the LLMPort via DSPy.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import dspy import dspy
from prometheus.domain.entities import Prompt from prometheus.domain.entities import Prompt
from prometheus.domain.ports import LLMPort from prometheus.domain.ports import LLMPort
from prometheus.infrastructure.retry import async_retry_with_backoff
class DSPyLLMAdapter(LLMPort): class DSPyLLMAdapter(LLMPort):
@@ -21,12 +24,34 @@ class DSPyLLMAdapter(LLMPort):
input_text: str = dspy.InputField(desc="The input to process.") input_text: str = dspy.InputField(desc="The input to process.")
output: str = dspy.OutputField(desc="The response following the instruction.") output: str = dspy.OutputField(desc="The response following the instruction.")
def __init__(self, model: str) -> None: def __init__(
self,
lm: dspy.LM,
max_retries: int = 3,
retry_delay_base: float = 1.0,
) -> None:
self._lm = lm
self._predictor = dspy.Predict(self._ExecuteSignature) self._predictor = dspy.Predict(self._ExecuteSignature)
self._max_retries = max_retries
self._retry_delay_base = retry_delay_base
self.call_count: int = 0
def execute(self, prompt: Prompt, input_text: str) -> str: async def execute(self, prompt: Prompt, input_text: str) -> str:
async def _call() -> str:
# DSPy is synchronous — run in a thread to avoid blocking the event loop.
return await asyncio.to_thread(self._sync_execute, prompt, input_text)
return await async_retry_with_backoff(
_call,
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)
def _sync_execute(self, prompt: Prompt, input_text: str) -> str:
with dspy.context(lm=self._lm):
result = self._predictor( result = self._predictor(
instruction=prompt.text, instruction=prompt.text,
input_text=input_text, input_text=input_text,
) )
self.call_count += 1
return str(result.output) return str(result.output)

View File

@@ -0,0 +1,70 @@
"""
Adapter: Instruction Mutation via DSPy.
Implements MutationPort — applies typed mutations (paraphrase, constrain,
generalize, specialize) to a prompt.
"""
from __future__ import annotations
import asyncio
import random
import dspy
from prometheus.domain.entities import Prompt
from prometheus.domain.ports import MutationPort
from prometheus.infrastructure.dspy_modules import InstructionMutator
from prometheus.infrastructure.retry import async_retry_with_backoff
_MUTATION_TYPES = ("paraphrase", "constrain", "generalize", "specialize")
class DSPyMutationAdapter(MutationPort):
"""Uses DSPy to apply typed mutations to an instruction."""
def __init__(
self,
lm: dspy.LM,
max_retries: int = 3,
retry_delay_base: float = 1.0,
) -> None:
self._lm = lm
self._mutator = InstructionMutator()
self._max_retries = max_retries
self._retry_delay_base = retry_delay_base
self.call_count: int = 0
async def mutate(
self,
prompt: Prompt,
task_description: str,
mutation_type: str = "paraphrase",
) -> Prompt:
if mutation_type not in _MUTATION_TYPES:
mutation_type = random.choice(_MUTATION_TYPES)
async def _call() -> Prompt:
return await asyncio.to_thread(
self._sync_mutate, prompt, task_description, mutation_type,
)
return await async_retry_with_backoff(
_call,
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)
def _sync_mutate(
self,
prompt: Prompt,
task_description: str,
mutation_type: str,
) -> Prompt:
with dspy.context(lm=self._lm):
pred = self._mutator(
current_instruction=prompt.text,
task_description=task_description,
mutation_type=mutation_type,
)
self.call_count += 1
return Prompt(text=pred.mutated_instruction)

View File

@@ -6,29 +6,58 @@ Converts trajectories into readable format for the LLM proposer.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import dspy
from prometheus.domain.entities import Prompt, Trajectory from prometheus.domain.entities import Prompt, Trajectory
from prometheus.domain.ports import ProposerPort from prometheus.domain.ports import ProposerPort
from prometheus.infrastructure.dspy_modules import InstructionProposer from prometheus.infrastructure.dspy_modules import InstructionProposer
from prometheus.infrastructure.retry import async_retry_with_backoff
class DSPyProposerAdapter(ProposerPort): class DSPyProposerAdapter(ProposerPort):
"""Uses evaluation trajectories to build a failure report and propose a new prompt.""" """Uses evaluation trajectories to build a failure report and propose a new prompt."""
def __init__(self) -> None: def __init__(
self,
lm: dspy.LM,
max_retries: int = 3,
retry_delay_base: float = 1.0,
) -> None:
self._lm = lm
self._proposer = InstructionProposer() self._proposer = InstructionProposer()
self._max_retries = max_retries
self._retry_delay_base = retry_delay_base
self.call_count: int = 0
def propose( async def propose(
self, self,
current_prompt: Prompt, current_prompt: Prompt,
trajectories: list[Trajectory], trajectories: list[Trajectory],
task_description: str, task_description: str,
) -> Prompt: ) -> Prompt:
failure_examples = self._format_failures(trajectories) failure_examples = self._format_failures(trajectories)
async def _call() -> Prompt:
return await asyncio.to_thread(
self._sync_propose, current_prompt, task_description, failure_examples,
)
return await async_retry_with_backoff(
_call,
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)
def _sync_propose(self, current_prompt: Prompt, task_description: str, failure_examples: str) -> Prompt:
with dspy.context(lm=self._lm):
pred = self._proposer( pred = self._proposer(
current_instruction=current_prompt.text, current_instruction=current_prompt.text,
task_description=task_description, task_description=task_description,
failure_examples=failure_examples, failure_examples=failure_examples,
) )
self.call_count += 1
return Prompt(text=pred.new_instruction) return Prompt(text=pred.new_instruction)
@staticmethod @staticmethod

View File

@@ -0,0 +1,102 @@
"""Retry with exponential backoff for transient LLM errors."""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Any, Callable, Coroutine, TypeVar
logger = logging.getLogger(__name__)
T = TypeVar("T")
# Status codes / keywords that indicate a transient error worth retrying.
_TRANSIENT_PATTERNS = (
"429",
"rate limit",
"rate_limit",
"too many requests",
"500",
"502",
"503",
"504",
"timeout",
"timed out",
"connection error",
"connection refused",
"overloaded",
)
def is_transient_error(exc: Exception) -> bool:
"""Return True if the exception looks like a transient LLM/API error."""
msg = str(exc).lower()
if any(p in msg for p in _TRANSIENT_PATTERNS):
return True
if isinstance(exc, (ConnectionError, TimeoutError, OSError)):
return True
return False
class TransientError(RuntimeError):
"""Raised when all retry attempts are exhausted for a transient error."""
def retry_with_backoff(
fn: Callable[..., T],
*args: Any,
max_retries: int = 3,
retry_delay_base: float = 1.0,
**kwargs: Any,
) -> T:
"""Call *fn* with exponential-backoff retry on transient errors.
Delay per attempt: ``retry_delay_base * 2 ** attempt`` seconds.
"""
last_exc: Exception | None = None
for attempt in range(max_retries + 1):
try:
return fn(*args, **kwargs)
except Exception as exc:
last_exc = exc
if not is_transient_error(exc) or attempt == max_retries:
raise
delay = retry_delay_base * (2 ** attempt)
logger.warning(
"Transient error (attempt %d/%d): %s — retrying in %.1fs",
attempt + 1,
max_retries + 1,
exc,
delay,
)
time.sleep(delay)
# Should not reach here, but satisfy type-checker.
raise TransientError(str(last_exc)) from last_exc
async def async_retry_with_backoff(
fn: Callable[..., Coroutine[Any, Any, T]],
*args: Any,
max_retries: int = 3,
retry_delay_base: float = 1.0,
**kwargs: Any,
) -> T:
"""Async version of retry_with_backoff — uses asyncio.sleep instead of time.sleep."""
last_exc: Exception | None = None
for attempt in range(max_retries + 1):
try:
return await fn(*args, **kwargs)
except Exception as exc:
last_exc = exc
if not is_transient_error(exc) or attempt == max_retries:
raise
delay = retry_delay_base * (2 ** attempt)
logger.warning(
"Transient error (attempt %d/%d): %s — retrying in %.1fs",
attempt + 1,
max_retries + 1,
exc,
delay,
)
await asyncio.sleep(delay)
raise TransientError(str(last_exc)) from last_exc

View File

@@ -0,0 +1,153 @@
"""Similarity adapters — concrete metrics for comparing prediction vs expected."""
from __future__ import annotations
import math
import re
from collections import Counter
from prometheus.domain.ports import SimilarityPort
class ExactMatchSimilarity(SimilarityPort):
"""Case-insensitive exact string match. Returns 1.0 or 0.0."""
def compute(self, prediction: str, expected: str) -> float:
return 1.0 if prediction.strip().lower() == expected.strip().lower() else 0.0
class BleuSimilarity(SimilarityPort):
"""BLEU-style n-gram precision (up to 4-grams).
Simplified implementation using sentence-level BLEU with brevity penalty.
Returns a score in [0, 1].
"""
def __init__(self, max_n: int = 4):
self._max_n = max_n
def compute(self, prediction: str, expected: str) -> float:
pred_tokens = _tokenize(prediction)
ref_tokens = _tokenize(expected)
if not pred_tokens or not ref_tokens:
return 0.0 if not ref_tokens else 0.0
# Modified precision for each n-gram
precisions: list[float] = []
for n in range(1, self._max_n + 1):
pred_ngrams = _ngrams(pred_tokens, n)
ref_ngrams = _ngrams(ref_tokens, n)
if not pred_ngrams:
break
clipped = sum(min(pred_ngrams[ng], ref_ngrams.get(ng, 0)) for ng in pred_ngrams)
total = sum(pred_ngrams.values())
precisions.append(clipped / total if total > 0 else 0.0)
if not precisions:
return 0.0
# Geometric mean of precisions
log_avg = sum(math.log(p) for p in precisions if p > 0)
n_nonzero = sum(1 for p in precisions if p > 0)
if n_nonzero == 0:
return 0.0
geo_mean = math.exp(log_avg / n_nonzero)
# Brevity penalty
bp = 1.0
if len(pred_tokens) < len(ref_tokens):
bp = math.exp(1 - len(ref_tokens) / len(pred_tokens))
return min(bp * geo_mean, 1.0)
class RougeLSimilarity(SimilarityPort):
"""ROUGE-L using Longest Common Subsequence.
Returns F1 score combining precision and recall in [0, 1].
"""
def compute(self, prediction: str, expected: str) -> float:
pred_tokens = _tokenize(prediction)
ref_tokens = _tokenize(expected)
if not pred_tokens or not ref_tokens:
return 0.0
lcs_len = _lcs_length(pred_tokens, ref_tokens)
precision = lcs_len / len(pred_tokens)
recall = lcs_len / len(ref_tokens)
if precision + recall == 0:
return 0.0
f1 = 2 * precision * recall / (precision + recall)
return f1
class CosineSimilarity(SimilarityPort):
"""TF-IDF cosine similarity between bag-of-words vectors.
Lightweight semantic similarity without external embedding models.
"""
def compute(self, prediction: str, expected: str) -> float:
pred_tokens = _tokenize(prediction)
ref_tokens = _tokenize(expected)
if not pred_tokens or not ref_tokens:
return 0.0
pred_counts = Counter(pred_tokens)
ref_counts = Counter(ref_tokens)
all_tokens = set(pred_counts) | set(ref_counts)
dot = sum(pred_counts.get(t, 0) * ref_counts.get(t, 0) for t in all_tokens)
norm_pred = math.sqrt(sum(v * v for v in pred_counts.values()))
norm_ref = math.sqrt(sum(v * v for v in ref_counts.values()))
if norm_pred == 0 or norm_ref == 0:
return 0.0
return dot / (norm_pred * norm_ref)
# --- Helpers ---
def _tokenize(text: str) -> list[str]:
"""Simple whitespace + punctuation tokenizer."""
return re.findall(r"\w+", text.lower())
def _ngrams(tokens: list[str], n: int) -> Counter:
"""Count n-grams in a token list."""
return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))
def _lcs_length(a: list[str], b: list[str]) -> int:
"""Compute length of the Longest Common Subsequence."""
m, n = len(a), len(b)
prev = [0] * (n + 1)
for i in range(1, m + 1):
curr = [0] * (n + 1)
for j in range(1, n + 1):
if a[i - 1] == b[j - 1]:
curr[j] = prev[j - 1] + 1
else:
curr[j] = max(prev[j], curr[j - 1])
prev = curr
return prev[n]
def create_similarity_adapter(metric: str) -> SimilarityPort:
"""Factory: create a SimilarityPort by metric name.
Supported metrics: exact, bleu, rouge_l, cosine.
"""
adapters = {
"exact": ExactMatchSimilarity,
"bleu": BleuSimilarity,
"rouge_l": RougeLSimilarity,
"cosine": CosineSimilarity,
}
cls = adapters.get(metric)
if cls is None:
raise ValueError(
f"Unknown eval metric '{metric}'. Choose from: {sorted(adapters)}"
)
return cls()

View File

@@ -5,6 +5,8 @@ Implements the SyntheticGeneratorPort via DSPy.
""" """
from __future__ import annotations from __future__ import annotations
import dspy
from prometheus.domain.entities import SyntheticExample from prometheus.domain.entities import SyntheticExample
from prometheus.domain.ports import SyntheticGeneratorPort from prometheus.domain.ports import SyntheticGeneratorPort
from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator
@@ -13,18 +15,22 @@ from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator
class DSPySyntheticAdapter(SyntheticGeneratorPort): class DSPySyntheticAdapter(SyntheticGeneratorPort):
"""Generates synthetic inputs in a single batch call via DSPy.""" """Generates synthetic inputs in a single batch call via DSPy."""
def __init__(self) -> None: def __init__(self, lm: dspy.LM) -> None:
self._lm = lm
self._generator = SyntheticInputGenerator() self._generator = SyntheticInputGenerator()
self.call_count: int = 0
def generate_inputs( def generate_inputs(
self, self,
task_description: str, task_description: str,
n_examples: int, n_examples: int,
) -> list[SyntheticExample]: ) -> list[SyntheticExample]:
with dspy.context(lm=self._lm):
pred = self._generator( pred = self._generator(
task_description=task_description, task_description=task_description,
n_examples=n_examples, n_examples=n_examples,
) )
self.call_count += 1
return [ return [
SyntheticExample( SyntheticExample(
input_text=text, input_text=text,

View File

@@ -1,7 +1,7 @@
"""Shared test fixtures.""" """Shared test fixtures."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@@ -66,17 +66,17 @@ def mock_eval_result() -> EvalResult:
@pytest.fixture @pytest.fixture
def mock_llm_port() -> MagicMock: def mock_llm_port() -> AsyncMock:
"""Mock LLMPort that returns canned responses.""" """Mock LLMPort that returns canned responses."""
port = MagicMock() port = AsyncMock()
port.execute.return_value = "This is a mock response." port.execute.return_value = "This is a mock response."
return port return port
@pytest.fixture @pytest.fixture
def mock_judge_port() -> MagicMock: def mock_judge_port() -> AsyncMock:
"""Mock JudgePort that returns moderate scores.""" """Mock JudgePort that returns moderate scores."""
port = MagicMock() port = AsyncMock()
port.judge_batch.return_value = [ port.judge_batch.return_value = [
(0.5, "Moderate quality, needs improvement."), (0.5, "Moderate quality, needs improvement."),
] * 5 ] * 5
@@ -84,10 +84,34 @@ def mock_judge_port() -> MagicMock:
@pytest.fixture @pytest.fixture
def mock_proposer_port() -> MagicMock: def mock_proposer_port() -> AsyncMock:
"""Mock ProposerPort that returns a slightly modified prompt.""" """Mock ProposerPort that returns a slightly modified prompt."""
port = MagicMock() port = AsyncMock()
port.propose.return_value = Prompt( port.propose.return_value = Prompt(
text="You are a very helpful assistant. Answer the question precisely." text="You are a very helpful assistant. Answer the question precisely."
) )
return port return port
@pytest.fixture
def mock_crossover_port() -> AsyncMock:
"""Mock CrossoverPort that combines two parent prompts."""
port = AsyncMock()
async def _crossover(parent_a: Prompt, parent_b: Prompt, task_description: str) -> Prompt:
return Prompt(text=f"{parent_a.text} Also, {parent_b.text.lower()}")
port.crossover = AsyncMock(side_effect=_crossover)
return port
@pytest.fixture
def mock_mutation_port() -> AsyncMock:
"""Mock MutationPort that paraphrases a prompt."""
port = AsyncMock()
async def _mutate(prompt: Prompt, task_description: str, mutation_type: str = "paraphrase") -> Prompt:
return Prompt(text=f"[{mutation_type}] {prompt.text}")
port.mutate = AsyncMock(side_effect=_mutate)
return port

View File

@@ -16,14 +16,14 @@ def mock_lm() -> dspy.LM:
{"output": "Mock output response"}, {"output": "Mock output response"},
] ]
) )
dspy.configure(lm=lm)
return lm return lm
class TestDSPyLLMAdapter: class TestDSPyLLMAdapter:
def test_execute_returns_response(self, mock_lm: dspy.LM) -> None: @pytest.mark.asyncio
adapter = DSPyLLMAdapter(model="openai/gpt-4o-mini") async def test_execute_returns_response(self, mock_lm: dspy.LM) -> None:
adapter = DSPyLLMAdapter(lm=mock_lm)
prompt = Prompt(text="Answer the question.") prompt = Prompt(text="Answer the question.")
result = adapter.execute(prompt, "What is 2+2?") result = await adapter.execute(prompt, "What is 2+2?")
assert isinstance(result, str) assert isinstance(result, str)
assert len(result) > 0 assert len(result) > 0

View File

@@ -0,0 +1,300 @@
"""Integration tests for multi-iteration evolution with mixed accept/reject."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.evaluator import PromptEvaluator
from prometheus.application.evolution import EvolutionLoop
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory
from prometheus.domain.ports import JudgePort, LLMPort, ProposerPort
def _make_eval(scores: list[float]) -> EvalResult:
return EvalResult(
scores=scores,
feedbacks=["feedback"] * len(scores),
trajectories=[
Trajectory(f"in{i}", f"out{i}", s, "feedback", "prompt")
for i, s in enumerate(scores)
],
)
class TestMultiIterationEvolution:
"""Tests for the evolution loop across multiple iterations."""
@pytest.fixture
def seed_prompt(self) -> Prompt:
return Prompt(text="You are a helpful assistant.")
@pytest.fixture
def task_description(self) -> str:
return "Answer factual questions."
@pytest.fixture
def synthetic_pool(self) -> list[SyntheticExample]:
return [SyntheticExample(input_text=f"input {i}", id=i) for i in range(20)]
@pytest.mark.asyncio
async def test_mixed_accept_reject(
self,
seed_prompt: Prompt,
task_description: str,
synthetic_pool: list[SyntheticExample],
) -> None:
"""Iteration 1: accept, iteration 2: reject, iteration 3: accept."""
mock_llm = MagicMock(spec=LLMPort)
mock_judge = MagicMock(spec=JudgePort)
mock_proposer = MagicMock(spec=ProposerPort)
evaluator = PromptEvaluator(mock_llm, mock_judge)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:3]
# Build eval sequence: initial, then per-iteration (current, new)
evals = [
_make_eval([0.3, 0.3, 0.3]), # initial seed eval
# Iter 1: accept (old=0.4, new=0.8)
_make_eval([0.4, 0.4, 0.4]),
_make_eval([0.8, 0.8, 0.8]),
# Iter 2: reject (old=0.7, new=0.2)
_make_eval([0.7, 0.7, 0.7]),
_make_eval([0.2, 0.2, 0.2]),
# Iter 3: accept (old=0.5, new=0.9)
_make_eval([0.5, 0.5, 0.5]),
_make_eval([0.9, 0.9, 0.9]),
]
evaluator.evaluate = AsyncMock(side_effect=evals)
mock_proposer.propose.side_effect = [
Prompt(text="Better prompt v1"),
Prompt(text="Worse prompt v2"),
Prompt(text="Best prompt v3"),
]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer,
bootstrap=bootstrap,
max_iterations=3,
minibatch_size=3,
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
assert state.iteration == 3
assert state.best_candidate is not None
assert state.best_candidate.best_score == pytest.approx(2.7) # 0.9*3
assert len(state.history) == 3
assert state.history[0]["event"] == "accepted"
assert state.history[1]["event"] == "rejected"
assert state.history[2]["event"] == "accepted"
@pytest.mark.asyncio
async def test_all_rejected_keeps_seed(
self,
seed_prompt: Prompt,
task_description: str,
synthetic_pool: list[SyntheticExample],
) -> None:
"""When all proposals are rejected, the seed prompt stays as best."""
mock_llm = MagicMock(spec=LLMPort)
mock_judge = MagicMock(spec=JudgePort)
mock_proposer = MagicMock(spec=ProposerPort)
evaluator = PromptEvaluator(mock_llm, mock_judge)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:3]
evals = [
_make_eval([0.5, 0.5, 0.5]), # initial
]
for _ in range(3):
evals.append(_make_eval([0.5, 0.5, 0.5])) # current
evals.append(_make_eval([0.1, 0.1, 0.1])) # worse proposal
evaluator.evaluate = AsyncMock(side_effect=evals)
mock_proposer.propose.side_effect = [
Prompt(text="bad v1"),
Prompt(text="bad v2"),
Prompt(text="bad v3"),
]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer,
bootstrap=bootstrap,
max_iterations=3,
minibatch_size=3,
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
assert state.best_candidate.prompt.text == seed_prompt.text
assert state.best_candidate.best_score == pytest.approx(1.5) # 0.5*3
@pytest.mark.asyncio
async def test_all_accepted_chain(
self,
seed_prompt: Prompt,
task_description: str,
synthetic_pool: list[SyntheticExample],
) -> None:
"""All iterations accept, forming an improvement chain."""
mock_llm = MagicMock(spec=LLMPort)
mock_judge = MagicMock(spec=JudgePort)
mock_proposer = MagicMock(spec=ProposerPort)
evaluator = PromptEvaluator(mock_llm, mock_judge)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:2]
evals = [
_make_eval([0.2, 0.2]), # initial
]
for i in range(1, 5):
score = 0.2 + i * 0.15
evals.append(_make_eval([score, score])) # current
evals.append(_make_eval([score + 0.1, score + 0.1])) # new (accepted)
evaluator.evaluate = AsyncMock(side_effect=evals)
mock_proposer.propose.side_effect = [
Prompt(text=f"Improved v{i}") for i in range(4)
]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer,
bootstrap=bootstrap,
max_iterations=4,
minibatch_size=2,
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
assert len(state.candidates) == 5 # seed + 4 accepted
assert all(h["event"] == "accepted" for h in state.history)
@pytest.mark.asyncio
async def test_error_recovery_continues_loop(
self,
seed_prompt: Prompt,
task_description: str,
synthetic_pool: list[SyntheticExample],
) -> None:
"""When an iteration errors, the loop continues."""
mock_llm = MagicMock(spec=LLMPort)
mock_judge = MagicMock(spec=JudgePort)
mock_proposer = MagicMock(spec=ProposerPort)
evaluator = PromptEvaluator(mock_llm, mock_judge)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:2]
# Eval sequence for 3 iterations:
# - iter 1: evaluate current → propose → evaluate new (accepted)
# - iter 2: evaluate current → propose (ERROR, no new eval)
# - iter 3: evaluate current → propose → evaluate new (accepted)
evals = [
_make_eval([0.3, 0.3]), # initial
_make_eval([0.5, 0.5]), # iter 1 current
_make_eval([0.9, 0.9]), # iter 1 new (accepted)
_make_eval([0.5, 0.5]), # iter 2 current (proposer errors after this)
_make_eval([0.5, 0.5]), # iter 3 current
_make_eval([0.8, 0.8]), # iter 3 new (accepted)
]
evaluator.evaluate = AsyncMock(side_effect=evals)
# Proposer raises on iter 2
mock_proposer.propose.side_effect = [
Prompt(text="good v1"),
RuntimeError("LLM timeout"),
Prompt(text="good v3"),
]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer,
bootstrap=bootstrap,
max_iterations=3,
minibatch_size=2,
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
assert state.iteration == 3
assert state.history[1]["event"] == "error"
assert "LLM timeout" in state.history[1]["error"]
assert state.history[0]["event"] == "accepted"
assert state.history[2]["event"] == "accepted"
@pytest.mark.asyncio
async def test_perfect_score_skips_proposer(
self,
seed_prompt: Prompt,
task_description: str,
synthetic_pool: list[SyntheticExample],
) -> None:
"""When all scores are perfect, no proposition is made."""
mock_llm = MagicMock(spec=LLMPort)
mock_judge = MagicMock(spec=JudgePort)
mock_proposer = MagicMock(spec=ProposerPort)
evaluator = PromptEvaluator(mock_llm, mock_judge)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:2]
perfect_eval = _make_eval([1.0, 1.0])
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer,
bootstrap=bootstrap,
max_iterations=5,
minibatch_size=2,
perfect_score=1.0,
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
mock_proposer.propose.assert_not_called()
assert all(h["event"] == "skip_perfect" for h in state.history)
@pytest.mark.asyncio
async def test_llm_call_counting(
self,
seed_prompt: Prompt,
task_description: str,
synthetic_pool: list[SyntheticExample],
) -> None:
"""Verify LLM call counting: 2*N per eval (execute + judge) + 1 per propose."""
mock_llm = MagicMock(spec=LLMPort)
mock_judge = MagicMock(spec=JudgePort)
mock_proposer = MagicMock(spec=ProposerPort)
evaluator = PromptEvaluator(mock_llm, mock_judge)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:3]
evals = [_make_eval([0.3, 0.3, 0.3])] # initial
for _ in range(2):
evals.append(_make_eval([0.4, 0.4, 0.4]))
evals.append(_make_eval([0.6, 0.6, 0.6]))
evaluator.evaluate = AsyncMock(side_effect=evals)
mock_proposer.propose.side_effect = [
Prompt(text="v1"),
Prompt(text="v2"),
]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer,
bootstrap=bootstrap,
max_iterations=2,
minibatch_size=3,
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
# Initial: 2*3=6, Iter1: 2*3 + 1 + 2*3 = 13, Iter2: same = 13
# Total: 6 + 13 + 13 = 32
assert state.total_llm_calls == 32

View File

@@ -1,7 +1,9 @@
"""End-to-end pipeline test with mocked LLM calls.""" """End-to-end pipeline test with mocked LLM calls."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest
from prometheus.application.bootstrap import SyntheticBootstrap from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.dto import OptimizationConfig from prometheus.application.dto import OptimizationConfig
@@ -23,9 +25,10 @@ def _make_eval(scores: list[float]) -> EvalResult:
class TestFullPipeline: class TestFullPipeline:
def test_pipeline_produces_result(self) -> None: @pytest.mark.asyncio
async def test_pipeline_produces_result(self) -> None:
"""Full pipeline with mocked ports produces an OptimizationResult.""" """Full pipeline with mocked ports produces an OptimizationResult."""
mock_llm = MagicMock(spec=LLMPort) mock_llm = AsyncMock(spec=LLMPort)
mock_llm.execute.return_value = "mock response" mock_llm.execute.return_value = "mock response"
mock_judge = MagicMock(spec=JudgePort) mock_judge = MagicMock(spec=JudgePort)
@@ -38,11 +41,11 @@ class TestFullPipeline:
eval_sequence.append(_make_eval([0.6, 0.6, 0.6, 0.6, 0.6])) # new eval (accepted) eval_sequence.append(_make_eval([0.6, 0.6, 0.6, 0.6, 0.6])) # new eval (accepted)
mock_judge.judge_batch.return_value = [(0.5, "ok")] * 5 mock_judge.judge_batch.return_value = [(0.5, "ok")] * 5
mock_proposer = MagicMock(spec=ProposerPort) mock_proposer = AsyncMock(spec=ProposerPort)
mock_proposer.propose.return_value = Prompt(text="Improved prompt") mock_proposer.propose.return_value = Prompt(text="Improved prompt")
evaluator = PromptEvaluator(mock_llm, mock_judge) evaluator = PromptEvaluator(mock_llm, mock_judge)
evaluator.evaluate = MagicMock(side_effect=eval_sequence) evaluator.evaluate = AsyncMock(side_effect=eval_sequence)
mock_gen = MagicMock() mock_gen = MagicMock()
mock_gen.generate_inputs.return_value = [ mock_gen.generate_inputs.return_value = [
@@ -65,7 +68,7 @@ class TestFullPipeline:
seed=42, seed=42,
) )
result = use_case.execute(config) result = await use_case.execute(config)
assert result.initial_prompt == "Answer questions." assert result.initial_prompt == "Answer questions."
assert result.optimized_prompt == "Improved prompt" assert result.optimized_prompt == "Improved prompt"

View File

@@ -0,0 +1,199 @@
"""Integration test — ground-truth evaluation end-to-end with real similarity metrics."""
from __future__ import annotations
import asyncio
import json
import pytest
from unittest.mock import AsyncMock
from prometheus.application.ground_truth_evaluator import GroundTruthEvaluator
from prometheus.domain.entities import GroundTruthExample, Prompt
from prometheus.domain.ports import LLMPort
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
from prometheus.infrastructure.similarity import (
BleuSimilarity,
CosineSimilarity,
ExactMatchSimilarity,
RougeLSimilarity,
create_similarity_adapter,
)
def _make_dataset(items: list[tuple[str, str]]) -> list[GroundTruthExample]:
return [
GroundTruthExample(input_text=inp, expected_output=exp, id=i)
for i, (inp, exp) in enumerate(items)
]
@pytest.fixture
def qa_dataset():
return _make_dataset([
("What is the capital of France?", "Paris"),
("What is 2+2?", "4"),
("What color is the sky?", "blue"),
])
@pytest.fixture
def prompt():
return Prompt(text="Answer the following question concisely.")
@pytest.fixture
def mock_executor():
"""Returns responses that partially match the ground truth."""
port = AsyncMock(spec=LLMPort)
port.execute.side_effect = [
"Paris is the capital of France.",
"The answer is 4.",
"The sky is blue.",
]
return port
class TestGroundTruthIntegrationWithExactMatch:
@pytest.mark.asyncio
async def test_exact_match_on_qa(self, mock_executor, qa_dataset, prompt):
evaluator = GroundTruthEvaluator(
executor=mock_executor,
similarity=ExactMatchSimilarity(),
)
result = await evaluator.evaluate(prompt, qa_dataset)
# None of the outputs are exact matches with expected outputs
assert all(s == 0.0 for s in result.scores)
@pytest.mark.asyncio
async def test_exact_match_with_exact_outputs(self, qa_dataset, prompt):
exact_executor = AsyncMock(spec=LLMPort)
exact_executor.execute.side_effect = ["Paris", "4", "blue"]
evaluator = GroundTruthEvaluator(
executor=exact_executor,
similarity=ExactMatchSimilarity(),
)
result = await evaluator.evaluate(prompt, qa_dataset)
assert all(s == 1.0 for s in result.scores)
class TestGroundTruthIntegrationWithBleu:
@pytest.mark.asyncio
async def test_bleu_scores_partial_match(self, mock_executor, qa_dataset, prompt):
evaluator = GroundTruthEvaluator(
executor=mock_executor,
similarity=BleuSimilarity(),
)
result = await evaluator.evaluate(prompt, qa_dataset)
assert all(0.0 < s < 1.0 for s in result.scores)
assert result.mean_score > 0.0
@pytest.mark.asyncio
async def test_bleu_perfect_match(self, qa_dataset, prompt):
perfect_executor = AsyncMock(spec=LLMPort)
perfect_executor.execute.side_effect = ["Paris", "4", "blue"]
evaluator = GroundTruthEvaluator(
executor=perfect_executor,
similarity=BleuSimilarity(),
)
result = await evaluator.evaluate(prompt, qa_dataset)
assert all(s > 0.0 for s in result.scores)
class TestGroundTruthIntegrationWithRouge:
@pytest.mark.asyncio
async def test_rouge_l_scores(self, mock_executor, qa_dataset, prompt):
evaluator = GroundTruthEvaluator(
executor=mock_executor,
similarity=RougeLSimilarity(),
)
result = await evaluator.evaluate(prompt, qa_dataset)
assert all(s > 0.0 for s in result.scores)
class TestGroundTruthIntegrationWithCosine:
@pytest.mark.asyncio
async def test_cosine_scores(self, mock_executor, qa_dataset, prompt):
evaluator = GroundTruthEvaluator(
executor=mock_executor,
similarity=CosineSimilarity(),
)
result = await evaluator.evaluate(prompt, qa_dataset)
assert all(s > 0.0 for s in result.scores)
class TestDatasetLoaderIntegration:
@pytest.mark.asyncio
async def test_load_csv_and_evaluate(self, tmp_path, prompt):
csv_file = tmp_path / "eval.csv"
csv_file.write_text("input,expected_output\nWhat is 2+2?,4\nWhat color is grass?,green\n")
loader = FileDatasetLoader()
dataset = loader.load(str(csv_file))
assert len(dataset) == 2
executor = AsyncMock(spec=LLMPort)
executor.execute.side_effect = ["4", "green"]
evaluator = GroundTruthEvaluator(
executor=executor,
similarity=ExactMatchSimilarity(),
)
result = await evaluator.evaluate(prompt, dataset)
assert all(s == 1.0 for s in result.scores)
@pytest.mark.asyncio
async def test_load_json_and_evaluate(self, tmp_path, prompt):
json_file = tmp_path / "eval.json"
data = [
{"input": "What is 2+2?", "expected_output": "4"},
{"input": "What color is grass?", "expected_output": "green"},
]
json_file.write_text(json.dumps(data))
loader = FileDatasetLoader()
dataset = loader.load(str(json_file))
assert len(dataset) == 2
executor = AsyncMock(spec=LLMPort)
executor.execute.side_effect = ["4", "not green"]
evaluator = GroundTruthEvaluator(
executor=executor,
similarity=create_similarity_adapter("bleu"),
)
result = await evaluator.evaluate(prompt, dataset)
# First item should score well, second poorly
assert result.scores[0] > result.scores[1]
class TestMetricComparison:
"""Compare different metrics on the same outputs to ensure they behave differently."""
@pytest.mark.asyncio
async def test_metrics_give_different_scores(self, qa_dataset, prompt):
results = {}
for metric_name, metric_cls in [
("exact", ExactMatchSimilarity),
("bleu", BleuSimilarity),
("rouge_l", RougeLSimilarity),
("cosine", CosineSimilarity),
]:
executor = AsyncMock(spec=LLMPort)
executor.execute.side_effect = [
"Paris is the capital of France.",
"The answer is 4.",
"The sky is blue.",
]
evaluator = GroundTruthEvaluator(
executor=executor,
similarity=metric_cls(),
)
result = await evaluator.evaluate(prompt, qa_dataset)
results[metric_name] = result.mean_score
# Exact match should be 0 (no exact matches)
assert results["exact"] == 0.0
# All other metrics should give partial credit
assert results["bleu"] > 0.0
assert results["rouge_l"] > 0.0
assert results["cosine"] > 0.0

View File

@@ -0,0 +1,213 @@
"""Unit tests for multi-model adapter configuration.
Verifies that each adapter uses its own dspy.LM instance and
that per-model api_base/api_key_env overrides are wired correctly.
"""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import dspy
import pytest
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
@pytest.fixture
def task_lm() -> dspy.LM:
"""Dummy LM for task execution."""
return dspy.utils.DummyLM([{"output": "task model output"}])
@pytest.fixture
def judge_lm() -> dspy.LM:
"""Dummy LM for judging (ChainOfThought requires reasoning field)."""
return dspy.utils.DummyLM(
[
{"reasoning": "Evaluating output.", "score": "0.8", "feedback": "Good response.", "dimension_scores": "{}"},
]
)
@pytest.fixture
def proposer_lm() -> dspy.LM:
"""Dummy LM for proposing (ChainOfThought requires reasoning field)."""
return dspy.utils.DummyLM(
[
{"reasoning": "Analyzing failures.", "new_instruction": "Improved prompt: be more specific."},
]
)
@pytest.fixture
def synth_lm() -> dspy.LM:
"""Dummy LM for synthetic generation (ChainOfThought requires reasoning field)."""
return dspy.utils.DummyLM(
[
{"reasoning": "Generating examples.", "examples": json.dumps(["input 1", "input 2", "input 3"])},
]
)
class TestDSPyLLMAdapterOwnLM:
"""Bug #2 fix: DSPyLLMAdapter must use the LM it receives, not the global one."""
@pytest.mark.asyncio
async def test_uses_provided_lm_not_global(self) -> None:
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
dspy.configure(lm=global_lm)
adapter = DSPyLLMAdapter(lm=local_lm)
result = await adapter.execute(Prompt(text="test"), "input")
assert result == "local response"
@pytest.mark.asyncio
async def test_does_not_affect_global_lm(self) -> None:
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
dspy.configure(lm=global_lm)
adapter = DSPyLLMAdapter(lm=local_lm)
await adapter.execute(Prompt(text="test"), "input")
# Global LM should still be the same
assert dspy.settings.lm is global_lm
class TestDSPyJudgeAdapterOwnLM:
"""DSPyJudgeAdapter must use its own LM instance."""
@pytest.mark.asyncio
async def test_uses_provided_lm(self, judge_lm: dspy.LM) -> None:
adapter = DSPyJudgeAdapter(lm=judge_lm)
results = await adapter.judge_batch(
task_description="Test task",
pairs=[("input 1", "output 1")],
)
assert len(results) == 1
score, feedback = results[0]
assert score == 0.8
assert feedback == "Good response."
@pytest.mark.asyncio
async def test_does_not_use_global_lm(self) -> None:
judge_lm = dspy.utils.DummyLM(
[{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response", "dimension_scores": "{}"}]
)
global_lm = dspy.utils.DummyLM([{"reasoning": "no", "score": "0.1", "feedback": "Wrong LM!", "dimension_scores": "{}"}])
dspy.configure(lm=global_lm)
adapter = DSPyJudgeAdapter(lm=judge_lm)
results = await adapter.judge_batch("task", [("in", "out")])
assert results[0][0] == 0.9
class TestDSPyProposerAdapterOwnLM:
"""DSPyProposerAdapter must use its own LM instance."""
@pytest.mark.asyncio
async def test_uses_provided_lm(self, proposer_lm: dspy.LM) -> None:
adapter = DSPyProposerAdapter(lm=proposer_lm)
trajectories = [
Trajectory(
input_text="test input",
output_text="test output",
score=0.3,
feedback="bad",
prompt_used="old prompt",
)
]
result = await adapter.propose(
current_prompt=Prompt(text="old prompt"),
trajectories=trajectories,
task_description="Test task",
)
assert "Improved prompt" in result.text
@pytest.mark.asyncio
async def test_does_not_use_global_lm(self) -> None:
proposer_lm = dspy.utils.DummyLM(
[{"reasoning": "ok", "new_instruction": "proposer-specific"}]
)
global_lm = dspy.utils.DummyLM(
[{"reasoning": "no", "new_instruction": "wrong-global"}]
)
dspy.configure(lm=global_lm)
adapter = DSPyProposerAdapter(lm=proposer_lm)
result = await adapter.propose(
current_prompt=Prompt(text="test"),
trajectories=[],
task_description="task",
)
assert result.text == "proposer-specific"
class TestDSPySyntheticAdapterOwnLM:
"""DSPySyntheticAdapter must use its own LM instance."""
def test_uses_provided_lm(self, synth_lm: dspy.LM) -> None:
adapter = DSPySyntheticAdapter(lm=synth_lm)
results = adapter.generate_inputs("Test task", 3)
assert len(results) == 3
assert all(isinstance(ex, SyntheticExample) for ex in results)
def test_does_not_use_global_lm(self) -> None:
synth_lm = dspy.utils.DummyLM(
[{"reasoning": "ok", "examples": json.dumps(["synth-specific"])}]
)
global_lm = dspy.utils.DummyLM(
[{"reasoning": "no", "examples": json.dumps(["wrong-global"])}]
)
dspy.configure(lm=global_lm)
adapter = DSPySyntheticAdapter(lm=synth_lm)
results = adapter.generate_inputs("task", 1)
assert results[0].input_text == "synth-specific"
class TestPerModelOverrides:
"""Verify that per-model api_base/api_key_env are passed through to dspy.LM."""
@patch("prometheus.cli.commands.optimize.dspy.LM")
def test_per_model_api_base_override(self, mock_lm_cls: MagicMock) -> None:
"""Per-model api_base should be used instead of global."""
mock_lm_cls.return_value = MagicMock()
from prometheus.application.dto import OptimizationConfig
config = OptimizationConfig(
seed_prompt="test",
task_description="test",
task_model="openai/gpt-4o-mini",
judge_model="openai/gpt-4o",
proposer_model="openai/gpt-4o",
synth_model="openai/gpt-4o",
judge_api_base="https://judge.example.com/v1",
judge_api_key_env="JUDGE_API_KEY",
)
# Verify config carries the overrides
assert config.judge_api_base == "https://judge.example.com/v1"
assert config.judge_api_key_env == "JUDGE_API_KEY"
assert config.task_api_base is None
def test_config_defaults_to_none(self) -> None:
from prometheus.application.dto import OptimizationConfig
config = OptimizationConfig(seed_prompt="test", task_description="test")
assert config.task_api_base is None
assert config.task_api_key_env is None
assert config.judge_api_base is None
assert config.judge_api_key_env is None
assert config.proposer_api_base is None
assert config.proposer_api_key_env is None
assert config.synth_api_base is None
assert config.synth_api_key_env is None

294
tests/unit/test_adapters.py Normal file
View File

@@ -0,0 +1,294 @@
"""Unit tests for infrastructure adapters — LLM, Judge, Proposer, Synthetic.
Uses mocked DSPy modules to isolate adapter logic from LLM calls.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import dspy
import pytest
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
# --- LLM Adapter ---
class TestDSPyLLMAdapter:
"""Tests for DSPyLLMAdapter.execute()."""
@pytest.fixture
def mock_lm(self) -> MagicMock:
return MagicMock(spec=dspy.LM)
@pytest.fixture
def adapter(self, mock_lm: MagicMock) -> DSPyLLMAdapter:
return DSPyLLMAdapter(lm=mock_lm)
@pytest.mark.asyncio
async def test_execute_returns_output_string(
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
) -> None:
mock_predictor = MagicMock()
mock_predictor.return_value = MagicMock(output="Hello response")
adapter._predictor = mock_predictor
prompt = Prompt(text="Say hello.")
result = await adapter.execute(prompt, "Hi there")
assert result == "Hello response"
@pytest.mark.asyncio
async def test_execute_passes_prompt_text_and_input(
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
) -> None:
mock_predictor = MagicMock()
mock_predictor.return_value = MagicMock(output="response")
adapter._predictor = mock_predictor
prompt = Prompt(text="Translate this.")
await adapter.execute(prompt, "Hello world")
mock_predictor.assert_called_once_with(
instruction="Translate this.",
input_text="Hello world",
)
@pytest.mark.asyncio
async def test_execute_uses_dspy_context(
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
) -> None:
mock_predictor = MagicMock()
mock_predictor.return_value = MagicMock(output="ok")
adapter._predictor = mock_predictor
with patch("prometheus.infrastructure.llm_adapter.dspy.context") as mock_ctx:
await adapter.execute(Prompt(text="test"), "input")
mock_ctx.assert_called_once_with(lm=mock_lm)
@pytest.mark.asyncio
async def test_execute_converts_output_to_str(
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
) -> None:
mock_predictor = MagicMock()
mock_predictor.return_value = MagicMock(output=42)
adapter._predictor = mock_predictor
result = await adapter.execute(Prompt(text="test"), "input")
assert isinstance(result, str)
assert result == "42"
# --- Judge Adapter ---
class TestDSPyJudgeAdapter:
"""Tests for DSPyJudgeAdapter.judge_batch()."""
@pytest.fixture
def mock_lm(self) -> MagicMock:
return MagicMock(spec=dspy.LM)
@pytest.fixture
def adapter(self, mock_lm: MagicMock) -> DSPyJudgeAdapter:
return DSPyJudgeAdapter(lm=mock_lm)
@pytest.mark.asyncio
async def test_judge_batch_returns_scores_and_feedback(
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
) -> None:
adapter._judge = MagicMock()
adapter._judge.side_effect = [
MagicMock(score=0.9, feedback="Excellent."),
MagicMock(score=0.4, feedback="Incomplete."),
]
pairs = [("What is 2+2?", "4"), ("Capital of France?", "London")]
result = await adapter.judge_batch("math and geography", pairs)
assert len(result) == 2
assert result[0] == (0.9, "Excellent.")
assert result[1] == (0.4, "Incomplete.")
@pytest.mark.asyncio
async def test_judge_batch_empty_pairs(
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
) -> None:
result = await adapter.judge_batch("task", [])
assert result == []
@pytest.mark.asyncio
async def test_judge_batch_uses_dspy_context(
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
) -> None:
adapter._judge = MagicMock()
adapter._judge.return_value = MagicMock(score=0.5, feedback="ok")
with patch("prometheus.infrastructure.judge_adapter.dspy.context") as mock_ctx:
await adapter.judge_batch("task", [("in", "out")])
mock_ctx.assert_called_once_with(lm=mock_lm)
@pytest.mark.asyncio
async def test_judge_batch_returns_all_results(
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
) -> None:
"""Judge calls run in parallel but all results are returned."""
adapter._judge = MagicMock()
adapter._judge.side_effect = [
MagicMock(score=0.5, feedback="ok"),
MagicMock(score=0.7, feedback="better"),
MagicMock(score=0.3, feedback="worse"),
]
pairs = [("first", "out1"), ("second", "out2"), ("third", "out3")]
results = await adapter.judge_batch("task", pairs)
assert len(results) == 3
scores = [r[0] for r in results]
assert 0.5 in scores
assert 0.7 in scores
assert 0.3 in scores
# --- Proposer Adapter ---
class TestDSPyProposerAdapter:
"""Tests for DSPyProposerAdapter.propose()."""
@pytest.fixture
def mock_lm(self) -> MagicMock:
return MagicMock(spec=dspy.LM)
@pytest.fixture
def adapter(self, mock_lm: MagicMock) -> DSPyProposerAdapter:
return DSPyProposerAdapter(lm=mock_lm)
@pytest.mark.asyncio
async def test_propose_returns_new_prompt(
self, adapter: DSPyProposerAdapter, mock_lm: MagicMock
) -> None:
adapter._proposer = MagicMock()
adapter._proposer.return_value = MagicMock(
new_instruction="Be concise and accurate."
)
current = Prompt(text="Answer questions.")
trajectories = [
Trajectory("in", "out", 0.3, "too verbose", "Answer questions.")
]
result = await adapter.propose(current, trajectories, "Q&A task")
assert isinstance(result, Prompt)
assert result.text == "Be concise and accurate."
@pytest.mark.asyncio
async def test_propose_uses_dspy_context(
self, adapter: DSPyProposerAdapter, mock_lm: MagicMock
) -> None:
adapter._proposer = MagicMock()
adapter._proposer.return_value = MagicMock(new_instruction="improved")
with patch("prometheus.infrastructure.proposer_adapter.dspy.context") as mock_ctx:
await adapter.propose(Prompt(text="t"), [], "task")
mock_ctx.assert_called_once_with(lm=mock_lm)
def test_format_failures_single_trajectory(self) -> None:
trajectories = [
Trajectory("What is AI?", "A type of robot.", 0.3, "Incomplete definition.", "prompt")
]
result = DSPyProposerAdapter._format_failures(trajectories)
assert "What is AI?" in result
assert "A type of robot." in result
assert "0.30" in result
assert "Incomplete definition." in result
assert "# Example 1" in result
def test_format_failures_multiple_trajectories(self) -> None:
trajectories = [
Trajectory("input1", "output1", 0.4, "bad", "prompt"),
Trajectory("input2", "output2", 0.2, "worse", "prompt"),
]
result = DSPyProposerAdapter._format_failures(trajectories)
assert "# Example 1" in result
assert "# Example 2" in result
assert "---" in result
assert "input1" in result
assert "input2" in result
def test_format_failures_empty_list(self) -> None:
result = DSPyProposerAdapter._format_failures([])
assert result == ""
# --- Synthetic Adapter ---
class TestDSPySyntheticAdapter:
"""Tests for DSPySyntheticAdapter.generate_inputs()."""
@pytest.fixture
def mock_lm(self) -> MagicMock:
return MagicMock(spec=dspy.LM)
@pytest.fixture
def adapter(self, mock_lm: MagicMock) -> DSPySyntheticAdapter:
return DSPySyntheticAdapter(lm=mock_lm)
def test_generate_inputs_returns_examples(
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
) -> None:
adapter._generator = MagicMock()
adapter._generator.return_value = MagicMock(
examples=["What is AI?", "Explain ML.", "What is NLP?"]
)
result = adapter.generate_inputs("AI task", 3)
assert len(result) == 3
assert all(isinstance(ex, SyntheticExample) for ex in result)
assert result[0].input_text == "What is AI?"
assert result[0].id == 0
assert result[1].id == 1
def test_generate_inputs_truncates_to_n(
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
) -> None:
adapter._generator = MagicMock()
adapter._generator.return_value = MagicMock(
examples=["q1", "q2", "q3", "q4", "q5"]
)
result = adapter.generate_inputs("task", 3)
assert len(result) == 3
def test_generate_inputs_passes_correct_args(
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
) -> None:
adapter._generator = MagicMock()
adapter._generator.return_value = MagicMock(examples=["q1"])
adapter.generate_inputs("my task", 5)
adapter._generator.assert_called_once_with(
task_description="my task",
n_examples=5,
)
def test_generate_inputs_empty_list(
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
) -> None:
adapter._generator = MagicMock()
adapter._generator.return_value = MagicMock(examples=[])
result = adapter.generate_inputs("task", 0)
assert result == []

View File

@@ -0,0 +1,333 @@
"""Unit tests for checkpoint & resume functionality."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.evaluator import PromptEvaluator
from prometheus.application.evolution import EvolutionLoop
from prometheus.domain.entities import (
Candidate,
EvalResult,
OptimizationState,
Prompt,
SyntheticExample,
Trajectory,
)
from prometheus.infrastructure.checkpoint import JsonCheckpointPersistence
# ---------------------------------------------------------------------------
# JsonCheckpointPersistence — save/load round-trip
# ---------------------------------------------------------------------------
class TestJsonCheckpointPersistence:
def test_roundtrip_full_state(self, tmp_path: Path) -> None:
"""Saving and loading preserves all fields."""
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
state = OptimizationState(
iteration=7,
best_candidate=Candidate(
prompt=Prompt(text="best prompt", metadata={"source": "test"}),
best_score=0.92,
generation=5,
),
candidates=[
Candidate(prompt=Prompt(text="p1"), best_score=0.5, generation=0),
Candidate(prompt=Prompt(text="p2"), best_score=0.92, generation=5),
],
synthetic_pool=[
SyntheticExample(input_text="q1", category="cat_a", id=0),
SyntheticExample(input_text="q2", category="cat_b", id=1),
],
history=[{"iteration": 1, "event": "accepted", "old_score": 0.5, "new_score": 0.7}],
total_llm_calls=42,
)
ckpt.save(state)
assert ckpt.latest_exists()
loaded = ckpt.load()
assert loaded is not None
assert loaded.iteration == 7
assert loaded.total_llm_calls == 42
assert loaded.best_candidate is not None
assert loaded.best_candidate.prompt.text == "best prompt"
assert loaded.best_candidate.prompt.metadata == {"source": "test"}
assert loaded.best_candidate.best_score == 0.92
assert len(loaded.candidates) == 2
assert len(loaded.synthetic_pool) == 2
assert loaded.synthetic_pool[0].input_text == "q1"
assert loaded.synthetic_pool[1].category == "cat_b"
assert loaded.history[0]["event"] == "accepted"
def test_load_returns_none_when_no_checkpoint(self, tmp_path: Path) -> None:
"""Loading from empty dir returns None."""
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "nope")
assert ckpt.load() is None
assert not ckpt.latest_exists()
def test_creates_directory_on_save(self, tmp_path: Path) -> None:
"""Save creates the directory tree if it doesn't exist."""
deep_dir = tmp_path / "a" / "b" / "c"
ckpt = JsonCheckpointPersistence(checkpoint_dir=deep_dir)
state = OptimizationState(iteration=1)
ckpt.save(state)
assert (deep_dir / "latest.json").exists()
def test_overwrites_previous_checkpoint(self, tmp_path: Path) -> None:
"""Second save overwrites the first."""
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
ckpt.save(OptimizationState(iteration=1, total_llm_calls=10))
ckpt.save(OptimizationState(iteration=5, total_llm_calls=50))
loaded = ckpt.load()
assert loaded is not None
assert loaded.iteration == 5
assert loaded.total_llm_calls == 50
def test_json_is_human_readable(self, tmp_path: Path) -> None:
"""Checkpoint file is valid, pretty-printed JSON."""
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
state = OptimizationState(
iteration=3,
best_candidate=Candidate(prompt=Prompt(text="hello"), best_score=0.8),
)
ckpt.save(state)
raw = json.loads((tmp_path / "latest.json").read_text())
assert raw["schema_version"] == 1
assert raw["iteration"] == 3
assert raw["best_candidate"]["prompt_text"] == "hello"
# ---------------------------------------------------------------------------
# EvolutionLoop — checkpoint integration
# ---------------------------------------------------------------------------
class TestEvolutionCheckpoint:
@pytest.mark.asyncio
async def test_checkpoint_saved_on_interval(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
) -> None:
"""Checkpoint is saved every checkpoint_interval iterations."""
from unittest.mock import MagicMock
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
# All iterations accepted so checkpoint triggers
good_eval = EvalResult(
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
feedbacks=["ok"] * 5,
trajectories=[
Trajectory(f"input{i}", f"out{i}", s, "ok", "p")
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
],
)
better_eval = EvalResult(
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
feedbacks=["good"] * 5,
trajectories=[],
)
# initial_eval + 5 iterations (each needs old_eval + new_eval)
evaluator.evaluate = AsyncMock(
side_effect=[good_eval] # initial
+ [good_eval, better_eval] * 5 # 5 iterations
)
proposer = AsyncMock()
proposer.propose.return_value = Prompt(text="improved prompt")
checkpoint_port = MagicMock()
loop = EvolutionLoop(
evaluator=evaluator,
proposer=proposer,
bootstrap=bootstrap,
max_iterations=5,
minibatch_size=5,
checkpoint_port=checkpoint_port,
checkpoint_interval=2,
)
await loop.run(seed_prompt, synthetic_pool, task_description)
# Checkpoint at iterations 2, 4 (every 2nd)
save_calls = checkpoint_port.save.call_count
assert save_calls >= 2 # at least at iters 2 and 4
@pytest.mark.asyncio
async def test_no_checkpoint_without_port(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
) -> None:
"""No checkpointing happens when checkpoint_port is None (default)."""
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
perfect_eval = EvalResult(
scores=[1.0] * 5,
feedbacks=["perfect"] * 5,
trajectories=[
Trajectory(f"in{i}", f"out{i}", 1.0, "perfect", "p")
for i in range(5)
],
)
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
loop = EvolutionLoop(
evaluator=evaluator,
proposer=AsyncMock(),
bootstrap=bootstrap,
max_iterations=3,
minibatch_size=5,
checkpoint_port=None,
)
# Should run without error — no checkpoint port, no crash
await loop.run(seed_prompt, synthetic_pool, task_description)
@pytest.mark.asyncio
async def test_resume_skips_seed_evaluation(
self,
synthetic_pool: list[SyntheticExample],
task_description: str,
) -> None:
"""When initial_state is provided, seed eval is skipped and loop starts from saved iteration."""
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
proposer = AsyncMock()
proposer.propose.return_value = Prompt(text="new prompt")
# Only return evaluations for resumed iterations (1 iter: old_eval + new_eval)
old_eval = EvalResult(
scores=[0.5] * 5,
feedbacks=["ok"] * 5,
trajectories=[
Trajectory(f"in{i}", f"out{i}", 0.5, "ok", "p") for i in range(5)
],
)
new_eval = EvalResult(
scores=[0.8] * 5,
feedbacks=["good"] * 5,
trajectories=[],
)
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
# Create a state simulating checkpoint at iteration 4
initial_state = OptimizationState(
iteration=4,
best_candidate=Candidate(
prompt=Prompt(text="checkpoint prompt"), best_score=2.5, generation=4
),
candidates=[Candidate(prompt=Prompt(text="checkpoint prompt"), best_score=2.5)],
total_llm_calls=40,
)
loop = EvolutionLoop(
evaluator=evaluator,
proposer=proposer,
bootstrap=bootstrap,
max_iterations=5, # only iteration 5 remains
minibatch_size=5,
)
state = await loop.run(
seed_prompt=Prompt(text="seed"),
synthetic_pool=synthetic_pool,
task_description=task_description,
initial_state=initial_state,
)
# Should have run only 1 iteration (iter 5)
assert state.iteration == 5
# total_llm_calls should include the 40 from checkpoint + new calls
assert state.total_llm_calls > 40
@pytest.mark.asyncio
async def test_full_save_and_resume_roundtrip(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
tmp_path: Path,
) -> None:
"""End-to-end: run a few iterations, checkpoint, resume, finish."""
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
old_eval = EvalResult(
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
feedbacks=["ok"] * 5,
trajectories=[
Trajectory(f"in{i}", f"out{i}", s, "ok", "p")
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
],
)
new_eval = EvalResult(
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
feedbacks=["good"] * 5,
trajectories=[],
)
evaluator.evaluate = AsyncMock(
side_effect=[old_eval, old_eval, new_eval, old_eval, new_eval]
)
proposer = AsyncMock()
proposer.propose.return_value = Prompt(text="improved prompt")
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
loop = EvolutionLoop(
evaluator=evaluator,
proposer=proposer,
bootstrap=bootstrap,
max_iterations=2,
minibatch_size=5,
checkpoint_port=ckpt,
checkpoint_interval=1,
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
assert state.iteration == 2
assert ckpt.latest_exists()
# Capture the checkpoint state *before* resume (state is mutated in-place)
loaded = ckpt.load()
assert loaded is not None
saved_llm_calls = loaded.total_llm_calls
saved_iteration = loaded.iteration
# Set up evaluator for resumed run (just 1 more iteration)
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
proposer.propose.return_value = Prompt(text="even better prompt")
loop2 = EvolutionLoop(
evaluator=evaluator,
proposer=proposer,
bootstrap=bootstrap,
max_iterations=3,
minibatch_size=5,
checkpoint_port=ckpt,
checkpoint_interval=1,
)
resumed = await loop2.run(
seed_prompt, synthetic_pool, task_description,
initial_state=loaded,
)
assert resumed.iteration == 3
assert resumed.total_llm_calls > saved_llm_calls
assert resumed.iteration > saved_iteration

278
tests/unit/test_cli.py Normal file
View File

@@ -0,0 +1,278 @@
"""Tests for the CLI interface — prometheus optimize, version, etc.
Uses Typer's CliRunner for isolated command testing.
"""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import yaml
from typer.testing import CliRunner
from prometheus.application.dto import OptimizationResult
from prometheus.cli.app import app
runner = CliRunner()
class TestCLIOptimize:
"""Tests for the `prometheus optimize` command."""
def _write_config(self, tmp_path: Path, **overrides: object) -> Path:
"""Write a minimal valid config YAML and return its path."""
data = {
"seed_prompt": "You are a helpful assistant.",
"task_description": "Answer factual questions accurately.",
}
data.update(overrides)
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
return config_file
def test_optimize_with_valid_config(self, tmp_path: Path) -> None:
config_file = self._write_config(tmp_path)
output_file = tmp_path / "output.yaml"
mock_result = OptimizationResult(
optimized_prompt="Improved prompt",
initial_prompt="You are a helpful assistant.",
iterations_used=5,
total_llm_calls=50,
initial_score=0.3,
final_score=0.9,
improvement=0.6,
history=[],
)
mock_uc = AsyncMock()
mock_uc.execute.return_value = mock_result
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
mock_llm_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
mock_judge_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
mock_prop_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.dspy"):
result = runner.invoke(
app,
[
"optimize",
"-i",
str(config_file),
"-o",
str(output_file),
],
)
assert result.exit_code == 0
assert "Optimized Prompt" in result.output
def test_optimize_missing_input_file(self) -> None:
result = runner.invoke(
app,
["optimize", "-i", "/nonexistent/config.yaml"],
)
assert result.exit_code != 0
def test_optimize_with_verbose_flag(self, tmp_path: Path) -> None:
config_file = self._write_config(tmp_path)
output_file = tmp_path / "output.yaml"
mock_result = OptimizationResult(
optimized_prompt="Improved",
initial_prompt="test",
iterations_used=1,
total_llm_calls=10,
initial_score=0.3,
final_score=0.8,
improvement=0.5,
history=[],
)
mock_uc = AsyncMock()
mock_uc.execute.return_value = mock_result
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
mock_llm_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
mock_judge_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
mock_prop_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.dspy"):
result = runner.invoke(
app,
[
"optimize",
"-i",
str(config_file),
"-o",
str(output_file),
"-v",
],
)
assert result.exit_code == 0
def test_optimize_displays_metrics(self, tmp_path: Path) -> None:
config_file = self._write_config(tmp_path)
output_file = tmp_path / "output.yaml"
mock_result = OptimizationResult(
optimized_prompt="Better prompt",
initial_prompt="test",
iterations_used=3,
total_llm_calls=30,
initial_score=0.40,
final_score=0.85,
improvement=0.45,
history=[],
)
mock_uc = AsyncMock()
mock_uc.execute.return_value = mock_result
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
mock_llm_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
mock_judge_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
mock_prop_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.dspy"):
result = runner.invoke(
app,
[
"optimize",
"-i",
str(config_file),
"-o",
str(output_file),
],
)
assert result.exit_code == 0
assert "0.40" in result.output
assert "0.85" in result.output
assert "+0.45" in result.output
def test_optimize_with_max_concurrency_flag(self, tmp_path: Path) -> None:
config_file = self._write_config(tmp_path)
output_file = tmp_path / "output.yaml"
mock_result = OptimizationResult(
optimized_prompt="Better prompt",
initial_prompt="test",
iterations_used=1,
total_llm_calls=10,
initial_score=0.3,
final_score=0.8,
improvement=0.5,
history=[],
)
mock_uc = AsyncMock()
mock_uc.execute.return_value = mock_result
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
mock_llm_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
mock_judge_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
mock_prop_cls.return_value = MagicMock()
with patch("prometheus.cli.commands.optimize.dspy"):
result = runner.invoke(
app,
[
"optimize",
"-i",
str(config_file),
"-o",
str(output_file),
"--max-concurrency",
"10",
],
)
assert result.exit_code == 0
class TestCLIHelp:
"""Tests for CLI help and no-args behavior."""
def test_no_args_shows_help(self) -> None:
result = runner.invoke(app, [])
# Typer uses exit code 2 when no_args_is_help=True
assert result.exit_code in (0, 2)
assert "PROMETHEUS" in result.output or "Usage" in result.output
def test_optimize_help(self) -> None:
result = runner.invoke(app, ["optimize", "--help"])
assert result.exit_code == 0
assert "input" in result.output.lower() or "INPUT" in result.output
def test_version_help(self) -> None:
result = runner.invoke(app, ["version", "--help"])
assert result.exit_code == 0
def test_init_help(self) -> None:
result = runner.invoke(app, ["init", "--help"])
assert result.exit_code == 0
def test_list_help(self) -> None:
result = runner.invoke(app, ["list", "--help"])
assert result.exit_code == 0
class TestCLIVersion:
"""Tests for the `prometheus version` command."""
def test_version_prints_version(self) -> None:
result = runner.invoke(app, ["version"])
assert result.exit_code == 0
assert "PROMETHEUS" in result.output
assert "0.1.0" in result.output
class TestCLIList:
"""Tests for the `prometheus list` command."""
def test_list_no_runs(self, tmp_path: Path) -> None:
result = runner.invoke(app, ["list", "-d", str(tmp_path)])
assert result.exit_code == 0
assert "No optimization runs found" in result.output
def test_list_with_result(self, tmp_path: Path) -> None:
result_data = {
"optimized_prompt": "Better prompt for testing",
"initial_prompt": "test",
"iterations_used": 5,
"total_llm_calls": 50,
"initial_score": 0.30,
"final_score": 0.90,
"improvement": 0.60,
"history": [],
}
result_file = tmp_path / "output.yaml"
import yaml as _yaml
with open(result_file, "w") as f:
_yaml.dump(result_data, f)
result = runner.invoke(app, ["list", "-d", str(tmp_path)])
assert result.exit_code == 0
assert "0.30" in result.output
assert "0.90" in result.output
def test_list_nonexistent_directory(self) -> None:
result = runner.invoke(app, ["list", "-d", "/nonexistent/dir"])
assert result.exit_code == 1

332
tests/unit/test_config.py Normal file
View File

@@ -0,0 +1,332 @@
"""Unit tests for config and config loading via CLI.
Tests config validation scenarios: missing fields, wrong types, defaults.
"""
from __future__ import annotations
from pathlib import Path
import pytest
import yaml
from pydantic import ValidationError
from prometheus.application.dto import OptimizationConfig
from prometheus.infrastructure.file_io import YamlPersistence
class TestOptimizationConfig:
"""Tests for OptimizationConfig Pydantic model defaults."""
def test_default_values(self) -> None:
config = OptimizationConfig(
seed_prompt="test prompt",
task_description="test task",
)
assert config.task_model == "openai/gpt-4o-mini"
assert config.judge_model == "openai/gpt-4o"
assert config.proposer_model == "openai/gpt-4o"
assert config.synth_model == "openai/gpt-4o"
assert config.max_iterations == 30
assert config.n_synthetic_inputs == 20
assert config.minibatch_size == 5
assert config.perfect_score == 1.0
assert config.seed == 42
assert config.output_path == "output.yaml"
assert config.verbose is False
assert config.error_strategy == "retry"
assert config.max_retries == 3
assert config.max_concurrency == 5
def test_custom_values(self) -> None:
config = OptimizationConfig(
seed_prompt="custom prompt",
task_description="custom task",
max_iterations=100,
minibatch_size=10,
seed=123,
verbose=True,
)
assert config.max_iterations == 100
assert config.minibatch_size == 10
assert config.seed == 123
assert config.verbose is True
def test_roundtrip_to_dict(self) -> None:
config = OptimizationConfig(
seed_prompt="test",
task_description="task",
)
d = config.model_dump()
assert d["seed_prompt"] == "test"
assert d["task_description"] == "task"
assert "history" not in d # OptimizationResult has history, not config
def test_config_version_defaults(self) -> None:
config = OptimizationConfig(seed_prompt="a", task_description="b")
assert config.config_version == 1
def test_config_version_stamps_current(self) -> None:
config = OptimizationConfig(
seed_prompt="a", task_description="b", config_version=0,
)
# Migration always stamps current version
assert config.config_version == 1
class TestConfigLoading:
"""Tests for loading OptimizationConfig from YAML via YamlPersistence."""
def test_minimal_config_loads(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
data = {
"seed_prompt": "You are helpful.",
"task_description": "Answer questions.",
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
raw = persistence.read_config(str(config_file))
config = OptimizationConfig.model_validate(raw)
assert config.seed_prompt == "You are helpful."
assert config.task_description == "Answer questions."
assert config.max_iterations == 30 # default
def test_full_config_loads(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
data = {
"seed_prompt": "You are helpful.",
"task_description": "Answer questions.",
"task_model": "openai/gpt-4o",
"judge_model": "openai/gpt-4o-mini",
"max_iterations": 50,
"n_synthetic_inputs": 30,
"minibatch_size": 8,
"seed": 99,
"verbose": True,
}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
raw = persistence.read_config(str(config_file))
config = OptimizationConfig.model_validate(raw)
assert config.task_model == "openai/gpt-4o"
assert config.max_iterations == 50
assert config.verbose is True
def test_missing_seed_prompt_raises_validation_error(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
data = {"task_description": "Answer questions."}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
raw = persistence.read_config(str(config_file))
with pytest.raises(ValidationError, match="seed_prompt"):
OptimizationConfig.model_validate(raw)
def test_missing_task_description_raises_validation_error(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
data = {"seed_prompt": "You are helpful."}
config_file = tmp_path / "config.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
raw = persistence.read_config(str(config_file))
with pytest.raises(ValidationError, match="task_description"):
OptimizationConfig.model_validate(raw)
def test_empty_yaml_raises_on_required_fields(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
config_file = tmp_path / "empty.yaml"
config_file.write_text("{}", encoding="utf-8")
raw = persistence.read_config(str(config_file))
with pytest.raises(ValidationError):
OptimizationConfig.model_validate(raw)
def test_partial_config_uses_defaults(self, tmp_path: Path) -> None:
persistence = YamlPersistence()
data = {
"seed_prompt": "test",
"task_description": "task",
"max_iterations": 10,
}
config_file = tmp_path / "partial.yaml"
with open(config_file, "w") as f:
yaml.dump(data, f)
raw = persistence.read_config(str(config_file))
config = OptimizationConfig.model_validate(raw)
assert config.max_iterations == 10
assert config.n_synthetic_inputs == 20 # default
assert config.minibatch_size == 5 # default
class TestConfigValidation:
"""Tests for Pydantic validation edge cases."""
def test_wrong_type_max_iterations(self) -> None:
with pytest.raises(ValidationError, match="max_iterations"):
OptimizationConfig(
seed_prompt="a",
task_description="b",
max_iterations="not_a_number", # type: ignore[arg-type]
)
def test_wrong_type_seed(self) -> None:
with pytest.raises(ValidationError, match="seed"):
OptimizationConfig(
seed_prompt="a",
task_description="b",
seed="not_a_number", # type: ignore[arg-type]
)
def test_negative_max_iterations(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 1"):
OptimizationConfig(
seed_prompt="a", task_description="b", max_iterations=0,
)
def test_negative_seed(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 0"):
OptimizationConfig(
seed_prompt="a", task_description="b", seed=-1,
)
def test_zero_minibatch_size(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 1"):
OptimizationConfig(
seed_prompt="a", task_description="b", minibatch_size=0,
)
def test_zero_max_concurrency(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 1"):
OptimizationConfig(
seed_prompt="a", task_description="b", max_concurrency=0,
)
def test_negative_max_retries(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 0"):
OptimizationConfig(
seed_prompt="a", task_description="b", max_retries=-1,
)
def test_zero_retry_delay_base(self) -> None:
with pytest.raises(ValidationError, match="greater than 0"):
OptimizationConfig(
seed_prompt="a", task_description="b", retry_delay_base=0,
)
def test_negative_retry_delay_base(self) -> None:
with pytest.raises(ValidationError, match="greater than 0"):
OptimizationConfig(
seed_prompt="a", task_description="b", retry_delay_base=-0.5,
)
def test_zero_circuit_breaker_threshold(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 1"):
OptimizationConfig(
seed_prompt="a", task_description="b", circuit_breaker_threshold=0,
)
def test_perfect_score_above_one(self) -> None:
with pytest.raises(ValidationError, match="less than or equal to 1"):
OptimizationConfig(
seed_prompt="a", task_description="b", perfect_score=1.5,
)
def test_perfect_score_negative(self) -> None:
with pytest.raises(ValidationError, match="greater than or equal to 0"):
OptimizationConfig(
seed_prompt="a", task_description="b", perfect_score=-0.1,
)
def test_invalid_error_strategy(self) -> None:
with pytest.raises(ValidationError, match="error_strategy must be one of"):
OptimizationConfig(
seed_prompt="a", task_description="b", error_strategy="invalid",
)
def test_valid_error_strategies(self) -> None:
for strategy in ("skip", "retry", "abort"):
config = OptimizationConfig(
seed_prompt="a", task_description="b", error_strategy=strategy,
)
assert config.error_strategy == strategy
def test_empty_seed_prompt(self) -> None:
with pytest.raises(ValidationError, match="seed_prompt"):
OptimizationConfig(seed_prompt="", task_description="b")
def test_empty_task_description(self) -> None:
with pytest.raises(ValidationError, match="task_description"):
OptimizationConfig(seed_prompt="a", task_description="")
def test_empty_model_string(self) -> None:
with pytest.raises(ValidationError, match="task_model"):
OptimizationConfig(
seed_prompt="a", task_description="b", task_model="",
)
def test_extra_fields_rejected(self) -> None:
with pytest.raises(ValidationError, match="extra"):
OptimizationConfig(
seed_prompt="a",
task_description="b",
nonexistent_field="value", # type: ignore[call-arg]
)
def test_boundary_values_accepted(self) -> None:
config = OptimizationConfig(
seed_prompt="a",
task_description="b",
max_iterations=1,
n_synthetic_inputs=1,
minibatch_size=1,
max_concurrency=1,
max_retries=0,
retry_delay_base=0.001,
circuit_breaker_threshold=1,
perfect_score=0.0,
seed=0,
)
assert config.max_iterations == 1
assert config.perfect_score == 0.0
class TestEvalConfigValidation:
"""Tests for ground-truth evaluation config fields."""
def test_eval_defaults(self) -> None:
config = OptimizationConfig(seed_prompt="a", task_description="b")
assert config.eval_dataset_path is None
assert config.eval_metric == "bleu"
def test_eval_dataset_path_set(self) -> None:
config = OptimizationConfig(
seed_prompt="a", task_description="b",
eval_dataset_path="data.csv",
)
assert config.eval_dataset_path == "data.csv"
def test_valid_eval_metrics(self) -> None:
for metric in ("exact", "bleu", "rouge_l", "cosine", "llm_judge"):
config = OptimizationConfig(
seed_prompt="a", task_description="b", eval_metric=metric,
)
assert config.eval_metric == metric
def test_invalid_eval_metric_raises(self) -> None:
with pytest.raises(ValidationError, match="eval_metric must be one of"):
OptimizationConfig(
seed_prompt="a", task_description="b",
eval_metric="invalid_metric",
)

View File

@@ -0,0 +1,86 @@
"""Tests for the ground-truth dataset loader."""
from __future__ import annotations
import json
import os
import tempfile
import pytest
from prometheus.domain.entities import GroundTruthExample
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
@pytest.fixture
def loader():
return FileDatasetLoader()
class TestCsvLoader:
def test_load_csv(self, loader, tmp_path):
csv_file = tmp_path / "test.csv"
csv_file.write_text("input,expected_output\nhello,world\nfoo,bar\n")
result = loader.load(str(csv_file))
assert len(result) == 2
assert result[0].input_text == "hello"
assert result[0].expected_output == "world"
assert result[1].input_text == "foo"
assert result[1].expected_output == "bar"
def test_load_csv_skips_empty_input(self, loader, tmp_path):
csv_file = tmp_path / "test.csv"
csv_file.write_text("input,expected_output\n,bar\nhello,world\n")
result = loader.load(str(csv_file))
assert len(result) == 1
assert result[0].input_text == "hello"
def test_load_csv_with_whitespace(self, loader, tmp_path):
csv_file = tmp_path / "test.csv"
csv_file.write_text("input,expected_output\n hello , world \n")
result = loader.load(str(csv_file))
assert result[0].input_text == "hello"
assert result[0].expected_output == "world"
def test_load_csv_empty_file(self, loader, tmp_path):
csv_file = tmp_path / "test.csv"
csv_file.write_text("input,expected_output\n")
result = loader.load(str(csv_file))
assert len(result) == 0
class TestJsonLoader:
def test_load_json(self, loader, tmp_path):
json_file = tmp_path / "test.json"
data = [
{"input": "hello", "expected_output": "world"},
{"input": "foo", "expected_output": "bar"},
]
json_file.write_text(json.dumps(data))
result = loader.load(str(json_file))
assert len(result) == 2
assert result[0].input_text == "hello"
assert result[0].expected_output == "world"
def test_load_json_skips_empty_input(self, loader, tmp_path):
json_file = tmp_path / "test.json"
data = [
{"input": "", "expected_output": "bar"},
{"input": "hello", "expected_output": "world"},
]
json_file.write_text(json.dumps(data))
result = loader.load(str(json_file))
assert len(result) == 1
def test_load_json_not_array_raises(self, loader, tmp_path):
json_file = tmp_path / "test.json"
json_file.write_text(json.dumps({"not": "an array"}))
with pytest.raises(ValueError, match="must be an array"):
loader.load(str(json_file))
class TestUnsupportedFormat:
def test_unsupported_extension_raises(self, loader, tmp_path):
txt_file = tmp_path / "test.txt"
txt_file.write_text("hello")
with pytest.raises(ValueError, match="Unsupported dataset format"):
loader.load(str(txt_file))

View File

@@ -0,0 +1,313 @@
"""Unit tests for error handling: retry, circuit breaker, per-call isolation."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.evaluator import PromptEvaluator
from prometheus.application.evolution import CircuitBreakerOpen, EvolutionLoop
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory
from prometheus.infrastructure.retry import is_transient_error, retry_with_backoff
# ---------------------------------------------------------------------------
# Retry utility
# ---------------------------------------------------------------------------
class TestIsTransientError:
def test_rate_limit_429(self):
assert is_transient_error(RuntimeError("HTTP 429: rate limit exceeded"))
def test_server_error_503(self):
assert is_transient_error(RuntimeError("503 Service Unavailable"))
def test_timeout(self):
assert is_transient_error(TimeoutError("request timed out"))
def test_connection_error(self):
assert is_transient_error(ConnectionError("connection refused"))
def test_non_transient(self):
assert not is_transient_error(ValueError("bad input"))
def test_os_error(self):
assert is_transient_error(OSError("network unreachable"))
class TestRetryWithBackoff:
def test_succeeds_on_first_try(self):
fn = MagicMock(return_value="ok")
result = retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
assert result == "ok"
assert fn.call_count == 1
def test_retries_on_transient_then_succeeds(self):
fn = MagicMock(
side_effect=[
RuntimeError("429 rate limit"),
RuntimeError("429 rate limit"),
"ok",
]
)
with patch("prometheus.infrastructure.retry.time.sleep"):
result = retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
assert result == "ok"
assert fn.call_count == 3
def test_raises_after_max_retries(self):
fn = MagicMock(side_effect=RuntimeError("503 overloaded"))
with patch("prometheus.infrastructure.retry.time.sleep"):
with pytest.raises(RuntimeError, match="503"):
retry_with_backoff(fn, max_retries=2, retry_delay_base=0)
assert fn.call_count == 3 # 1 initial + 2 retries
def test_non_transient_not_retried(self):
fn = MagicMock(side_effect=ValueError("bad"))
with pytest.raises(ValueError, match="bad"):
retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
assert fn.call_count == 1
def test_exponential_backoff_delays(self):
fn = MagicMock(side_effect=[RuntimeError("timeout"), "ok"])
with patch("prometheus.infrastructure.retry.time.sleep") as mock_sleep:
retry_with_backoff(fn, max_retries=3, retry_delay_base=2.0)
mock_sleep.assert_called_once_with(2.0) # 2.0 * 2^0 = 2.0
# ---------------------------------------------------------------------------
# Circuit breaker (EvolutionLoop)
# ---------------------------------------------------------------------------
def _make_eval_result(scores, feedbacks=None):
"""Helper to create EvalResult with matching trajectories."""
feedbacks = feedbacks or ["ok"] * len(scores)
return EvalResult(
scores=scores,
feedbacks=feedbacks,
trajectories=[
Trajectory(f"input{i}", f"output{i}", s, f, "prompt")
for i, (s, f) in enumerate(zip(scores, feedbacks))
],
)
class TestCircuitBreaker:
@pytest.mark.asyncio
async def test_trips_on_consecutive_failures(self):
"""Loop stops when consecutive failures reach the threshold."""
initial_eval = _make_eval_result([0.3, 0.4])
evaluator = MagicMock()
call_count = 0
def _evaluate(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return initial_eval # seed eval succeeds
raise RuntimeError("LLM down")
evaluator.evaluate = AsyncMock(side_effect=_evaluate)
proposer = MagicMock()
proposer.propose = AsyncMock()
bootstrap = MagicMock(spec=SyntheticBootstrap)
loop = EvolutionLoop(
evaluator=evaluator,
proposer=proposer,
bootstrap=bootstrap,
max_iterations=10,
minibatch_size=2,
circuit_breaker_threshold=3,
error_strategy="skip",
)
state = await loop.run(
Prompt("test"),
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
"task",
)
error_events = [h for h in state.history if h.get("event") == "error"]
cb_events = [h for h in state.history if h.get("event") == "circuit_breaker"]
assert len(error_events) == 3
assert len(cb_events) == 1
assert state.iteration < 10 # stopped early
@pytest.mark.asyncio
async def test_abort_raises_on_first_error(self):
"""With error_strategy=abort, the first error raises immediately."""
initial_eval = _make_eval_result([0.3, 0.4])
evaluator = MagicMock()
call_count = 0
def _evaluate(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return initial_eval
raise RuntimeError("LLM down")
evaluator.evaluate = AsyncMock(side_effect=_evaluate)
proposer = MagicMock()
proposer.propose = AsyncMock()
bootstrap = MagicMock(spec=SyntheticBootstrap)
loop = EvolutionLoop(
evaluator=evaluator,
proposer=proposer,
bootstrap=bootstrap,
max_iterations=10,
minibatch_size=2,
circuit_breaker_threshold=3,
error_strategy="abort",
)
with pytest.raises(RuntimeError, match="LLM down"):
await loop.run(
Prompt("test"),
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
"task",
)
@pytest.mark.asyncio
async def test_resets_on_success(self):
"""Consecutive failure counter resets after a successful iteration."""
initial_eval = _make_eval_result([0.3, 0.4])
good_eval = _make_eval_result([0.8, 0.9])
evaluator = MagicMock()
call_count = 0
def _evaluate(*args, **kwargs):
nonlocal call_count
call_count += 1
# call 1: seed eval → succeed
# call 2: iter 1 current eval → fail
# call 3: iter 2 current eval → fail
# call 4: iter 3 current eval → succeed (returns initial_eval)
# call 5: iter 3 new eval → succeed (returns good_eval, accepted)
# call 6+: iter 4+ current eval → succeed
if call_count == 1:
return initial_eval
if call_count in (2, 3):
raise RuntimeError("timeout")
if call_count % 2 == 0:
return initial_eval # current eval
return good_eval # new eval
evaluator.evaluate = AsyncMock(side_effect=_evaluate)
proposer = MagicMock()
proposer.propose = AsyncMock(return_value=Prompt("better prompt"))
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = [
SyntheticExample(f"in{i}", id=i) for i in range(2)
]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=proposer,
bootstrap=bootstrap,
max_iterations=6,
minibatch_size=2,
circuit_breaker_threshold=3,
error_strategy="skip",
)
state = await loop.run(
Prompt("test"),
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
"task",
)
# Should NOT have tripped — 2 fails, then success reset the counter
cb_events = [h for h in state.history if h.get("event") == "circuit_breaker"]
assert len(cb_events) == 0
assert state.iteration == 6 # ran all iterations
# ---------------------------------------------------------------------------
# Per-call isolation (Evaluator)
# ---------------------------------------------------------------------------
class TestPerCallIsolation:
@pytest.mark.asyncio
async def test_evaluator_isolates_execution_failure(self):
"""A failing execution produces a sentinel output, not a crash."""
executor = MagicMock()
executor.execute = AsyncMock(side_effect=[
"good output",
RuntimeError("API error"),
"another good output",
])
judge = MagicMock()
judge.judge_batch = AsyncMock(return_value=[
(0.8, "good"),
(0.0, "[judge error]"),
(0.7, "ok"),
])
evaluator = PromptEvaluator(executor, judge)
result = await evaluator.evaluate(
Prompt("test"),
[
SyntheticExample("in0", id=0),
SyntheticExample("in1", id=1),
SyntheticExample("in2", id=2),
],
"task",
)
assert len(result.scores) == 3
assert result.scores[1] == 0.0 # failed item got zero score
assert "execution error" in result.trajectories[1].output_text
assert result.scores[0] == 0.8 # other items unaffected
@pytest.mark.asyncio
async def test_judge_adapter_isolates_single_failure(self):
"""DSPyJudgeAdapter returns sentinel for a failed item, not crash."""
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
adapter = DSPyJudgeAdapter.__new__(DSPyJudgeAdapter)
adapter._lm = MagicMock()
adapter._max_retries = 1
adapter._retry_delay_base = 0
adapter._semaphore = __import__("asyncio").Semaphore(5)
adapter._judge_criteria = ""
adapter._judge_dimensions = []
adapter._dimension_names = ""
adapter._weights = {}
adapter.call_count = 0
# Mock _judge to fail on first call, succeed on second
call_count = 0
class FakePred:
def __init__(self):
self.score = 0.9
self.feedback = "good"
def fake_judge(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise RuntimeError("judge failure")
return FakePred()
adapter._judge = fake_judge
with patch("prometheus.infrastructure.judge_adapter.dspy.context"):
with patch(
"prometheus.infrastructure.retry.asyncio.sleep",
new=AsyncMock(),
):
results = await adapter.judge_batch(
"task", [("input1", "output1"), ("input2", "output2")]
)
assert len(results) == 2
# First item failed even after retry → sentinel
assert results[0] == (0.0, "[judge error: judge failure]")
# Second item succeeded
assert results[1] == (0.9, "good")

View File

@@ -1,7 +1,7 @@
"""Unit tests for PromptEvaluator.evaluate().""" """Unit tests for PromptEvaluator.evaluate()."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@@ -14,22 +14,23 @@ class TestPromptEvaluatorEvaluate:
"""Tests for the evaluate() pipeline: execute → judge → trajectories.""" """Tests for the evaluate() pipeline: execute → judge → trajectories."""
@pytest.fixture @pytest.fixture
def executor(self) -> MagicMock: def executor(self) -> AsyncMock:
return MagicMock(spec=LLMPort) return AsyncMock(spec=LLMPort)
@pytest.fixture @pytest.fixture
def judge(self) -> MagicMock: def judge(self) -> AsyncMock:
return MagicMock(spec=JudgePort) return AsyncMock(spec=JudgePort)
@pytest.fixture @pytest.fixture
def evaluator(self, executor: MagicMock, judge: MagicMock) -> PromptEvaluator: def evaluator(self, executor: AsyncMock, judge: AsyncMock) -> PromptEvaluator:
return PromptEvaluator(executor=executor, judge=judge) return PromptEvaluator(executor=executor, judge=judge)
def test_happy_path_builds_correct_trajectories( @pytest.mark.asyncio
async def test_happy_path_builds_correct_trajectories(
self, self,
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
executor: MagicMock, executor: AsyncMock,
judge: MagicMock, judge: AsyncMock,
) -> None: ) -> None:
prompt = Prompt(text="Answer the question.") prompt = Prompt(text="Answer the question.")
examples = [ examples = [
@@ -42,7 +43,7 @@ class TestPromptEvaluatorEvaluate:
(0.8, "Mostly correct."), (0.8, "Mostly correct."),
] ]
result = evaluator.evaluate(prompt, examples, "math and geography") result = await evaluator.evaluate(prompt, examples, "math and geography")
assert isinstance(result, EvalResult) assert isinstance(result, EvalResult)
assert result.scores == [0.9, 0.8] assert result.scores == [0.9, 0.8]
@@ -55,14 +56,15 @@ class TestPromptEvaluatorEvaluate:
assert result.trajectories[0].prompt_used == "Answer the question." assert result.trajectories[0].prompt_used == "Answer the question."
assert result.trajectories[1].prompt_used == "Answer the question." assert result.trajectories[1].prompt_used == "Answer the question."
def test_empty_minibatch_returns_empty_result( @pytest.mark.asyncio
async def test_empty_minibatch_returns_empty_result(
self, self,
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
executor: MagicMock, executor: AsyncMock,
judge: MagicMock, judge: AsyncMock,
) -> None: ) -> None:
prompt = Prompt(text="test") prompt = Prompt(text="test")
result = evaluator.evaluate(prompt, [], "task") result = await evaluator.evaluate(prompt, [], "task")
assert result.scores == [] assert result.scores == []
assert result.feedbacks == [] assert result.feedbacks == []
@@ -71,41 +73,44 @@ class TestPromptEvaluatorEvaluate:
# judge_batch is called with empty pairs list # judge_batch is called with empty pairs list
judge.judge_batch.assert_called_once_with("task", []) judge.judge_batch.assert_called_once_with("task", [])
def test_executor_called_with_correct_prompt( @pytest.mark.asyncio
async def test_executor_called_with_correct_prompt(
self, self,
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
executor: MagicMock, executor: AsyncMock,
judge: MagicMock, judge: AsyncMock,
) -> None: ) -> None:
prompt = Prompt(text="Summarize this.") prompt = Prompt(text="Summarize this.")
examples = [SyntheticExample(input_text="Long text here", id=0)] examples = [SyntheticExample(input_text="Long text here", id=0)]
executor.execute.return_value = "Summary." executor.execute.return_value = "Summary."
judge.judge_batch.return_value = [(0.7, "Good summary.")] judge.judge_batch.return_value = [(0.7, "Good summary.")]
evaluator.evaluate(prompt, examples, "summarization") await evaluator.evaluate(prompt, examples, "summarization")
executor.execute.assert_called_once_with(prompt, "Long text here") executor.execute.assert_called_once_with(prompt, "Long text here")
def test_trajectories_prompt_used_matches_input_prompt( @pytest.mark.asyncio
async def test_trajectories_prompt_used_matches_input_prompt(
self, self,
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
executor: MagicMock, executor: AsyncMock,
judge: MagicMock, judge: AsyncMock,
) -> None: ) -> None:
prompt = Prompt(text="Translate to French.") prompt = Prompt(text="Translate to French.")
examples = [SyntheticExample(input_text="Hello", id=0)] examples = [SyntheticExample(input_text="Hello", id=0)]
executor.execute.return_value = "Bonjour" executor.execute.return_value = "Bonjour"
judge.judge_batch.return_value = [(1.0, "Perfect.")] judge.judge_batch.return_value = [(1.0, "Perfect.")]
result = evaluator.evaluate(prompt, examples, "translation") result = await evaluator.evaluate(prompt, examples, "translation")
assert result.trajectories[0].prompt_used == "Translate to French." assert result.trajectories[0].prompt_used == "Translate to French."
def test_scores_feedbacks_trajectories_lists_sized_correctly( @pytest.mark.asyncio
async def test_scores_feedbacks_trajectories_lists_sized_correctly(
self, self,
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
executor: MagicMock, executor: AsyncMock,
judge: MagicMock, judge: AsyncMock,
) -> None: ) -> None:
prompt = Prompt(text="test prompt") prompt = Prompt(text="test prompt")
examples = [SyntheticExample(input_text=f"q{i}", id=i) for i in range(4)] examples = [SyntheticExample(input_text=f"q{i}", id=i) for i in range(4)]
@@ -114,7 +119,7 @@ class TestPromptEvaluatorEvaluate:
(0.1 * i, f"fb{i}") for i in range(4) (0.1 * i, f"fb{i}") for i in range(4)
] ]
result = evaluator.evaluate(prompt, examples, "task") result = await evaluator.evaluate(prompt, examples, "task")
assert len(result.scores) == 4 assert len(result.scores) == 4
assert len(result.feedbacks) == 4 assert len(result.feedbacks) == 4

View File

@@ -1,51 +1,55 @@
"""Unit tests for the evolution loop — with full mocking.""" """Unit tests for the evolution loop — with full mocking."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prometheus.application.bootstrap import SyntheticBootstrap from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.evaluator import PromptEvaluator from prometheus.application.evaluator import PromptEvaluator
from prometheus.application.evolution import EvolutionLoop from prometheus.application.evolution import EvolutionLoop
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory from prometheus.domain.entities import (
Candidate,
EvalResult,
Prompt,
SyntheticExample,
Trajectory,
)
def _make_eval(scores: list[float], label: str = "ok") -> EvalResult:
"""Helper to build an EvalResult from a list of scores."""
return EvalResult(
scores=scores,
feedbacks=[label] * len(scores),
trajectories=[
Trajectory(f"input{i}", f"output{i}", s, label, "prompt")
for i, s in enumerate(scores)
],
)
class TestEvolutionLoop: class TestEvolutionLoop:
def test_accepts_improvement( """Tests for the original single-candidate hill-climbing mode (population_size=1)."""
@pytest.mark.asyncio
async def test_accepts_improvement(
self, self,
seed_prompt: Prompt, seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample], synthetic_pool: list[SyntheticExample],
task_description: str, task_description: str,
mock_llm_port: MagicMock, mock_llm_port: AsyncMock,
mock_judge_port: MagicMock, mock_judge_port: AsyncMock,
mock_proposer_port: MagicMock, mock_proposer_port: AsyncMock,
) -> None: ) -> None:
"""When the new prompt improves the score, the best candidate is updated.""" """When the new prompt improves the score, the best candidate is updated."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5] bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
initial_eval = EvalResult( low_eval = _make_eval([0.3, 0.4, 0.3, 0.5, 0.2], "bad")
scores=[0.3, 0.4, 0.3, 0.5, 0.2], high_eval = _make_eval([0.8, 0.9, 0.7, 0.8, 0.9], "good")
feedbacks=["bad"] * 5, evaluator.evaluate = AsyncMock(side_effect=[low_eval, low_eval, high_eval])
trajectories=[
Trajectory(f"input{i}", f"output{i}", s, "bad", "prompt")
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
],
)
old_eval = EvalResult(
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
feedbacks=["bad"] * 5,
trajectories=[
Trajectory(f"input{i}", f"output{i}", s, "bad", "prompt")
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
],
)
new_eval = EvalResult(
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
feedbacks=["good"] * 5,
trajectories=[],
)
evaluator.evaluate = MagicMock(side_effect=[initial_eval, old_eval, new_eval])
loop = EvolutionLoop( loop = EvolutionLoop(
evaluator=evaluator, evaluator=evaluator,
@@ -54,48 +58,29 @@ class TestEvolutionLoop:
max_iterations=1, max_iterations=1,
minibatch_size=5, minibatch_size=5,
) )
with patch.object(loop, "_log"): state = await loop.run(seed_prompt, synthetic_pool, task_description)
state = loop.run(seed_prompt, synthetic_pool, task_description)
assert state.best_candidate is not None assert state.best_candidate is not None
assert state.best_candidate.best_score > 0 assert state.best_candidate.best_score > 0
def test_rejects_regression( @pytest.mark.asyncio
async def test_rejects_regression(
self, self,
seed_prompt: Prompt, seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample], synthetic_pool: list[SyntheticExample],
task_description: str, task_description: str,
mock_llm_port: MagicMock, mock_llm_port: AsyncMock,
mock_judge_port: MagicMock, mock_judge_port: AsyncMock,
mock_proposer_port: MagicMock, mock_proposer_port: AsyncMock,
) -> None: ) -> None:
"""When the new prompt degrades the score, the best candidate stays unchanged.""" """When the new prompt degrades the score, the best candidate stays unchanged."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5] bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
initial_eval = EvalResult( high_eval = _make_eval([0.7, 0.8, 0.7, 0.8, 0.9], "ok")
scores=[0.7, 0.8, 0.7, 0.8, 0.9], low_eval = _make_eval([0.2, 0.1, 0.3, 0.2, 0.1], "bad")
feedbacks=["ok"] * 5, evaluator.evaluate = AsyncMock(side_effect=[high_eval, high_eval, low_eval])
trajectories=[
Trajectory(f"input{i}", f"output{i}", s, "ok", "prompt")
for i, s in enumerate([0.7, 0.8, 0.7, 0.8, 0.9])
],
)
old_eval = EvalResult(
scores=[0.7, 0.8, 0.7, 0.8, 0.9],
feedbacks=["ok"] * 5,
trajectories=[
Trajectory(f"input{i}", f"output{i}", s, "ok", "prompt")
for i, s in enumerate([0.7, 0.8, 0.7, 0.8, 0.9])
],
)
new_eval = EvalResult(
scores=[0.2, 0.1, 0.3, 0.2, 0.1],
feedbacks=["bad"] * 5,
trajectories=[],
)
evaluator.evaluate = MagicMock(side_effect=[initial_eval, old_eval, new_eval])
loop = EvolutionLoop( loop = EvolutionLoop(
evaluator=evaluator, evaluator=evaluator,
@@ -104,35 +89,28 @@ class TestEvolutionLoop:
max_iterations=1, max_iterations=1,
minibatch_size=5, minibatch_size=5,
) )
with patch.object(loop, "_log"): state = await loop.run(seed_prompt, synthetic_pool, task_description)
state = loop.run(seed_prompt, synthetic_pool, task_description)
assert state.best_candidate is not None assert state.best_candidate is not None
assert state.best_candidate.prompt.text == seed_prompt.text assert state.best_candidate.prompt.text == seed_prompt.text
def test_skips_perfect_scores( @pytest.mark.asyncio
async def test_skips_perfect_scores(
self, self,
seed_prompt: Prompt, seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample], synthetic_pool: list[SyntheticExample],
task_description: str, task_description: str,
mock_llm_port: MagicMock, mock_llm_port: AsyncMock,
mock_judge_port: MagicMock, mock_judge_port: AsyncMock,
mock_proposer_port: MagicMock, mock_proposer_port: AsyncMock,
) -> None: ) -> None:
"""When all scores are perfect, no proposition is made.""" """When all scores are perfect, no proposition is made."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5] bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
perfect_eval = EvalResult( perfect_eval = _make_eval([1.0, 1.0, 1.0, 1.0, 1.0], "perfect")
scores=[1.0, 1.0, 1.0, 1.0, 1.0], evaluator.evaluate = AsyncMock(return_value=perfect_eval)
feedbacks=["perfect"] * 5,
trajectories=[
Trajectory(f"input{i}", f"output{i}", 1.0, "perfect", "prompt")
for i in range(5)
],
)
evaluator.evaluate = MagicMock(return_value=perfect_eval)
loop = EvolutionLoop( loop = EvolutionLoop(
evaluator=evaluator, evaluator=evaluator,
@@ -141,7 +119,226 @@ class TestEvolutionLoop:
max_iterations=3, max_iterations=3,
minibatch_size=5, minibatch_size=5,
) )
with patch.object(loop, "_log"): await loop.run(seed_prompt, synthetic_pool, task_description)
loop.run(seed_prompt, synthetic_pool, task_description)
mock_proposer_port.propose.assert_not_called() mock_proposer_port.propose.assert_not_called()
class TestPopulationEvolution:
"""Tests for population-based evolution (population_size > 1)."""
@pytest.mark.asyncio
async def test_population_initialization(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
mock_llm_port: AsyncMock,
mock_judge_port: AsyncMock,
mock_proposer_port: AsyncMock,
mock_mutation_port: AsyncMock,
) -> None:
"""Population is initialized with the right number of candidates."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
evaluator.evaluate = AsyncMock(
return_value=_make_eval([0.5] * 5, "ok")
)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer_port,
bootstrap=bootstrap,
max_iterations=0, # no iterations, just initialization
minibatch_size=5,
population_size=4,
mutation_port=mock_mutation_port,
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
# 1 seed + 3 mutations = 4 candidates
assert len(state.candidates) == 4
assert mock_mutation_port.mutate.call_count == 3
@pytest.mark.asyncio
async def test_population_initialization_uses_proposer_fallback(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
mock_llm_port: AsyncMock,
mock_judge_port: AsyncMock,
mock_proposer_port: AsyncMock,
) -> None:
"""When no mutation_port is provided, population init falls back to proposer."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
evaluator.evaluate = AsyncMock(
return_value=_make_eval([0.5] * 5, "ok")
)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer_port,
bootstrap=bootstrap,
max_iterations=0,
minibatch_size=5,
population_size=3,
# mutation_port intentionally omitted
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
assert len(state.candidates) == 3
assert mock_proposer_port.propose.call_count == 2 # 3-1 = 2 init mutations
@pytest.mark.asyncio
async def test_population_iteration_replaces_worst(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
mock_llm_port: AsyncMock,
mock_judge_port: AsyncMock,
mock_proposer_port: AsyncMock,
mock_crossover_port: AsyncMock,
mock_mutation_port: AsyncMock,
) -> None:
"""Crossover child replaces worst candidate when its fitness is higher."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
# Sequence:
# 1. Initial eval (seed)
# 2. Population init: 3 mutation calls use proposer.propose(), NOT evaluator.evaluate
# 3. Population iteration: crossover produces child → eval child
# Only 2 evaluator.evaluate calls total
seed_eval = _make_eval([0.5] * 5, "ok")
# Crossover child eval - high score to beat worst
child_eval = _make_eval([0.9, 0.9, 0.8, 0.9, 0.8], "great")
all_evals = [seed_eval, child_eval]
evaluator.evaluate = AsyncMock(side_effect=all_evals)
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer_port,
bootstrap=bootstrap,
max_iterations=1,
minibatch_size=5,
population_size=4,
crossover_rate=1.0,
crossover_port=mock_crossover_port,
mutation_rate=0.0, # disable post-crossover mutation for determinism
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
accepted_events = [h for h in state.history if h.get("event") == "pop_accepted"]
assert len(accepted_events) >= 1
@pytest.mark.asyncio
async def test_population_iteration_rejects_inferior_child(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
mock_llm_port: AsyncMock,
mock_judge_port: AsyncMock,
mock_proposer_port: AsyncMock,
mock_crossover_port: AsyncMock,
) -> None:
"""Inferior child is rejected and doesn't replace any candidate."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
seed_eval = _make_eval([0.8] * 5, "ok")
# Crossover produces very LOW-scoring child
child_eval = _make_eval([0.1] * 5, "terrible")
all_evals = [seed_eval, child_eval]
evaluator.evaluate = AsyncMock(side_effect=all_evals)
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer_port,
bootstrap=bootstrap,
max_iterations=1,
minibatch_size=5,
population_size=4,
crossover_rate=1.0,
crossover_port=mock_crossover_port,
mutation_rate=0.0,
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
rejected_events = [h for h in state.history if h.get("event") == "pop_rejected"]
assert len(rejected_events) >= 1
class TestDiversityScore:
"""Tests for the diversity/similarity scoring logic."""
def test_identical_prompts_have_high_similarity(self) -> None:
"""Identical prompts should have very high similarity."""
identical = Prompt(text="You are a helpful assistant. Answer the question.")
pop_a = Candidate(prompt=identical, best_score=4.0, generation=0)
pop_b = Candidate(
prompt=Prompt(text="Completely different prompt about data analysis."),
best_score=3.0,
generation=0,
)
sim_same = EvolutionLoop._compute_diversity_score(identical, [pop_a, pop_b])
# Average includes similarity to the different member, so ~0.5 not 0.9+
assert sim_same > 0.3
def test_different_prompts_have_lower_similarity(self) -> None:
"""Different prompts should have lower similarity than identical ones."""
prompt_a = Prompt(text="You are a helpful assistant. Answer the question.")
prompt_b = Prompt(text="Provide detailed analysis of complex data patterns with precision.")
pop_a = Candidate(prompt=prompt_a, best_score=4.0, generation=0)
pop_b = Candidate(prompt=prompt_b, best_score=3.0, generation=0)
sim_a = EvolutionLoop._compute_diversity_score(prompt_a, [pop_a, pop_b])
sim_b = EvolutionLoop._compute_diversity_score(prompt_b, [pop_a, pop_b])
# Both should be < 1.0 since they're different
assert sim_a < 1.0
assert sim_b < 1.0
def test_single_member_population_returns_1(self) -> None:
"""Single-member population always returns 1.0 (no penalty)."""
prompt = Prompt(text="Any prompt text here.")
pop = [Candidate(prompt=prompt, best_score=1.0, generation=0)]
sim = EvolutionLoop._compute_diversity_score(prompt, pop)
assert sim == 1.0
def test_empty_prompt_returns_zero(self) -> None:
"""Empty prompt text returns 0.0 when population has >1 member."""
prompt = Prompt(text="")
pop = [
Candidate(prompt=Prompt(text="some text"), best_score=1.0, generation=0),
Candidate(prompt=Prompt(text="other text"), best_score=2.0, generation=0),
]
sim = EvolutionLoop._compute_diversity_score(prompt, pop)
assert sim == 0.0
class TestPromptDiff:
"""Tests for the static _compute_prompt_diff helper."""
def test_identical_prompts(self) -> None:
result = EvolutionLoop._compute_prompt_diff("hello\nworld", "hello\nworld")
assert result["lines_added"] == 0
assert result["lines_removed"] == 0
assert result["chars_delta"] == 0
def test_added_lines(self) -> None:
result = EvolutionLoop._compute_prompt_diff("hello", "hello\nworld")
assert result["lines_added"] == 1
assert result["lines_removed"] == 0
assert result["chars_delta"] == 6 # "\nworld"
def test_removed_lines(self) -> None:
result = EvolutionLoop._compute_prompt_diff("hello\nworld", "hello")
assert result["lines_added"] == 0
assert result["lines_removed"] == 1

View File

@@ -0,0 +1,133 @@
"""Tests for GroundTruthEvaluator — execution + similarity comparison."""
from __future__ import annotations
from unittest.mock import AsyncMock
import pytest
from prometheus.application.ground_truth_evaluator import GroundTruthEvaluator
from prometheus.domain.entities import EvalResult, GroundTruthExample, Prompt
from prometheus.domain.ports import LLMPort, SimilarityPort
@pytest.fixture
def mock_executor() -> AsyncMock:
port = AsyncMock(spec=LLMPort)
port.execute.return_value = "Paris is the capital of France."
return port
@pytest.fixture
def mock_similarity() -> AsyncMock:
port = AsyncMock(spec=SimilarityPort)
port.compute.return_value = 0.85
return port
@pytest.fixture
def gt_dataset() -> list[GroundTruthExample]:
return [
GroundTruthExample(input_text="What is the capital of France?", expected_output="Paris", id=0),
GroundTruthExample(input_text="What is 2+2?", expected_output="4", id=1),
GroundTruthExample(input_text="What color is the sky?", expected_output="blue", id=2),
]
@pytest.fixture
def prompt() -> Prompt:
return Prompt(text="Answer the following question accurately.")
@pytest.mark.asyncio
class TestGroundTruthEvaluator:
async def test_evaluate_happy_path(self, mock_executor, mock_similarity, gt_dataset, prompt):
evaluator = GroundTruthEvaluator(
executor=mock_executor,
similarity=mock_similarity,
max_concurrency=2,
)
result = await evaluator.evaluate(prompt, gt_dataset)
assert isinstance(result, EvalResult)
assert len(result.scores) == 3
assert len(result.feedbacks) == 3
assert len(result.trajectories) == 3
assert all(s == 0.85 for s in result.scores)
assert result.mean_score == pytest.approx(0.85)
assert result.total_score == pytest.approx(2.55)
async def test_executor_called_for_each_input(self, mock_executor, mock_similarity, gt_dataset, prompt):
evaluator = GroundTruthEvaluator(
executor=mock_executor, similarity=mock_similarity,
)
await evaluator.evaluate(prompt, gt_dataset)
assert mock_executor.execute.call_count == 3
async def test_similarity_called_for_each_output(self, mock_executor, mock_similarity, gt_dataset, prompt):
evaluator = GroundTruthEvaluator(
executor=mock_executor, similarity=mock_similarity,
)
await evaluator.evaluate(prompt, gt_dataset)
assert mock_similarity.compute.call_count == 3
async def test_execution_error_produces_zero_score(self, mock_similarity, gt_dataset, prompt):
failing_executor = AsyncMock(spec=LLMPort)
failing_executor.execute.side_effect = RuntimeError("API timeout")
evaluator = GroundTruthEvaluator(
executor=failing_executor, similarity=mock_similarity,
)
result = await evaluator.evaluate(prompt, gt_dataset)
assert len(result.scores) == 3
# The similarity adapter is called with the error sentinel
assert all(isinstance(s, float) for s in result.scores)
assert all("[execution error:" in t.output_text for t in result.trajectories)
async def test_empty_dataset(self, mock_executor, mock_similarity, prompt):
evaluator = GroundTruthEvaluator(
executor=mock_executor, similarity=mock_similarity,
)
result = await evaluator.evaluate(prompt, [])
assert result.scores == []
assert result.mean_score == 0.0
assert result.total_score == 0.0
async def test_trajectory_contains_prompt_used(self, mock_executor, mock_similarity, gt_dataset, prompt):
evaluator = GroundTruthEvaluator(
executor=mock_executor, similarity=mock_similarity,
)
result = await evaluator.evaluate(prompt, gt_dataset)
for t in result.trajectories:
assert t.prompt_used == prompt.text
async def test_scores_clamped_to_unit_range(self, mock_executor, gt_dataset, prompt):
# Similarity returns a value > 1.0 (should be clamped)
over_similarity = AsyncMock(spec=SimilarityPort)
over_similarity.compute.return_value = 1.5
evaluator = GroundTruthEvaluator(
executor=mock_executor, similarity=over_similarity,
)
result = await evaluator.evaluate(prompt, gt_dataset)
assert all(0.0 <= s <= 1.0 for s in result.scores)
async def test_feedback_for_exact_match(self, mock_executor, gt_dataset, prompt):
exact_similarity = AsyncMock(spec=SimilarityPort)
exact_similarity.compute.return_value = 1.0
evaluator = GroundTruthEvaluator(
executor=mock_executor, similarity=exact_similarity,
)
result = await evaluator.evaluate(prompt, gt_dataset)
assert all("Exact match" in fb for fb in result.feedbacks)
async def test_feedback_for_poor_match(self, mock_executor, gt_dataset, prompt):
poor_similarity = AsyncMock(spec=SimilarityPort)
poor_similarity.compute.return_value = 0.1
evaluator = GroundTruthEvaluator(
executor=mock_executor, similarity=poor_similarity,
)
result = await evaluator.evaluate(prompt, gt_dataset)
assert all("Poor match" in fb for fb in result.feedbacks)

View File

@@ -0,0 +1,316 @@
"""Unit tests for hold-out validation and early stopping."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.evaluator import PromptEvaluator
from prometheus.application.evolution import EvolutionLoop
from prometheus.domain.entities import (
Candidate,
EvalResult,
Prompt,
SyntheticExample,
Trajectory,
)
def _make_eval(mean_score: float, n: int = 5) -> EvalResult:
"""Helper: create an EvalResult with a given mean score."""
scores = [mean_score] * n
return EvalResult(
scores=scores,
feedbacks=["feedback"] * n,
trajectories=[
Trajectory(f"input{i}", f"output{i}", mean_score, "feedback", "prompt")
for i in range(n)
],
)
class TestBootstrapSplit:
"""Tests for SyntheticBootstrap.split_pool."""
def test_split_produces_correct_sizes(self):
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(20)]
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
assert len(train) + len(val) == 20
assert len(val) == 6 # 20 * 0.3 = 6
assert len(train) == 14
def test_split_zero_fraction_returns_all_train(self):
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(10)]
train, val = SyntheticBootstrap.split_pool(pool, 0.0)
assert len(train) == 10
assert len(val) == 0
def test_split_single_element(self):
pool = [SyntheticExample(input_text="only", id=0)]
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
assert len(train) == 1
assert len(val) == 0
def test_split_deterministic_with_seed(self):
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(50)]
train1, val1 = SyntheticBootstrap.split_pool(pool, 0.3, rng=MagicMock(wraps=__import__("random").Random(42)))
train2, val2 = SyntheticBootstrap.split_pool(pool, 0.3, rng=MagicMock(wraps=__import__("random").Random(42)))
assert [ex.id for ex in train1] == [ex.id for ex in train2]
assert [ex.id for ex in val1] == [ex.id for ex in val2]
def test_split_no_overlap(self):
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(30)]
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
train_ids = {ex.id for ex in train}
val_ids = {ex.id for ex in val}
assert train_ids.isdisjoint(val_ids)
assert train_ids | val_ids == {ex.id for ex in pool}
class TestValidationEvaluation:
"""Tests for hold-out evaluation during evolution."""
@pytest.mark.asyncio
async def test_validation_pool_evaluated_after_each_iteration(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
mock_llm_port: AsyncMock,
mock_judge_port: AsyncMock,
mock_proposer_port: AsyncMock,
) -> None:
"""When a validation pool is provided, the best candidate is evaluated on it."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
# Initial eval (train) + validation eval + iteration train eval + new prompt eval + validation eval
train_eval = _make_eval(0.5)
val_eval = _make_eval(0.6)
new_eval = _make_eval(0.7)
val_eval_2 = _make_eval(0.65)
evaluator.evaluate = AsyncMock(
side_effect=[train_eval, val_eval, train_eval, new_eval, val_eval_2]
)
validation_pool = synthetic_pool[-6:]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer_port,
bootstrap=bootstrap,
max_iterations=1,
minibatch_size=5,
)
state = await loop.run(
seed_prompt, synthetic_pool, task_description,
validation_pool=validation_pool,
)
# Should have validation metrics in state
assert state.best_validation_score is not None
# History should contain validation_eval entries
val_events = [h for h in state.history if h["event"] == "validation_eval"]
assert len(val_events) >= 1
@pytest.mark.asyncio
async def test_no_validation_without_pool(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
mock_llm_port: AsyncMock,
mock_judge_port: AsyncMock,
mock_proposer_port: AsyncMock,
) -> None:
"""Without a validation pool, no validation is performed."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
train_eval = _make_eval(0.5)
old_eval = _make_eval(0.5)
new_eval = _make_eval(0.7)
evaluator.evaluate = AsyncMock(side_effect=[train_eval, old_eval, new_eval])
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer_port,
bootstrap=bootstrap,
max_iterations=1,
minibatch_size=5,
)
state = await loop.run(seed_prompt, synthetic_pool, task_description)
assert state.best_validation_score is None
assert not state.early_stopped
val_events = [h for h in state.history if h["event"] == "validation_eval"]
assert len(val_events) == 0
class TestEarlyStopping:
"""Tests for early stopping when validation score degrades."""
@pytest.mark.asyncio
async def test_early_stop_triggers_on_patience_exceeded(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
mock_llm_port: AsyncMock,
mock_judge_port: AsyncMock,
mock_proposer_port: AsyncMock,
) -> None:
"""Early stopping triggers when validation doesn't improve for K iterations."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
patience = 3
# Build eval sequence:
# 1. Initial train eval
# 2. Initial validation eval (0.5)
# Then for each of 3 iterations:
# - train eval (current best)
# - train eval (new prompt - accepted)
# - validation eval (degrading)
evals = [
_make_eval(0.5), # initial train
_make_eval(0.5), # initial validation
]
for i in range(patience):
evals.extend([
_make_eval(0.5 + i * 0.1), # current eval (train)
_make_eval(0.6 + i * 0.1), # new eval (train) - accepted
_make_eval(0.4), # validation eval (degrading)
])
evaluator.evaluate = AsyncMock(side_effect=evals)
validation_pool = synthetic_pool[-5:]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer_port,
bootstrap=bootstrap,
max_iterations=10, # would go further without early stop
minibatch_size=5,
early_stop_patience=patience,
)
state = await loop.run(
seed_prompt, synthetic_pool, task_description,
validation_pool=validation_pool,
)
assert state.early_stopped is True
assert state.iteration == patience
assert state.best_validation_score is not None
# Should have an early_stop event in history
early_stop_events = [h for h in state.history if h["event"] == "early_stop"]
assert len(early_stop_events) == 1
@pytest.mark.asyncio
async def test_early_stop_does_not_trigger_when_improving(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
mock_llm_port: AsyncMock,
mock_judge_port: AsyncMock,
mock_proposer_port: AsyncMock,
) -> None:
"""When validation keeps improving, early stopping does not trigger."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
evals = [
_make_eval(0.3), # initial train
_make_eval(0.3), # initial validation
]
# 3 iterations, each with improving validation
for i in range(3):
evals.extend([
_make_eval(0.3 + i * 0.1), # current train eval
_make_eval(0.4 + i * 0.1), # new train eval (accepted)
_make_eval(0.3 + (i + 1) * 0.1), # validation eval (improving)
])
evaluator.evaluate = AsyncMock(side_effect=evals)
validation_pool = synthetic_pool[-5:]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer_port,
bootstrap=bootstrap,
max_iterations=3,
minibatch_size=5,
early_stop_patience=5,
)
state = await loop.run(
seed_prompt, synthetic_pool, task_description,
validation_pool=validation_pool,
)
assert state.early_stopped is False
assert state.iteration == 3
assert state.best_validation_score is not None
@pytest.mark.asyncio
async def test_validation_patience_resets_on_improvement(
self,
seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample],
task_description: str,
mock_llm_port: AsyncMock,
mock_judge_port: AsyncMock,
mock_proposer_port: AsyncMock,
) -> None:
"""Patience counter resets when validation improves after degrading."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
evals = [
_make_eval(0.3), # initial train
_make_eval(0.3), # initial validation
# iter 1: degrade
_make_eval(0.3), # current train
_make_eval(0.5), # new train (accepted)
_make_eval(0.2), # validation degrade (patience=1)
# iter 2: degrade
_make_eval(0.5), # current train
_make_eval(0.6), # new train (accepted)
_make_eval(0.2), # validation degrade (patience=2)
# iter 3: improve! (resets patience)
_make_eval(0.6), # current train
_make_eval(0.7), # new train (accepted)
_make_eval(0.4), # validation improve (patience=0)
# iter 4: degrade again
_make_eval(0.7), # current train
_make_eval(0.8), # new train (accepted)
_make_eval(0.2), # validation degrade (patience=1)
]
evaluator.evaluate = AsyncMock(side_effect=evals)
validation_pool = synthetic_pool[-5:]
loop = EvolutionLoop(
evaluator=evaluator,
proposer=mock_proposer_port,
bootstrap=bootstrap,
max_iterations=4,
minibatch_size=5,
early_stop_patience=3,
)
state = await loop.run(
seed_prompt, synthetic_pool, task_description,
validation_pool=validation_pool,
)
assert state.early_stopped is False
assert state.iteration == 4

189
tests/unit/test_logging.py Normal file
View File

@@ -0,0 +1,189 @@
"""Unit tests for structured logging configuration."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from prometheus.cli.logging_setup import configure_logging, get_logger
class TestConfigureLogging:
def _count_handlers(self, name: str = "prometheus") -> int:
return len(logging.getLogger(name).handlers)
def test_default_creates_console_handler(self) -> None:
configure_logging(level=logging.INFO)
prom = logging.getLogger("prometheus")
assert len(prom.handlers) == 1
assert isinstance(prom.handlers[0], logging.StreamHandler)
prom.handlers.clear()
def test_json_format_produces_valid_json(self, capsys) -> None:
configure_logging(level=logging.INFO, log_format="json")
logger = get_logger("test_json")
logger.info("hello", extra={"structured": {"key": "value"}})
captured = capsys.readouterr()
# Output goes to stderr
line = captured.err.strip().split("\n")[-1]
data = json.loads(line)
assert data["message"] == "hello"
assert data["structured"]["key"] == "value"
assert data["level"] == "INFO"
assert "timestamp" in data
logging.getLogger("prometheus").handlers.clear()
def test_text_format_includes_structured_extras(self, capsys) -> None:
configure_logging(level=logging.INFO, log_format="text")
logger = get_logger("test_text")
logger.info("msg", extra={"structured": {"foo": "bar"}})
captured = capsys.readouterr()
assert "foo=bar" in captured.err
logging.getLogger("prometheus").handlers.clear()
def test_debug_level_shows_debug_messages(self, capsys) -> None:
configure_logging(level=logging.DEBUG)
logger = get_logger("test_debug")
logger.debug("debug msg")
captured = capsys.readouterr()
assert "debug msg" in captured.err
logging.getLogger("prometheus").handlers.clear()
def test_warning_level_hides_debug_messages(self, capsys) -> None:
configure_logging(level=logging.WARNING)
logger = get_logger("test_warn")
logger.debug("should not appear")
logger.info("also hidden")
captured = capsys.readouterr()
assert "should not appear" not in captured.err
assert "also hidden" not in captured.err
logging.getLogger("prometheus").handlers.clear()
def test_file_handler_writes_to_file(self, tmp_path: Path) -> None:
log_file = tmp_path / "test.log"
configure_logging(level=logging.INFO, log_file=str(log_file))
logger = get_logger("test_file")
logger.info("file message")
prom = logging.getLogger("prometheus")
# Flush handlers
for h in prom.handlers:
h.flush()
prom.handlers.clear()
content = log_file.read_text()
assert "file message" in content
def test_json_file_output(self, tmp_path: Path) -> None:
log_file = tmp_path / "test.json.log"
configure_logging(level=logging.INFO, log_format="json", log_file=str(log_file))
logger = get_logger("test_json_file")
logger.info("json file msg", extra={"structured": {"x": 1}})
prom = logging.getLogger("prometheus")
for h in prom.handlers:
h.flush()
prom.handlers.clear()
content = log_file.read_text().strip()
data = json.loads(content)
assert data["message"] == "json file msg"
assert data["structured"]["x"] == 1
def test_reconfigure_clears_old_handlers(self) -> None:
configure_logging(level=logging.INFO)
configure_logging(level=logging.DEBUG)
prom = logging.getLogger("prometheus")
assert len(prom.handlers) == 1
prom.handlers.clear()
def test_propagate_false_prevents_duplicate_output(self, capsys) -> None:
configure_logging(level=logging.INFO)
prom = logging.getLogger("prometheus")
assert prom.propagate is False
prom.handlers.clear()
class TestGetLogger:
def test_returns_child_of_prometheus(self) -> None:
logger = get_logger("mymodule")
assert logger.name == "prometheus.mymodule"
def test_inherits_level_from_parent(self) -> None:
configure_logging(level=logging.DEBUG)
logger = get_logger("child")
assert logger.getEffectiveLevel() <= logging.DEBUG
logging.getLogger("prometheus").handlers.clear()
class TestJsonFormatter:
def test_exception_included(self, capsys) -> None:
configure_logging(level=logging.ERROR, log_format="json")
logger = get_logger("test_exc")
try:
raise ValueError("boom")
except ValueError:
logger.error("failed", exc_info=True)
captured = capsys.readouterr()
line = captured.err.strip().split("\n")[-1]
data = json.loads(line)
assert "ValueError: boom" in data["exception"]
logging.getLogger("prometheus").handlers.clear()
class TestLoggingCLIIntegration:
"""Tests for CLI flags that configure logging."""
def test_verbose_flag_enables_info(self, tmp_path: Path) -> None:
"""Simulate what -v does — configure_logging at INFO level."""
configure_logging(level=logging.INFO)
logger = get_logger("evolution")
logger.info("test message")
prom = logging.getLogger("prometheus")
assert len(prom.handlers) == 1
prom.handlers.clear()
def test_debug_flag_enables_debug(self) -> None:
"""Simulate what --debug does — configure_logging at DEBUG level."""
configure_logging(level=logging.DEBUG)
logger = get_logger("evolution")
logger.debug("debug message")
prom = logging.getLogger("prometheus")
assert prom.level == logging.DEBUG
prom.handlers.clear()
def test_log_format_invalid_rejected(self) -> None:
"""Invalid log_format should be caught by OptimizationConfig validator."""
from pydantic import ValidationError
from prometheus.application.dto import OptimizationConfig
import pytest
with pytest.raises(ValidationError, match="log_format must be one of"):
OptimizationConfig(
seed_prompt="a",
task_description="b",
log_format="xml",
)
def test_log_format_text_and_json_accepted(self) -> None:
"""Both text and json log_format values should be valid."""
from prometheus.application.dto import OptimizationConfig
for fmt in ("text", "json"):
config = OptimizationConfig(
seed_prompt="a", task_description="b", log_format=fmt,
)
assert config.log_format == fmt

View File

@@ -0,0 +1,96 @@
"""Additional unit tests for scoring edge cases."""
from __future__ import annotations
import pytest
from prometheus.domain.entities import EvalResult, Trajectory
from prometheus.domain.scoring import normalize_score, should_accept
def _make_eval(scores: list[float]) -> EvalResult:
return EvalResult(
scores=scores,
feedbacks=[""] * len(scores),
trajectories=[
Trajectory(f"in{i}", f"out{i}", s, "", "p")
for i, s in enumerate(scores)
],
)
class TestShouldAcceptEdgeCases:
"""Extended edge-case tests for should_accept."""
def test_tiny_improvement_accepted(self) -> None:
old = _make_eval([0.5])
new = _make_eval([0.5001])
assert should_accept(old, new) is True
def test_tiny_improvement_below_threshold(self) -> None:
old = _make_eval([0.5])
new = _make_eval([0.5001])
assert should_accept(old, new, min_improvement=0.01) is False
def test_zero_scores_equal(self) -> None:
old = _make_eval([0.0, 0.0])
new = _make_eval([0.0, 0.0])
assert should_accept(old, new) is False
def test_negative_to_zero_not_accepted(self) -> None:
"""Scores should be [0,1] but test should_accept with edge values."""
old = _make_eval([-0.1])
new = _make_eval([0.0])
assert should_accept(old, new) is True
def test_large_improvement(self) -> None:
old = _make_eval([0.0, 0.0, 0.0])
new = _make_eval([1.0, 1.0, 1.0])
assert should_accept(old, new) is True
def test_single_score_improvement(self) -> None:
old = _make_eval([0.4])
new = _make_eval([0.5])
assert should_accept(old, new) is True
def test_min_improvement_exactly_met(self) -> None:
"""When improvement exactly equals min_improvement, still rejected (strict >)."""
old = _make_eval([0.5])
new = _make_eval([0.7])
assert should_accept(old, new, min_improvement=0.2) is False
def test_min_improvement_just_over(self) -> None:
old = _make_eval([0.5])
new = _make_eval([0.7001])
assert should_accept(old, new, min_improvement=0.2) is True
class TestNormalizeScoreEdgeCases:
"""Extended edge-case tests for normalize_score."""
def test_exact_bounds(self) -> None:
assert normalize_score(0.0) == 0.0
assert normalize_score(1.0) == 1.0
def test_very_large_value(self) -> None:
assert normalize_score(1e10) == 1.0
def test_very_negative_value(self) -> None:
assert normalize_score(-1e10) == 0.0
def test_custom_bounds_at_edges(self) -> None:
assert normalize_score(5.0, min_val=0.0, max_val=10.0) == 5.0
assert normalize_score(0.0, min_val=0.0, max_val=10.0) == 0.0
assert normalize_score(10.0, min_val=0.0, max_val=10.0) == 10.0
def test_negative_custom_range(self) -> None:
assert normalize_score(0.0, min_val=-5.0, max_val=5.0) == 0.0
assert normalize_score(-3.0, min_val=-5.0, max_val=5.0) == -3.0
assert normalize_score(-10.0, min_val=-5.0, max_val=5.0) == -5.0
def test_zero_span_range(self) -> None:
"""When min == max, clamps to min."""
assert normalize_score(5.0, min_val=5.0, max_val=5.0) == 5.0
assert normalize_score(0.0, min_val=5.0, max_val=5.0) == 5.0
def test_fractional_score(self) -> None:
assert normalize_score(0.3333) == pytest.approx(0.3333)

View File

@@ -0,0 +1,133 @@
"""Tests for similarity adapters — exact, BLEU, ROUGE-L, cosine."""
from __future__ import annotations
import pytest
from prometheus.infrastructure.similarity import (
BleuSimilarity,
CosineSimilarity,
ExactMatchSimilarity,
RougeLSimilarity,
create_similarity_adapter,
)
class TestExactMatchSimilarity:
def test_exact_match(self):
s = ExactMatchSimilarity()
assert s.compute("Hello World", "Hello World") == 1.0
def test_case_insensitive(self):
s = ExactMatchSimilarity()
assert s.compute("hello world", "HELLO WORLD") == 1.0
def test_whitespace_trimmed(self):
s = ExactMatchSimilarity()
assert s.compute(" hello ", "hello") == 1.0
def test_no_match(self):
s = ExactMatchSimilarity()
assert s.compute("hello", "world") == 0.0
def test_partial_no_match(self):
s = ExactMatchSimilarity()
assert s.compute("hello world", "hello") == 0.0
class TestBleuSimilarity:
def test_perfect_match(self):
s = BleuSimilarity()
assert s.compute("the cat sat on the mat", "the cat sat on the mat") == 1.0
def test_no_overlap(self):
s = BleuSimilarity()
assert s.compute("aaa bbb ccc", "ddd eee fff") == 0.0
def test_partial_overlap(self):
s = BleuSimilarity()
score = s.compute("the cat sat", "the cat")
assert 0.0 < score < 1.0
def test_empty_prediction(self):
s = BleuSimilarity()
assert s.compute("", "hello world") == 0.0
def test_empty_expected(self):
s = BleuSimilarity()
assert s.compute("hello world", "") == 0.0
def test_both_empty(self):
s = BleuSimilarity()
assert s.compute("", "") == 0.0
def test_shorter_prediction_gets_brevity_penalty(self):
s = BleuSimilarity()
short = s.compute("cat", "the cat sat on the mat")
full = s.compute("the cat sat on the mat", "the cat sat on the mat")
assert short < full
class TestRougeLSimilarity:
def test_perfect_match(self):
s = RougeLSimilarity()
assert s.compute("the cat sat", "the cat sat") == 1.0
def test_no_overlap(self):
s = RougeLSimilarity()
assert s.compute("aaa bbb", "ccc ddd") == 0.0
def test_partial_overlap(self):
s = RougeLSimilarity()
score = s.compute("the cat sat on the mat", "the cat on the rug")
assert 0.0 < score < 1.0
def test_empty_prediction(self):
s = RougeLSimilarity()
assert s.compute("", "hello") == 0.0
def test_subsequence(self):
s = RougeLSimilarity()
# "cat mat" is a subsequence of "the cat sat on the mat"
score = s.compute("cat mat", "the cat sat on the mat")
assert score > 0.0
class TestCosineSimilarity:
def test_identical_texts(self):
s = CosineSimilarity()
assert s.compute("hello world", "hello world") == pytest.approx(1.0)
def test_no_overlap(self):
s = CosineSimilarity()
assert s.compute("aaa bbb", "ccc ddd") == 0.0
def test_partial_overlap(self):
s = CosineSimilarity()
score = s.compute("hello world foo", "hello world bar")
assert 0.0 < score < 1.0
def test_empty_prediction(self):
s = CosineSimilarity()
assert s.compute("", "hello") == 0.0
class TestCreateSimilarityAdapter:
def test_create_exact(self):
adapter = create_similarity_adapter("exact")
assert isinstance(adapter, ExactMatchSimilarity)
def test_create_bleu(self):
adapter = create_similarity_adapter("bleu")
assert isinstance(adapter, BleuSimilarity)
def test_create_rouge_l(self):
adapter = create_similarity_adapter("rouge_l")
assert isinstance(adapter, RougeLSimilarity)
def test_create_cosine(self):
adapter = create_similarity_adapter("cosine")
assert isinstance(adapter, CosineSimilarity)
def test_unknown_metric_raises(self):
with pytest.raises(ValueError, match="Unknown eval metric"):
create_similarity_adapter("nonexistent")

View File

@@ -0,0 +1,233 @@
"""Unit tests for OptimizePromptUseCase — direct orchestration tests."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.dto import OptimizationConfig, OptimizationResult
from prometheus.application.evaluator import PromptEvaluator
from prometheus.application.evolution import EvolutionLoop
from prometheus.application.use_cases import OptimizePromptUseCase
from prometheus.domain.entities import (
Candidate,
EvalResult,
OptimizationState,
Prompt,
SyntheticExample,
Trajectory,
)
def _make_eval(scores: list[float]) -> EvalResult:
return EvalResult(
scores=scores,
feedbacks=["feedback"] * len(scores),
trajectories=[
Trajectory(f"in{i}", f"out{i}", s, "feedback", "prompt")
for i, s in enumerate(scores)
],
)
def _make_state(
iterations: int = 3,
initial_score: float = 0.3,
final_score: float = 0.8,
accepted: bool = True,
) -> OptimizationState:
seed = Candidate(prompt=Prompt(text="seed"), best_score=initial_score, generation=0)
best = Candidate(
prompt=Prompt(text="optimized" if accepted else "seed"),
best_score=final_score,
generation=iterations if accepted else 0,
)
history = []
for i in range(1, iterations + 1):
event = "accepted" if accepted else "rejected"
history.append({"iteration": i, "event": event, "old_score": 0.3, "new_score": 0.8})
return OptimizationState(
iteration=iterations,
best_candidate=best,
candidates=[seed, best] if accepted else [seed],
total_llm_calls=iterations * 11 + 10,
history=history,
)
class TestOptimizePromptUseCaseExecute:
"""Tests for the execute() orchestration method."""
@pytest.fixture
def mock_evaluator(self) -> MagicMock:
return MagicMock(spec=PromptEvaluator)
@pytest.fixture
def mock_proposer(self) -> MagicMock:
return MagicMock()
@pytest.fixture
def mock_bootstrap(self) -> MagicMock:
return MagicMock(spec=SyntheticBootstrap)
@pytest.fixture
def use_case(
self,
mock_evaluator: MagicMock,
mock_proposer: MagicMock,
mock_bootstrap: MagicMock,
) -> OptimizePromptUseCase:
return OptimizePromptUseCase(
evaluator=mock_evaluator,
proposer=mock_proposer,
bootstrap=mock_bootstrap,
)
@pytest.fixture
def config(self) -> OptimizationConfig:
return OptimizationConfig(
seed_prompt="Answer the question.",
task_description="Q&A task",
max_iterations=5,
n_synthetic_inputs=20,
minibatch_size=5,
seed=42,
)
@pytest.mark.asyncio
async def test_returns_optimization_result(
self,
use_case: OptimizePromptUseCase,
mock_bootstrap: MagicMock,
config: OptimizationConfig,
) -> None:
mock_bootstrap.run.return_value = [
SyntheticExample(input_text=f"q{i}", id=i) for i in range(20)
]
mock_state = _make_state(iterations=3, initial_score=0.3, final_score=0.9)
with patch.object(EvolutionLoop, "run", return_value=mock_state):
result = await use_case.execute(config)
assert isinstance(result, OptimizationResult)
assert result.initial_prompt == "Answer the question."
assert result.final_score == 0.9
assert result.improvement == pytest.approx(0.6)
@pytest.mark.asyncio
async def test_bootstrap_called_with_config_params(
self,
use_case: OptimizePromptUseCase,
mock_bootstrap: MagicMock,
config: OptimizationConfig,
) -> None:
mock_bootstrap.run.return_value = []
mock_state = _make_state()
with patch.object(EvolutionLoop, "run", return_value=mock_state):
await use_case.execute(config)
mock_bootstrap.run.assert_called_once_with(
task_description="Q&A task",
n_examples=20,
)
@pytest.mark.asyncio
async def test_evolution_loop_configured_from_config(
self,
use_case: OptimizePromptUseCase,
mock_bootstrap: MagicMock,
config: OptimizationConfig,
) -> None:
mock_bootstrap.run.return_value = []
mock_state = _make_state()
with patch.object(EvolutionLoop, "run", return_value=mock_state) as mock_run:
await use_case.execute(config)
# Verify the loop was instantiated with correct params
mock_run.assert_called_once()
call_args = mock_run.call_args
seed_prompt = call_args[0][0]
assert seed_prompt.text == "Answer the question."
synthetic_pool = call_args[0][1]
assert len(synthetic_pool) == 0 # bootstrap returned empty
assert call_args[0][2] == "Q&A task"
@pytest.mark.asyncio
async def test_total_llm_calls_includes_bootstrap_call(
self,
use_case: OptimizePromptUseCase,
mock_bootstrap: MagicMock,
config: OptimizationConfig,
) -> None:
mock_bootstrap.run.return_value = []
mock_state = _make_state(iterations=3)
# total_llm_calls from state + 1 for bootstrap
expected = mock_state.total_llm_calls + 1
with patch.object(EvolutionLoop, "run", return_value=mock_state):
result = await use_case.execute(config)
assert result.total_llm_calls == expected
@pytest.mark.asyncio
async def test_no_candidates_fallback(
self,
use_case: OptimizePromptUseCase,
mock_bootstrap: MagicMock,
config: OptimizationConfig,
) -> None:
mock_bootstrap.run.return_value = [
SyntheticExample(input_text=f"q{i}", id=i) for i in range(20)
]
mock_state = OptimizationState(
iteration=0,
best_candidate=None,
candidates=[],
total_llm_calls=0,
)
with patch.object(EvolutionLoop, "run", return_value=mock_state):
result = await use_case.execute(config)
assert result.optimized_prompt == "Answer the question."
assert result.initial_score == 0.0
assert result.final_score == 0.0
assert result.improvement == 0.0
@pytest.mark.asyncio
async def test_iterations_used_matches_state(
self,
use_case: OptimizePromptUseCase,
mock_bootstrap: MagicMock,
config: OptimizationConfig,
) -> None:
mock_bootstrap.run.return_value = []
mock_state = _make_state(iterations=7)
with patch.object(EvolutionLoop, "run", return_value=mock_state):
result = await use_case.execute(config)
assert result.iterations_used == 7
@pytest.mark.asyncio
async def test_history_passed_through(
self,
use_case: OptimizePromptUseCase,
mock_bootstrap: MagicMock,
config: OptimizationConfig,
) -> None:
mock_bootstrap.run.return_value = []
history = [
{"iteration": 1, "event": "accepted"},
{"iteration": 2, "event": "rejected"},
]
mock_state = _make_state()
mock_state.history = history
with patch.object(EvolutionLoop, "run", return_value=mock_state):
result = await use_case.execute(config)
assert result.history == history