feat: v0.2.0 sprint — ground truth eval, crossover/mutation, checkpointing, similarity guards, dataset loader, CLI commands, extended test coverage
Aggregates all v0.2.0 sprint work (GARAA-30 through GARAA-40) and fixes 2 integration tests that broke when the codebase went async (DSPyLLMAdapter and full pipeline tests now properly await coroutines). 277 tests pass (260 unit + 17 integration). Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
369
docs/FEATURE_ROADMAP.md
Normal file
369
docs/FEATURE_ROADMAP.md
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
# PROMETHEUS Feature Roadmap
|
||||||
|
|
||||||
|
> Complete codebase review — features needed for production-grade prompt optimization.
|
||||||
|
> Generated from v0.1.0 architecture review (2026-03-29).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Legend
|
||||||
|
|
||||||
|
| Marker | Meaning |
|
||||||
|
|--------|---------|
|
||||||
|
| **CLI** | Exposed as a CLI option/flag |
|
||||||
|
| **Config** | YAML config field |
|
||||||
|
| **Internal** | No user-facing surface, architectural improvement |
|
||||||
|
| **P1** | Critical / must-have for reliability |
|
||||||
|
| **P2** | High value, should-have |
|
||||||
|
| **P3** | Nice-to-have, deferred to later versions |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Multi-Model Routing (P1)
|
||||||
|
|
||||||
|
**Current state:** `OptimizationConfig` defines four model slots (`task_model`, `judge_model`, `proposer_model`, `synth_model`), but `cli/app.py` only configures a single global DSPy LM from `task_model`. All adapters silently use the same model regardless of config.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Each adapter (`DSPyLLMAdapter`, `DSPyJudgeAdapter`, `DSPyProposerAdapter`, `DSPySyntheticAdapter`) must instantiate its own `dspy.LM` from the corresponding config field.
|
||||||
|
- Support per-model `api_base` and `api_key_env` overrides (e.g., judge on GPT-4o, propose on a cheaper model).
|
||||||
|
|
||||||
|
**Surface:** Config (already partially defined) — `judge_model`, `proposer_model`, `synth_model` become functional. No new CLI flags needed; the YAML already has the fields.
|
||||||
|
|
||||||
|
**Scope:** Infrastructure layer (`llm_adapter.py`, `judge_adapter.py`, `proposer_adapter.py`, `synth_adapter.py`) + `cli/app.py` DI wiring.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Async / Parallel Execution (P1)
|
||||||
|
|
||||||
|
**Current state:** All LLM calls (execute, judge, propose) are sequential. A single iteration with `minibatch_size=5` makes ~11 sequential LLM calls. Wall-clock time scales linearly with minibatch size.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Parallelize execution of the prompt across a minibatch (`asyncio.gather` or `dspy.Parallel`).
|
||||||
|
- Parallelize judge calls within a batch.
|
||||||
|
- Keep the proposer sequential (single call per iteration).
|
||||||
|
|
||||||
|
**Surface:** Internal. Optionally exposed via `--max-concurrency` CLI flag and `max_concurrency` YAML field.
|
||||||
|
|
||||||
|
**Scope:** `evaluator.py`, `judge_adapter.py`, `llm_adapter.py`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Robust Error Handling & Retry (P1)
|
||||||
|
|
||||||
|
**Current state:** The evolution loop catches broad `Exception` per iteration and logs it, then continues. Individual LLM call failures (timeouts, rate limits, malformed responses) are not retried. DSPy module fallbacks only cover parsing, not network errors.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Retry with exponential backoff for transient errors (rate limits, timeouts, 5xx).
|
||||||
|
- Configurable `max_retries` and `retry_delay_base`.
|
||||||
|
- Circuit breaker: if N consecutive iterations fail, pause and alert.
|
||||||
|
- Per-call error isolation: one bad minibatch item shouldn't fail the whole evaluation.
|
||||||
|
|
||||||
|
**Surface:** `--max-retries` CLI flag, `max_retries` Config field. `--error-strategy` (skip | retry | abort) CLI flag.
|
||||||
|
|
||||||
|
**Scope:** Infrastructure adapters + evolution loop.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Checkpoint & Resume (P2)
|
||||||
|
|
||||||
|
**Current state:** If a long optimization run crashes or is interrupted, all progress is lost. There is no intermediate state persistence.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Save `OptimizationState` to disk every K iterations (or every accepted improvement).
|
||||||
|
- Resume from the latest checkpoint file on restart.
|
||||||
|
- Checkpoint includes: current best candidate, all candidates, iteration number, LLM call count, RNG seed state.
|
||||||
|
|
||||||
|
**Surface:** `--checkpoint-dir` CLI flag (default: `.prometheus/checkpoints/`). `--resume` CLI flag to resume from latest checkpoint. `checkpoint_interval` Config field.
|
||||||
|
|
||||||
|
**Scope:** New `CheckpointPort` in domain, `JsonCheckpointPersistence` in infrastructure, modifications to `EvolutionLoop.run()`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Population-Based Evolution (P2)
|
||||||
|
|
||||||
|
**Current state:** The evolution loop keeps only a single best candidate (hill climbing). No diversity, no crossover, no population dynamics. The `Candidate` entity has `generation` and `parent_id` fields that suggest population support was planned.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Maintain a population of K candidates (e.g., top-K by score or Pareto front).
|
||||||
|
- Crossover: combine instructions from two parent candidates.
|
||||||
|
- Mutation operators: paraphrase, constrain, generalize, specialize.
|
||||||
|
- Diversity maintenance: penalize candidates too similar to existing ones (cosine similarity or edit distance).
|
||||||
|
|
||||||
|
**Surface:** `--population-size` CLI flag, `population_size` Config field. `--crossover-rate`, `--mutation-rate` CLI flags.
|
||||||
|
|
||||||
|
**Scope:** `EvolutionLoop` refactor, new `CrossoverPort` and `MutationPort` in domain, new DSPy signatures for crossover/mutation in infrastructure.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Hold-Out Validation (P2)
|
||||||
|
|
||||||
|
**Current state:** The same synthetic inputs are used for both optimization and evaluation. No train/test split. Risk of overfitting to synthetic inputs.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Split synthetic pool into train (e.g., 70%) and validation (30%) sets.
|
||||||
|
- Evolution uses train minibatches for accept/reject decisions.
|
||||||
|
- After each iteration, evaluate the best candidate on the hold-out set.
|
||||||
|
- Report both train and validation scores in results.
|
||||||
|
- Optional early stopping if validation score degrades for K consecutive iterations.
|
||||||
|
|
||||||
|
**Surface:** `--validation-split` CLI flag (default: 0.3). `--early-stop-patience` CLI flag (default: 5). Config fields: `validation_split`, `early_stop_patience`.
|
||||||
|
|
||||||
|
**Scope:** `SyntheticBootstrap`, `EvolutionLoop`, `OptimizationResult` (add validation metrics).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Custom Judge Criteria (P2)
|
||||||
|
|
||||||
|
**Current state:** The judge uses a hardcoded rubric in `JudgeOutput` DSPy signature ("score 0.0-1.0" with generic quality assessment). Users cannot customize evaluation criteria.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Allow users to define custom judge rubrics, criteria, and scoring scales.
|
||||||
|
- Support multi-dimensional scoring (e.g., accuracy: 0-10, clarity: 0-10, safety: 0-10) with configurable weights.
|
||||||
|
- Allow `perfect_score` to reflect the custom scale.
|
||||||
|
|
||||||
|
**Surface:** `judge_criteria` YAML field (free text). `judge_dimensions` YAML field (list of `{name, weight, description}`). CLI: `--judge-criteria` for quick overrides.
|
||||||
|
|
||||||
|
**Scope:** `JudgeOutput` signature (dynamic instructions), `JudgePort`, `DSPyJudgeAdapter`, `scoring.py` (weighted aggregation).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Real-World Evaluation Harness (P2)
|
||||||
|
|
||||||
|
**Current state:** The system only evaluates against synthetic inputs. There is no way to test optimized prompts against real inputs with known-good outputs.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Accept an optional evaluation dataset (CSV/JSON with `input` and `expected_output` columns).
|
||||||
|
- When provided, use exact/semantic similarity matching against expected outputs instead of (or in addition to) LLM-as-Judge.
|
||||||
|
- Report metrics: accuracy, BLEU, ROUGE, or embedding cosine similarity vs expected.
|
||||||
|
|
||||||
|
**Surface:** `--eval-dataset` CLI flag. `eval_dataset_path` Config field. `--eval-metric` CLI flag (exact | semantic | llm_judge).
|
||||||
|
|
||||||
|
**Scope:** New `GroundTruthEvaluator` in application, new `SimilarityPort` in domain, dataset loader in infrastructure.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Logging & Observability (P2)
|
||||||
|
|
||||||
|
**Current state:** Verbose mode (`-v`) configures Python's `logging` module but no handler is attached (Bug #4 in TEST_REPORT.md). No structured logging, no tracing.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Proper structured logging with configurable levels (DEBUG, INFO, WARNING, ERROR).
|
||||||
|
- JSON-formatted log output for machine parsing.
|
||||||
|
- Per-iteration trace: minibatch sample IDs, execution outputs, judge scores, proposer prompt diff.
|
||||||
|
- Optional OpenTelemetry export for distributed tracing.
|
||||||
|
|
||||||
|
**Surface:** `-v` / `--verbose` enables INFO level. `--debug` enables DEBUG level. `--log-format` (text | json). `--log-file` for file output. Config fields: `log_level`, `log_format`, `log_file`.
|
||||||
|
|
||||||
|
**Scope:** `cli/app.py` (logging setup), `evolution.py` (structured traces), new `TracingPort` in domain.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. CLI Improvements (P2)
|
||||||
|
|
||||||
|
**Current state:** Single `optimize` command. Known Typer 0.24 bug absorbing subcommands (Bug #1). No `version`, `init`, or `list-results` commands.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Fix Typer subcommand routing.
|
||||||
|
- `prometheus version` — show version.
|
||||||
|
- `prometheus init` — scaffold a config YAML interactively.
|
||||||
|
- `prometheus list` — list past optimization runs.
|
||||||
|
- `prometheus diff` — compare two result files (before/after prompt diff, score improvement).
|
||||||
|
- `prometheus eval` — evaluate a prompt against a dataset without optimization.
|
||||||
|
|
||||||
|
**Surface:** CLI subcommands.
|
||||||
|
|
||||||
|
**Scope:** `cli/app.py` restructured into `cli/commands/` with one module per command.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11. Input Validation & Schema Enforcement (P2)
|
||||||
|
|
||||||
|
**Current state:** Config YAML is parsed as a raw dict with no schema validation. Missing or wrong-type fields cause cryptic errors deep in the pipeline.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Validate input YAML against a Pydantic schema (leveraging the existing `pydantic` dependency).
|
||||||
|
- Provide clear, actionable error messages for missing/invalid fields.
|
||||||
|
- Support config migration/upgrade from older versions.
|
||||||
|
|
||||||
|
**Surface:** Internal. Errors surface as clear CLI messages.
|
||||||
|
|
||||||
|
**Scope:** `OptimizationConfig` converted to Pydantic model with validators, `cli/app.py` validation step before pipeline execution.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 12. Adaptive Minibatch Sizing (P3)
|
||||||
|
|
||||||
|
**Current state:** Minibatch size is static throughout the run. Small batches are noisy; large batches are expensive.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Start with a small minibatch for quick early iterations.
|
||||||
|
- Increase minibatch size as the prompt improves (higher confidence needed for marginal gains).
|
||||||
|
- Shrink if too many evaluations fail (cost optimization).
|
||||||
|
|
||||||
|
**Surface:** `--adaptive-minibatch` CLI flag (boolean toggle). `minibatch_size` becomes `minibatch_size_min` and `minibatch_size_max` in config.
|
||||||
|
|
||||||
|
**Scope:** `EvolutionLoop`, `SyntheticBootstrap`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 13. Prompt Diversity Tracking (P3)
|
||||||
|
|
||||||
|
**Current state:** No visibility into how much the prompt is actually changing between iterations. A "successful" optimization might just rephrase without structural change.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Compute edit distance (Levenshtein) or embedding cosine similarity between consecutive prompts.
|
||||||
|
- Report diversity metrics in the result.
|
||||||
|
- Flag stagnation (N iterations with <epsilon change).
|
||||||
|
|
||||||
|
**Surface:** Internal. Reported in `OptimizationResult.history` entries.
|
||||||
|
|
||||||
|
**Scope:** `EvolutionLoop`, `OptimizationResult` (add diversity field per history entry).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 14. Temperature & Sampling Control (P3)
|
||||||
|
|
||||||
|
**Current state:** No way to control LLM temperature, top_p, or other sampling parameters for any of the four model slots. DSPy defaults apply.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Per-model-slot temperature and sampling parameters.
|
||||||
|
- Higher temperature for proposer (creativity), lower for judge (consistency).
|
||||||
|
|
||||||
|
**Surface:** `task_temperature`, `judge_temperature`, `proposer_temperature`, `synth_temperature` Config fields. `--temperature` CLI flag for global override.
|
||||||
|
|
||||||
|
**Scope:** `cli/app.py` (DSPy LM configuration), infrastructure adapters.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 15. Cost Estimation & Budget Caps (P3)
|
||||||
|
|
||||||
|
**Current state:** `total_llm_calls` is tracked (inaccurately). No cost estimation, no budget caps.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Estimate cost per run based on model pricing and approximate token counts.
|
||||||
|
- Allow users to set a budget cap (`--max-cost-usd`).
|
||||||
|
- Report estimated cost in the result.
|
||||||
|
|
||||||
|
**Surface:** `--max-cost-usd` CLI flag. `max_cost_usd` Config field. Cost breakdown in result output.
|
||||||
|
|
||||||
|
**Scope:** `cli/app.py`, `OptimizationResult` (add cost fields), token counting in adapters.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 16. Multi-Objective Optimization (P3)
|
||||||
|
|
||||||
|
**Current state:** Single scalar score from the judge. The `Prompt` entity comment mentions "Pareto tracking" but it's not implemented.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Optimize for multiple objectives simultaneously (quality, latency, token efficiency, safety).
|
||||||
|
- Maintain a Pareto front of non-dominated candidates.
|
||||||
|
- Allow users to set objective weights or constraints.
|
||||||
|
|
||||||
|
**Surface:** `objectives` Config field (list of `{name, weight, judge_criteria}`). CLI: `--objective` repeatable flag.
|
||||||
|
|
||||||
|
**Scope:** `EvolutionLoop` (Pareto front), `scoring.py` (multi-objective acceptance), `OptimizationResult` (Pareto set).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 17. Export Optimized Prompt (P3)
|
||||||
|
|
||||||
|
**Current state:** The optimized prompt is embedded in the YAML result file. No easy way to extract it for use.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- `prometheus export` command to extract the optimized prompt as plain text.
|
||||||
|
- Support multiple export formats: plain text, Markdown, JSON, LangChain template, DSPy module.
|
||||||
|
- Copy to clipboard option.
|
||||||
|
|
||||||
|
**Surface:** `prometheus export --format <txt|md|json|langchain|dspy>` CLI subcommand. `--clipboard` flag.
|
||||||
|
|
||||||
|
**Scope:** New `cli/commands/export.py`, format renderers in infrastructure.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 18. Config Profiles / Presets (P3)
|
||||||
|
|
||||||
|
**Current state:** Every run requires a full config YAML. Common patterns (fast iterate, thorough optimize, cheap run) are not captured.
|
||||||
|
|
||||||
|
**Feature:**
|
||||||
|
- Named profiles: `fast`, `thorough`, `economy`, `research`.
|
||||||
|
- Profile overrides individual config fields.
|
||||||
|
- User-defined profiles stored in `~/.prometheus/profiles/`.
|
||||||
|
|
||||||
|
**Surface:** `--profile` CLI flag. `prometheus profile list` / `prometheus profile create` subcommands.
|
||||||
|
|
||||||
|
**Scope:** `cli/app.py`, new `ProfileManager` in application.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Summary Table
|
||||||
|
|
||||||
|
| # | Feature | Priority | CLI Surface | Config Surface | Estimated Scope |
|
||||||
|
|---|---------|----------|-------------|----------------|-----------------|
|
||||||
|
| 1 | Multi-Model Routing | P1 | Existing | Existing | Small |
|
||||||
|
| 2 | Async / Parallel Execution | P1 | `--max-concurrency` | `max_concurrency` | Medium |
|
||||||
|
| 3 | Error Handling & Retry | P1 | `--max-retries`, `--error-strategy` | `max_retries`, `error_strategy` | Medium |
|
||||||
|
| 4 | Checkpoint & Resume | P2 | `--checkpoint-dir`, `--resume` | `checkpoint_interval` | Medium |
|
||||||
|
| 5 | Population-Based Evolution | P2 | `--population-size`, `--crossover-rate` | `population_size`, `crossover_rate` | Large |
|
||||||
|
| 6 | Hold-Out Validation | P2 | `--validation-split`, `--early-stop-patience` | `validation_split`, `early_stop_patience` | Medium |
|
||||||
|
| 7 | Custom Judge Criteria | P2 | `--judge-criteria` | `judge_criteria`, `judge_dimensions` | Medium |
|
||||||
|
| 8 | Real-World Eval Harness | P2 | `--eval-dataset`, `--eval-metric` | `eval_dataset_path` | Large |
|
||||||
|
| 9 | Logging & Observability | P2 | `--debug`, `--log-format`, `--log-file` | `log_level`, `log_format` | Medium |
|
||||||
|
| 10 | CLI Improvements | P2 | Subcommands | — | Medium |
|
||||||
|
| 11 | Input Validation | P2 | — (error messages) | — | Small |
|
||||||
|
| 12 | Adaptive Minibatch | P3 | `--adaptive-minibatch` | `minibatch_size_min/max` | Small |
|
||||||
|
| 13 | Prompt Diversity Tracking | P3 | — | — | Small |
|
||||||
|
| 14 | Temperature & Sampling | P3 | `--temperature` | `*_temperature` | Small |
|
||||||
|
| 15 | Cost Estimation | P3 | `--max-cost-usd` | `max_cost_usd` | Small |
|
||||||
|
| 16 | Multi-Objective Optimization | P3 | `--objective` | `objectives` | Large |
|
||||||
|
| 17 | Export Optimized Prompt | P3 | `prometheus export` | — | Small |
|
||||||
|
| 18 | Config Profiles / Presets | P3 | `--profile` | — | Small |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Known Bugs (from TEST_REPORT.md and code review)
|
||||||
|
|
||||||
|
| # | Bug | Severity | File |
|
||||||
|
|---|-----|----------|------|
|
||||||
|
| 1 | Multi-model config not wired — all adapters use single global LM | HIGH | `cli/app.py`, all adapters |
|
||||||
|
| 2 | `DSPyLLMAdapter` accepts `model` param but never uses it | HIGH | `infrastructure/llm_adapter.py` |
|
||||||
|
| 3 | CLI subcommand `optimize` absorbed by Typer 0.24 | HIGH | `cli/app.py` |
|
||||||
|
| 4 | Verbose logging produces no output — no handler configured | MEDIUM | `cli/app.py` |
|
||||||
|
| 5 | `total_llm_calls` counter is inaccurate | LOW | `application/use_cases.py`, `evolution.py` |
|
||||||
|
| 6 | `normalize_score()` is dead code — never called | LOW | `domain/scoring.py` |
|
||||||
|
| 7 | `AppSettings` is never imported or used | LOW | `config.py` |
|
||||||
|
| 8 | No LLM error handling in evolution loop | MEDIUM | `evolution.py` |
|
||||||
|
| 9 | Unpinned dependencies (dspy, typer) | LOW | `pyproject.toml` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Test Coverage Gaps
|
||||||
|
|
||||||
|
| Area | Current | Needed |
|
||||||
|
|------|---------|--------|
|
||||||
|
| CLI commands | 0 tests | Unit + integration for each subcommand |
|
||||||
|
| Config validation | 0 tests | Schema validation, missing fields, type errors |
|
||||||
|
| Evolution loop | 3 tests (single iteration each) | Multi-iteration, mixed accept/reject, failure recovery |
|
||||||
|
| Integration pipeline | 1 test (happy path only) | Error paths, mixed results, real adapters |
|
||||||
|
| Adapter coverage | 1 adapter tested | All 4 adapters + error scenarios |
|
||||||
|
| Use case orchestration | 1 indirect test | Direct unit tests for `OptimizePromptUseCase` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Recommended Implementation Order
|
||||||
|
|
||||||
|
### Phase 1 — Production Reliability (P1)
|
||||||
|
1. Fix multi-model routing (#1) — highest impact, smallest scope
|
||||||
|
2. Add error handling & retry (#3) — essential for production runs
|
||||||
|
3. Implement async/parallel execution (#2) — biggest wall-clock improvement
|
||||||
|
|
||||||
|
### Phase 2 — Optimization Quality (P2)
|
||||||
|
4. Input validation (#11) — small scope, high reliability gain
|
||||||
|
5. Logging & observability (#9) — enables debugging long runs
|
||||||
|
6. CLI improvements (#10) — fix Typer bug, add basic commands
|
||||||
|
7. Hold-out validation (#6) — prevents overfitting
|
||||||
|
8. Checkpoint & resume (#4) — essential for long runs
|
||||||
|
9. Custom judge criteria (#7) — enables domain-specific optimization
|
||||||
|
|
||||||
|
### Phase 3 — Advanced Features (P3)
|
||||||
|
10. Population-based evolution (#5)
|
||||||
|
11. Real-world eval harness (#8)
|
||||||
|
12. Remaining P3 features as demand dictates
|
||||||
@@ -5,12 +5,12 @@ description = "Prompt evolution without reference data"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"dspy>=2.6,<3.0",
|
"dspy==2.6.27",
|
||||||
"typer>=0.15,<0.20",
|
"typer==0.19.2",
|
||||||
"pydantic>=2.10",
|
"pydantic==2.12.5",
|
||||||
"pydantic-settings>=2.7",
|
"pydantic-settings==2.13.1",
|
||||||
"pyyaml>=6.0",
|
"pyyaml==6.0.3",
|
||||||
"rich>=13.9",
|
"rich==14.3.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -46,6 +46,6 @@ module = ["dspy", "dspy.*"]
|
|||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = ["prometheus.infrastructure.*", "prometheus.cli.app"]
|
module = ["prometheus.infrastructure.*", "prometheus.cli.app", "prometheus.cli.commands.*"]
|
||||||
disable_error_code = ["misc", "import-untyped"]
|
disable_error_code = ["misc", "import-untyped"]
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,24 @@ class SyntheticBootstrap:
|
|||||||
self._generator = generator
|
self._generator = generator
|
||||||
self._rng = random.Random(seed)
|
self._rng = random.Random(seed)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_pool(
|
||||||
|
pool: list[SyntheticExample],
|
||||||
|
validation_fraction: float,
|
||||||
|
rng: random.Random | None = None,
|
||||||
|
) -> tuple[list[SyntheticExample], list[SyntheticExample]]:
|
||||||
|
"""Split *pool* into (train, validation) sets.
|
||||||
|
|
||||||
|
Returns (pool, []) when *validation_fraction* is 0.
|
||||||
|
"""
|
||||||
|
if validation_fraction <= 0.0 or len(pool) < 2:
|
||||||
|
return pool, []
|
||||||
|
n_val = max(1, int(len(pool) * validation_fraction))
|
||||||
|
shuffled = list(pool)
|
||||||
|
_rng = rng or random.Random(42)
|
||||||
|
_rng.shuffle(shuffled)
|
||||||
|
return shuffled[:-n_val], shuffled[-n_val:]
|
||||||
|
|
||||||
def run(self, task_description: str, n_examples: int) -> list[SyntheticExample]:
|
def run(self, task_description: str, n_examples: int) -> list[SyntheticExample]:
|
||||||
"""Generate the synthetic pool in a single call.
|
"""Generate the synthetic pool in a single call.
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from prometheus.domain.entities import (
|
|||||||
Trajectory,
|
Trajectory,
|
||||||
)
|
)
|
||||||
from prometheus.domain.ports import JudgePort, LLMPort
|
from prometheus.domain.ports import JudgePort, LLMPort
|
||||||
|
from prometheus.domain.scoring import normalize_score
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -72,6 +73,7 @@ class PromptEvaluator:
|
|||||||
trajectories: list[Trajectory] = []
|
trajectories: list[Trajectory] = []
|
||||||
for i, (example, output) in enumerate(zip(minibatch, outputs)):
|
for i, (example, output) in enumerate(zip(minibatch, outputs)):
|
||||||
score, feedback = judge_results[i]
|
score, feedback = judge_results[i]
|
||||||
|
score = normalize_score(score)
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
feedbacks.append(feedback)
|
feedbacks.append(feedback)
|
||||||
trajectories.append(
|
trajectories.append(
|
||||||
|
|||||||
@@ -2,24 +2,33 @@
|
|||||||
Evolution loop — core PROMETHEUS engine.
|
Evolution loop — core PROMETHEUS engine.
|
||||||
|
|
||||||
Orchestrates the select → evaluate → propose → accept cycle.
|
Orchestrates the select → evaluate → propose → accept cycle.
|
||||||
Equivalent to GEPAEngine.run(), adapted to work without a valset.
|
Supports two modes:
|
||||||
|
- Single-candidate hill climbing (population_size=1, backward compat)
|
||||||
|
- Population-based evolution with crossover & mutation (population_size>1)
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||||
from prometheus.application.evaluator import PromptEvaluator
|
from prometheus.application.evaluator import PromptEvaluator
|
||||||
|
from prometheus.cli.logging_setup import get_logger
|
||||||
from prometheus.domain.entities import (
|
from prometheus.domain.entities import (
|
||||||
Candidate,
|
Candidate,
|
||||||
OptimizationState,
|
OptimizationState,
|
||||||
Prompt,
|
Prompt,
|
||||||
SyntheticExample,
|
SyntheticExample,
|
||||||
)
|
)
|
||||||
from prometheus.domain.ports import ProposerPort
|
from prometheus.domain.ports import (
|
||||||
|
CheckpointPort,
|
||||||
|
CrossoverPort,
|
||||||
|
MutationPort,
|
||||||
|
ProposerPort,
|
||||||
|
)
|
||||||
from prometheus.domain.scoring import should_accept
|
from prometheus.domain.scoring import should_accept
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger("evolution")
|
||||||
|
|
||||||
|
|
||||||
class CircuitBreakerOpen(Exception):
|
class CircuitBreakerOpen(Exception):
|
||||||
@@ -30,9 +39,9 @@ class EvolutionLoop:
|
|||||||
"""Main evolution loop.
|
"""Main evolution loop.
|
||||||
|
|
||||||
Design:
|
Design:
|
||||||
- Keeps only the best candidate (no full population).
|
- population_size=1: classic single-candidate hill climbing (backward compat).
|
||||||
- Simplifies vs GEPA (no Pareto, no merge).
|
- population_size>1: population-based evolution with crossover, mutation,
|
||||||
- Population support deferred to v2.
|
and diversity maintenance.
|
||||||
|
|
||||||
Error handling:
|
Error handling:
|
||||||
- Transient errors are retried by adapters.
|
- Transient errors are retried by adapters.
|
||||||
@@ -51,6 +60,17 @@ class EvolutionLoop:
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
circuit_breaker_threshold: int = 5,
|
circuit_breaker_threshold: int = 5,
|
||||||
error_strategy: str = "retry",
|
error_strategy: str = "retry",
|
||||||
|
checkpoint_port: CheckpointPort | None = None,
|
||||||
|
checkpoint_interval: int = 5,
|
||||||
|
# --- Population-based evolution params ---
|
||||||
|
population_size: int = 1,
|
||||||
|
crossover_rate: float = 0.5,
|
||||||
|
mutation_rate: float = 0.3,
|
||||||
|
diversity_penalty: float = 0.1,
|
||||||
|
crossover_port: CrossoverPort | None = None,
|
||||||
|
mutation_port: MutationPort | None = None,
|
||||||
|
# --- Hold-out validation params ---
|
||||||
|
early_stop_patience: int = 5,
|
||||||
):
|
):
|
||||||
self._evaluator = evaluator
|
self._evaluator = evaluator
|
||||||
self._proposer = proposer
|
self._proposer = proposer
|
||||||
@@ -61,18 +81,44 @@ class EvolutionLoop:
|
|||||||
self._verbose = verbose
|
self._verbose = verbose
|
||||||
self._circuit_breaker_threshold = circuit_breaker_threshold
|
self._circuit_breaker_threshold = circuit_breaker_threshold
|
||||||
self._error_strategy = error_strategy
|
self._error_strategy = error_strategy
|
||||||
|
self._checkpoint_port = checkpoint_port
|
||||||
|
self._checkpoint_interval = checkpoint_interval
|
||||||
|
self._population_size = population_size
|
||||||
|
self._crossover_rate = crossover_rate
|
||||||
|
self._mutation_rate = mutation_rate
|
||||||
|
self._diversity_penalty = diversity_penalty
|
||||||
|
self._crossover_port = crossover_port
|
||||||
|
self._mutation_port = mutation_port
|
||||||
|
self._early_stop_patience = early_stop_patience
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
seed_prompt: Prompt,
|
seed_prompt: Prompt,
|
||||||
synthetic_pool: list[SyntheticExample],
|
synthetic_pool: list[SyntheticExample],
|
||||||
task_description: str,
|
task_description: str,
|
||||||
|
initial_state: OptimizationState | None = None,
|
||||||
|
validation_pool: list[SyntheticExample] | None = None,
|
||||||
) -> OptimizationState:
|
) -> OptimizationState:
|
||||||
"""Execute the complete evolution loop."""
|
"""Execute the complete evolution loop.
|
||||||
state = OptimizationState()
|
|
||||||
|
If *initial_state* is provided (from a checkpoint), resume from that
|
||||||
|
point — skipping the seed evaluation and continuing at the saved iteration.
|
||||||
|
|
||||||
|
If *validation_pool* is provided (non-empty), the best candidate is
|
||||||
|
evaluated on the hold-out set after each iteration and early stopping
|
||||||
|
is applied when validation score degrades for ``early_stop_patience``
|
||||||
|
consecutive iterations.
|
||||||
|
"""
|
||||||
|
state = initial_state or OptimizationState()
|
||||||
consecutive_failures = 0
|
consecutive_failures = 0
|
||||||
|
|
||||||
# Evaluate the seed
|
# Hold-out validation tracking
|
||||||
|
has_validation = bool(validation_pool)
|
||||||
|
best_validation_score: float = -1.0
|
||||||
|
validation_patience_counter: int = 0
|
||||||
|
|
||||||
|
# Only evaluate the seed when starting fresh (no checkpoint resume)
|
||||||
|
if initial_state is None:
|
||||||
initial_batch = self._bootstrap.sample_minibatch(
|
initial_batch = self._bootstrap.sample_minibatch(
|
||||||
synthetic_pool, self._minibatch_size
|
synthetic_pool, self._minibatch_size
|
||||||
)
|
)
|
||||||
@@ -81,31 +127,162 @@ class EvolutionLoop:
|
|||||||
)
|
)
|
||||||
state.total_llm_calls += 2 * self._minibatch_size # N executions + N judge calls
|
state.total_llm_calls += 2 * self._minibatch_size # N executions + N judge calls
|
||||||
|
|
||||||
best_candidate = Candidate(
|
seed_candidate = Candidate(
|
||||||
prompt=seed_prompt,
|
prompt=seed_prompt,
|
||||||
best_score=initial_eval.total_score,
|
best_score=initial_eval.total_score,
|
||||||
generation=0,
|
generation=0,
|
||||||
)
|
)
|
||||||
state.best_candidate = best_candidate
|
state.best_candidate = seed_candidate
|
||||||
state.candidates.append(best_candidate)
|
state.candidates.append(seed_candidate)
|
||||||
self._log(f"Initial score: {initial_eval.total_score:.2f}")
|
logger.info(
|
||||||
|
"Initial evaluation complete",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "initial_eval",
|
||||||
|
"score": round(initial_eval.total_score, 4),
|
||||||
|
"minibatch_size": self._minibatch_size,
|
||||||
|
"sample_ids": [ex.id for ex in initial_batch],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate seed on validation set
|
||||||
|
if has_validation and state.best_candidate is not None:
|
||||||
|
val_eval = await self._evaluator.evaluate(
|
||||||
|
state.best_candidate.prompt, validation_pool, task_description
|
||||||
|
)
|
||||||
|
state.total_llm_calls += 2 * len(validation_pool)
|
||||||
|
best_validation_score = val_eval.mean_score
|
||||||
|
logger.info(
|
||||||
|
"Initial validation evaluation",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "validation_eval",
|
||||||
|
"iteration": 0,
|
||||||
|
"validation_score": round(best_validation_score, 4),
|
||||||
|
"validation_pool_size": len(validation_pool),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Population initialization: seed the population with mutations
|
||||||
|
if self._population_size > 1:
|
||||||
|
await self._initialize_population(
|
||||||
|
state, seed_prompt, seed_candidate, task_description
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Resuming from checkpoint",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "resume",
|
||||||
|
"iteration": state.iteration,
|
||||||
|
"total_llm_calls": state.total_llm_calls,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Restore validation tracking from state history
|
||||||
|
if has_validation:
|
||||||
|
for entry in reversed(state.history):
|
||||||
|
if entry.get("event") == "validation_eval":
|
||||||
|
best_validation_score = entry.get("best_validation_score", -1.0)
|
||||||
|
validation_patience_counter = entry.get("validation_patience", 0)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Determine starting iteration
|
||||||
|
start_iteration = state.iteration + 1
|
||||||
|
|
||||||
# Main loop
|
# Main loop
|
||||||
for i in range(1, self._max_iterations + 1):
|
for i in range(start_iteration, self._max_iterations + 1):
|
||||||
state.iteration = i
|
state.iteration = i
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._run_iteration(
|
if self._population_size > 1 and len(state.candidates) > 1:
|
||||||
i, state, best_candidate, synthetic_pool, task_description
|
await self._run_population_iteration(
|
||||||
|
i, state, synthetic_pool, task_description
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._run_single_iteration(
|
||||||
|
i, state, synthetic_pool, task_description
|
||||||
)
|
)
|
||||||
# Update best_candidate from state after successful iteration
|
|
||||||
best_candidate = state.best_candidate # type: ignore[assignment]
|
|
||||||
consecutive_failures = 0
|
consecutive_failures = 0
|
||||||
|
|
||||||
|
# Hold-out validation: evaluate best candidate on validation set
|
||||||
|
if has_validation and state.best_candidate is not None:
|
||||||
|
val_eval = await self._evaluator.evaluate(
|
||||||
|
state.best_candidate.prompt, validation_pool, task_description
|
||||||
|
)
|
||||||
|
state.total_llm_calls += 2 * len(validation_pool)
|
||||||
|
current_val_score = val_eval.mean_score
|
||||||
|
|
||||||
|
if current_val_score > best_validation_score:
|
||||||
|
best_validation_score = current_val_score
|
||||||
|
validation_patience_counter = 0
|
||||||
|
else:
|
||||||
|
validation_patience_counter += 1
|
||||||
|
|
||||||
|
state.history.append({
|
||||||
|
"iteration": i,
|
||||||
|
"event": "validation_eval",
|
||||||
|
"validation_score": round(current_val_score, 4),
|
||||||
|
"best_validation_score": round(best_validation_score, 4),
|
||||||
|
"validation_patience": validation_patience_counter,
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Validation evaluation",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "validation_eval",
|
||||||
|
"iteration": i,
|
||||||
|
"validation_score": round(current_val_score, 4),
|
||||||
|
"best_validation_score": round(best_validation_score, 4),
|
||||||
|
"patience": f"{validation_patience_counter}/{self._early_stop_patience}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if validation_patience_counter >= self._early_stop_patience:
|
||||||
|
logger.warning(
|
||||||
|
"Early stopping triggered — validation score did not improve for %d iterations",
|
||||||
|
self._early_stop_patience,
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "early_stop",
|
||||||
|
"iteration": i,
|
||||||
|
"best_validation_score": round(best_validation_score, 4),
|
||||||
|
"patience": self._early_stop_patience,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
state.history.append({
|
||||||
|
"iteration": i,
|
||||||
|
"event": "early_stop",
|
||||||
|
"best_validation_score": round(best_validation_score, 4),
|
||||||
|
"patience": self._early_stop_patience,
|
||||||
|
})
|
||||||
|
state.best_validation_score = best_validation_score
|
||||||
|
state.early_stopped = True
|
||||||
|
if self._checkpoint_port is not None:
|
||||||
|
self._checkpoint_port.save(state)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Checkpoint on accepted improvement (detected via state change)
|
||||||
|
self._maybe_checkpoint(state)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
consecutive_failures += 1
|
consecutive_failures += 1
|
||||||
self._log(
|
logger.error(
|
||||||
f"Iter {i}: ERROR ({consecutive_failures} consecutive) — {exc}"
|
"Iteration error",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "iteration_error",
|
||||||
|
"iteration": i,
|
||||||
|
"consecutive_failures": consecutive_failures,
|
||||||
|
"error": str(exc),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
state.history.append(
|
state.history.append(
|
||||||
{
|
{
|
||||||
@@ -118,9 +295,16 @@ class EvolutionLoop:
|
|||||||
|
|
||||||
# Check circuit breaker
|
# Check circuit breaker
|
||||||
if consecutive_failures >= self._circuit_breaker_threshold:
|
if consecutive_failures >= self._circuit_breaker_threshold:
|
||||||
self._log(
|
logger.warning(
|
||||||
f"Circuit breaker tripped after {consecutive_failures} "
|
"Circuit breaker tripped",
|
||||||
f"consecutive failures."
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "circuit_breaker",
|
||||||
|
"iteration": i,
|
||||||
|
"consecutive_failures": consecutive_failures,
|
||||||
|
"error_strategy": self._error_strategy,
|
||||||
|
},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
state.history.append(
|
state.history.append(
|
||||||
{
|
{
|
||||||
@@ -134,7 +318,9 @@ class EvolutionLoop:
|
|||||||
f"Circuit breaker tripped after "
|
f"Circuit breaker tripped after "
|
||||||
f"{consecutive_failures} consecutive failures"
|
f"{consecutive_failures} consecutive failures"
|
||||||
) from exc
|
) from exc
|
||||||
# skip / retry strategies: stop the loop gracefully
|
# skip / retry strategies: save checkpoint, then stop the loop gracefully
|
||||||
|
if self._checkpoint_port is not None:
|
||||||
|
self._checkpoint_port.save(state)
|
||||||
break
|
break
|
||||||
|
|
||||||
if self._error_strategy == "abort":
|
if self._error_strategy == "abort":
|
||||||
@@ -142,21 +328,77 @@ class EvolutionLoop:
|
|||||||
# skip / retry: continue to next iteration
|
# skip / retry: continue to next iteration
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Store final validation metadata on state
|
||||||
|
if has_validation:
|
||||||
|
state.best_validation_score = best_validation_score
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
async def _run_iteration(
|
# ------------------------------------------------------------------
|
||||||
|
# Population initialization
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _initialize_population(
|
||||||
|
self,
|
||||||
|
state: OptimizationState,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
seed_candidate: Candidate,
|
||||||
|
task_description: str,
|
||||||
|
) -> None:
|
||||||
|
"""Fill the population with mutated variants of the seed prompt."""
|
||||||
|
n_needed = self._population_size - 1
|
||||||
|
mutation_types = ["paraphrase", "constrain", "generalize", "specialize"]
|
||||||
|
|
||||||
|
for idx in range(n_needed):
|
||||||
|
mutation_type = mutation_types[idx % len(mutation_types)]
|
||||||
|
|
||||||
|
if self._mutation_port is not None:
|
||||||
|
new_prompt = await self._mutation_port.mutate(
|
||||||
|
seed_prompt, task_description, mutation_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback: use proposer for reflective mutation
|
||||||
|
new_prompt = await self._proposer.propose(
|
||||||
|
seed_prompt, [], task_description
|
||||||
|
)
|
||||||
|
state.total_llm_calls += 1
|
||||||
|
new_candidate = Candidate(
|
||||||
|
prompt=new_prompt,
|
||||||
|
best_score=seed_candidate.best_score, # estimate until evaluated
|
||||||
|
generation=0,
|
||||||
|
parent_id=id(seed_candidate),
|
||||||
|
)
|
||||||
|
state.candidates.append(new_candidate)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Population initialized",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "population_init",
|
||||||
|
"population_size": len(state.candidates),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Single-candidate iteration (original hill-climbing)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _run_single_iteration(
|
||||||
self,
|
self,
|
||||||
i: int,
|
i: int,
|
||||||
state: OptimizationState,
|
state: OptimizationState,
|
||||||
best_candidate: Candidate,
|
|
||||||
synthetic_pool: list[SyntheticExample],
|
synthetic_pool: list[SyntheticExample],
|
||||||
task_description: str,
|
task_description: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute a single iteration. Mutates *state* in-place."""
|
"""Execute a single-candidate iteration. Mutates *state* in-place."""
|
||||||
|
best_candidate = state.best_candidate # type: ignore[assignment]
|
||||||
|
|
||||||
# 1. Sample a fresh minibatch
|
# 1. Sample a fresh minibatch
|
||||||
batch = self._bootstrap.sample_minibatch(
|
batch = self._bootstrap.sample_minibatch(
|
||||||
synthetic_pool, self._minibatch_size
|
synthetic_pool, self._minibatch_size
|
||||||
)
|
)
|
||||||
|
sample_ids = [ex.id for ex in batch]
|
||||||
|
|
||||||
# 2. Evaluate the current candidate
|
# 2. Evaluate the current candidate
|
||||||
current_eval = await self._evaluator.evaluate(
|
current_eval = await self._evaluator.evaluate(
|
||||||
@@ -164,9 +406,31 @@ class EvolutionLoop:
|
|||||||
)
|
)
|
||||||
state.total_llm_calls += 2 * self._minibatch_size
|
state.total_llm_calls += 2 * self._minibatch_size
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Iteration minibatch evaluated",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "minibatch_eval",
|
||||||
|
"iteration": i,
|
||||||
|
"sample_ids": sample_ids,
|
||||||
|
"scores": [round(s, 4) for s in current_eval.scores],
|
||||||
|
"total_score": round(current_eval.total_score, 4),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Skip if perfect
|
# 3. Skip if perfect
|
||||||
if all(s >= self._perfect_score for s in current_eval.scores):
|
if all(s >= self._perfect_score for s in current_eval.scores):
|
||||||
self._log(f"Iter {i}: All scores perfect, skipping.")
|
logger.info(
|
||||||
|
"Iteration skipped — all scores perfect",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "skip_perfect",
|
||||||
|
"iteration": i,
|
||||||
|
"total_score": round(current_eval.total_score, 4),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
state.history.append(
|
state.history.append(
|
||||||
{
|
{
|
||||||
"iteration": i,
|
"iteration": i,
|
||||||
@@ -177,12 +441,26 @@ class EvolutionLoop:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 4. Propose a new prompt (reflective mutation) — sequential
|
# 4. Propose a new prompt (reflective mutation) — sequential
|
||||||
|
state.total_llm_calls += 1
|
||||||
new_prompt = await self._proposer.propose(
|
new_prompt = await self._proposer.propose(
|
||||||
best_candidate.prompt,
|
best_candidate.prompt,
|
||||||
current_eval.trajectories,
|
current_eval.trajectories,
|
||||||
task_description,
|
task_description,
|
||||||
)
|
)
|
||||||
state.total_llm_calls += 1 # 1 proposition call
|
|
||||||
|
prompt_diff = self._compute_prompt_diff(
|
||||||
|
best_candidate.prompt.text, new_prompt.text
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Proposed new prompt",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "proposer_output",
|
||||||
|
"iteration": i,
|
||||||
|
"prompt_diff": prompt_diff,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# 5. Evaluate the new prompt on the same minibatch
|
# 5. Evaluate the new prompt on the same minibatch
|
||||||
new_eval = await self._evaluator.evaluate(
|
new_eval = await self._evaluator.evaluate(
|
||||||
@@ -200,9 +478,22 @@ class EvolutionLoop:
|
|||||||
)
|
)
|
||||||
state.best_candidate = new_candidate
|
state.best_candidate = new_candidate
|
||||||
state.candidates.append(new_candidate)
|
state.candidates.append(new_candidate)
|
||||||
self._log(
|
logger.info(
|
||||||
f"Iter {i}: ACCEPTED "
|
"Iteration accepted",
|
||||||
f"({current_eval.total_score:.2f} -> {new_eval.total_score:.2f})"
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "accepted",
|
||||||
|
"iteration": i,
|
||||||
|
"old_score": round(current_eval.total_score, 4),
|
||||||
|
"new_score": round(new_eval.total_score, 4),
|
||||||
|
"improvement": round(
|
||||||
|
new_eval.total_score - current_eval.total_score, 4
|
||||||
|
),
|
||||||
|
"sample_ids": sample_ids,
|
||||||
|
"new_scores": [round(s, 4) for s in new_eval.scores],
|
||||||
|
"prompt_diff": prompt_diff,
|
||||||
|
},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
state.history.append(
|
state.history.append(
|
||||||
{
|
{
|
||||||
@@ -215,9 +506,19 @@ class EvolutionLoop:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._log(
|
logger.info(
|
||||||
f"Iter {i}: REJECTED "
|
"Iteration rejected",
|
||||||
f"({new_eval.total_score:.2f} <= {current_eval.total_score:.2f})"
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "rejected",
|
||||||
|
"iteration": i,
|
||||||
|
"old_score": round(current_eval.total_score, 4),
|
||||||
|
"new_score": round(new_eval.total_score, 4),
|
||||||
|
"sample_ids": sample_ids,
|
||||||
|
"new_scores": [round(s, 4) for s in new_eval.scores],
|
||||||
|
"prompt_diff": prompt_diff,
|
||||||
|
},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
state.history.append(
|
state.history.append(
|
||||||
{
|
{
|
||||||
@@ -228,6 +529,213 @@ class EvolutionLoop:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _log(self, msg: str) -> None:
|
# ------------------------------------------------------------------
|
||||||
if self._verbose:
|
# Population-based iteration
|
||||||
logger.info("[PROMETHEUS] %s", msg)
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _run_population_iteration(
|
||||||
|
self,
|
||||||
|
i: int,
|
||||||
|
state: OptimizationState,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a population-based iteration. Mutates *state* in-place."""
|
||||||
|
population = state.candidates
|
||||||
|
|
||||||
|
# 1. Sample a fresh minibatch
|
||||||
|
batch = self._bootstrap.sample_minibatch(
|
||||||
|
synthetic_pool, self._minibatch_size
|
||||||
|
)
|
||||||
|
sample_ids = [ex.id for ex in batch]
|
||||||
|
|
||||||
|
# 2. Select two parents via tournament selection
|
||||||
|
parent_a = self._tournament_select(population)
|
||||||
|
parent_b = self._tournament_select(population)
|
||||||
|
|
||||||
|
# 3. Generate child: crossover or reflective mutation
|
||||||
|
use_crossover = (
|
||||||
|
random.random() < self._crossover_rate
|
||||||
|
and self._crossover_port is not None
|
||||||
|
)
|
||||||
|
if use_crossover:
|
||||||
|
state.total_llm_calls += 1
|
||||||
|
child_prompt = await self._crossover_port.crossover(
|
||||||
|
parent_a.prompt, parent_b.prompt, task_description
|
||||||
|
)
|
||||||
|
origin = "crossover"
|
||||||
|
else:
|
||||||
|
# Reflective mutation: evaluate a parent, propose improvement
|
||||||
|
state.total_llm_calls += 2 * self._minibatch_size
|
||||||
|
parent_eval = await self._evaluator.evaluate(
|
||||||
|
parent_a.prompt, batch, task_description
|
||||||
|
)
|
||||||
|
state.total_llm_calls += 1
|
||||||
|
child_prompt = await self._proposer.propose(
|
||||||
|
parent_a.prompt,
|
||||||
|
parent_eval.trajectories,
|
||||||
|
task_description,
|
||||||
|
)
|
||||||
|
origin = "reflective"
|
||||||
|
|
||||||
|
# 4. Optional mutation
|
||||||
|
if random.random() < self._mutation_rate and self._mutation_port is not None:
|
||||||
|
mutation_type = random.choice(
|
||||||
|
["paraphrase", "constrain", "generalize", "specialize"]
|
||||||
|
)
|
||||||
|
state.total_llm_calls += 1
|
||||||
|
child_prompt = await self._mutation_port.mutate(
|
||||||
|
child_prompt, task_description, mutation_type
|
||||||
|
)
|
||||||
|
origin += f"+mutation({mutation_type})"
|
||||||
|
|
||||||
|
# 5. Evaluate the child
|
||||||
|
state.total_llm_calls += 2 * self._minibatch_size
|
||||||
|
child_eval = await self._evaluator.evaluate(
|
||||||
|
child_prompt, batch, task_description
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Compute fitness with diversity penalty
|
||||||
|
child_score = child_eval.total_score
|
||||||
|
diversity_sim = self._compute_diversity_score(
|
||||||
|
child_prompt, population
|
||||||
|
)
|
||||||
|
child_fitness = child_score - self._diversity_penalty * (1.0 - diversity_sim)
|
||||||
|
|
||||||
|
# 7. Find worst candidate and replace if child is better
|
||||||
|
worst_idx = min(
|
||||||
|
range(len(population)),
|
||||||
|
key=lambda idx: population[idx].best_score,
|
||||||
|
)
|
||||||
|
worst_fitness = (
|
||||||
|
population[worst_idx].best_score
|
||||||
|
- self._diversity_penalty * (1.0 - self._compute_diversity_score(
|
||||||
|
population[worst_idx].prompt, population
|
||||||
|
))
|
||||||
|
)
|
||||||
|
|
||||||
|
accepted = child_fitness > worst_fitness
|
||||||
|
|
||||||
|
if accepted:
|
||||||
|
new_candidate = Candidate(
|
||||||
|
prompt=child_prompt,
|
||||||
|
best_score=child_score,
|
||||||
|
generation=i,
|
||||||
|
parent_id=id(parent_a),
|
||||||
|
)
|
||||||
|
population[worst_idx] = new_candidate
|
||||||
|
|
||||||
|
# Update best if this child is the new best
|
||||||
|
if child_score > (state.best_candidate.best_score if state.best_candidate else 0):
|
||||||
|
state.best_candidate = new_candidate
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Population iteration accepted",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "pop_accepted",
|
||||||
|
"iteration": i,
|
||||||
|
"origin": origin,
|
||||||
|
"child_score": round(child_score, 4),
|
||||||
|
"child_fitness": round(child_fitness, 4),
|
||||||
|
"diversity_sim": round(diversity_sim, 4),
|
||||||
|
"replaced_idx": worst_idx,
|
||||||
|
"sample_ids": sample_ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
state.history.append(
|
||||||
|
{
|
||||||
|
"iteration": i,
|
||||||
|
"event": "pop_accepted",
|
||||||
|
"origin": origin,
|
||||||
|
"child_score": child_score,
|
||||||
|
"child_fitness": child_fitness,
|
||||||
|
"diversity_sim": diversity_sim,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Population iteration rejected",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "pop_rejected",
|
||||||
|
"iteration": i,
|
||||||
|
"origin": origin,
|
||||||
|
"child_score": round(child_score, 4),
|
||||||
|
"child_fitness": round(child_fitness, 4),
|
||||||
|
"worst_fitness": round(worst_fitness, 4),
|
||||||
|
"sample_ids": sample_ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
state.history.append(
|
||||||
|
{
|
||||||
|
"iteration": i,
|
||||||
|
"event": "pop_rejected",
|
||||||
|
"origin": origin,
|
||||||
|
"child_score": child_score,
|
||||||
|
"child_fitness": child_fitness,
|
||||||
|
"worst_fitness": worst_fitness,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Selection and diversity helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _tournament_select(
|
||||||
|
self,
|
||||||
|
population: list[Candidate],
|
||||||
|
tournament_size: int = 3,
|
||||||
|
) -> Candidate:
|
||||||
|
"""Tournament selection: pick the best from a random subset."""
|
||||||
|
k = min(tournament_size, len(population))
|
||||||
|
contestants = random.sample(population, k)
|
||||||
|
return max(contestants, key=lambda c: c.best_score)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_diversity_score(
|
||||||
|
prompt: Prompt,
|
||||||
|
population: list[Candidate],
|
||||||
|
) -> float:
|
||||||
|
"""Compute the average Jaccard similarity between *prompt* and all
|
||||||
|
population members. Returns 1.0 when population has only one member
|
||||||
|
(no diversity penalty)."""
|
||||||
|
if len(population) <= 1:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
prompt_words = set(prompt.text.lower().split())
|
||||||
|
if not prompt_words:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
similarities: list[float] = []
|
||||||
|
for candidate in population:
|
||||||
|
other_words = set(candidate.prompt.text.lower().split())
|
||||||
|
if not other_words:
|
||||||
|
continue
|
||||||
|
intersection = prompt_words & other_words
|
||||||
|
union = prompt_words | other_words
|
||||||
|
sim = len(intersection) / len(union) if union else 0.0
|
||||||
|
similarities.append(sim)
|
||||||
|
|
||||||
|
# Average similarity (lower = more diverse)
|
||||||
|
return sum(similarities) / len(similarities) if similarities else 0.0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_prompt_diff(old: str, new: str) -> dict[str, int]:
|
||||||
|
"""Compute a simple diff summary between two prompts."""
|
||||||
|
old_lines = set(old.splitlines())
|
||||||
|
new_lines = set(new.splitlines())
|
||||||
|
return {
|
||||||
|
"lines_added": len(new_lines - old_lines),
|
||||||
|
"lines_removed": len(old_lines - new_lines),
|
||||||
|
"chars_delta": len(new) - len(old),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _maybe_checkpoint(self, state: OptimizationState) -> None:
|
||||||
|
"""Save a checkpoint if the interval is met or on accepted improvements."""
|
||||||
|
if self._checkpoint_port is None:
|
||||||
|
return
|
||||||
|
if state.iteration % self._checkpoint_interval == 0:
|
||||||
|
self._checkpoint_port.save(state)
|
||||||
|
|||||||
116
src/prometheus/application/ground_truth_evaluator.py
Normal file
116
src/prometheus/application/ground_truth_evaluator.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""
|
||||||
|
Ground-truth evaluator — execution + similarity comparison.
|
||||||
|
|
||||||
|
Produces a quality signal *with* ground truth by comparing model outputs
|
||||||
|
against expected outputs using a configurable similarity metric.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from prometheus.domain.entities import (
|
||||||
|
EvalResult,
|
||||||
|
GroundTruthExample,
|
||||||
|
Prompt,
|
||||||
|
Trajectory,
|
||||||
|
)
|
||||||
|
from prometheus.domain.ports import LLMPort, SimilarityPort
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GroundTruthEvaluator:
|
||||||
|
"""Evaluates a prompt against a ground-truth dataset.
|
||||||
|
|
||||||
|
Pipeline: execute → compare with similarity metric → build trajectories.
|
||||||
|
Unlike PromptEvaluator (which uses LLM-as-Judge), this compares outputs
|
||||||
|
directly against known-good expected outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
executor: LLMPort,
|
||||||
|
similarity: SimilarityPort,
|
||||||
|
max_concurrency: int = 5,
|
||||||
|
):
|
||||||
|
self._executor = executor
|
||||||
|
self._similarity = similarity
|
||||||
|
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
|
async def evaluate(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
dataset: list[GroundTruthExample],
|
||||||
|
) -> EvalResult:
|
||||||
|
"""Evaluate the prompt on the ground-truth dataset.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Execute the prompt on each input (parallel, bounded)
|
||||||
|
2. Compare each output against expected using similarity metric
|
||||||
|
3. Build trajectories with feedback
|
||||||
|
"""
|
||||||
|
# Step 1: Parallel execution (per-item isolation)
|
||||||
|
output_coros = [
|
||||||
|
self._execute_single(prompt, example) for example in dataset
|
||||||
|
]
|
||||||
|
outputs = await asyncio.gather(*output_coros)
|
||||||
|
|
||||||
|
# Step 2: Compute similarity scores
|
||||||
|
scores: list[float] = []
|
||||||
|
feedbacks: list[str] = []
|
||||||
|
trajectories: list[Trajectory] = []
|
||||||
|
|
||||||
|
for example, output in zip(dataset, outputs):
|
||||||
|
score = self._similarity.compute(output, example.expected_output)
|
||||||
|
score = max(0.0, min(1.0, score)) # normalize to [0, 1]
|
||||||
|
scores.append(score)
|
||||||
|
feedback = self._build_feedback(output, example.expected_output, score)
|
||||||
|
feedbacks.append(feedback)
|
||||||
|
trajectories.append(
|
||||||
|
Trajectory(
|
||||||
|
input_text=example.input_text,
|
||||||
|
output_text=output,
|
||||||
|
score=score,
|
||||||
|
feedback=feedback,
|
||||||
|
prompt_used=prompt.text,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Ground-truth evaluation complete: %d items, mean_score=%.4f",
|
||||||
|
len(dataset),
|
||||||
|
sum(scores) / len(scores) if scores else 0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return EvalResult(
|
||||||
|
scores=scores,
|
||||||
|
feedbacks=feedbacks,
|
||||||
|
trajectories=trajectories,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_single(
|
||||||
|
self, prompt: Prompt, example: GroundTruthExample
|
||||||
|
) -> str:
|
||||||
|
async with self._semaphore:
|
||||||
|
try:
|
||||||
|
return await self._executor.execute(prompt, example.input_text)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Execution failed for input '%s…': %s",
|
||||||
|
example.input_text[:40],
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return f"[execution error: {exc}]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_feedback(output: str, expected: str, score: float) -> str:
|
||||||
|
"""Build human-readable feedback for a ground-truth comparison."""
|
||||||
|
if score >= 0.99:
|
||||||
|
return "Exact match."
|
||||||
|
elif score >= 0.7:
|
||||||
|
return f"Close match (score={score:.2f}). Expected: {expected[:100]}"
|
||||||
|
elif score >= 0.3:
|
||||||
|
return f"Partial match (score={score:.2f}). Expected: {expected[:100]}"
|
||||||
|
else:
|
||||||
|
return f"Poor match (score={score:.2f}). Expected: {expected[:100]}"
|
||||||
@@ -10,8 +10,16 @@ from prometheus.application.bootstrap import SyntheticBootstrap
|
|||||||
from prometheus.application.dto import OptimizationConfig, OptimizationResult
|
from prometheus.application.dto import OptimizationConfig, OptimizationResult
|
||||||
from prometheus.application.evaluator import PromptEvaluator
|
from prometheus.application.evaluator import PromptEvaluator
|
||||||
from prometheus.application.evolution import EvolutionLoop
|
from prometheus.application.evolution import EvolutionLoop
|
||||||
|
from prometheus.cli.logging_setup import get_logger
|
||||||
from prometheus.domain.entities import Prompt
|
from prometheus.domain.entities import Prompt
|
||||||
from prometheus.domain.ports import ProposerPort
|
from prometheus.domain.ports import (
|
||||||
|
CheckpointPort,
|
||||||
|
CrossoverPort,
|
||||||
|
MutationPort,
|
||||||
|
ProposerPort,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger("use_cases")
|
||||||
|
|
||||||
|
|
||||||
class OptimizePromptUseCase:
|
class OptimizePromptUseCase:
|
||||||
@@ -25,24 +33,60 @@ class OptimizePromptUseCase:
|
|||||||
evaluator: PromptEvaluator,
|
evaluator: PromptEvaluator,
|
||||||
proposer: ProposerPort,
|
proposer: ProposerPort,
|
||||||
bootstrap: SyntheticBootstrap,
|
bootstrap: SyntheticBootstrap,
|
||||||
|
checkpoint_port: CheckpointPort | None = None,
|
||||||
|
crossover_port: CrossoverPort | None = None,
|
||||||
|
mutation_port: MutationPort | None = None,
|
||||||
):
|
):
|
||||||
self._evaluator = evaluator
|
self._evaluator = evaluator
|
||||||
self._proposer = proposer
|
self._proposer = proposer
|
||||||
self._bootstrap = bootstrap
|
self._bootstrap = bootstrap
|
||||||
|
self._checkpoint_port = checkpoint_port
|
||||||
|
self._crossover_port = crossover_port
|
||||||
|
self._mutation_port = mutation_port
|
||||||
|
|
||||||
async def execute(self, config: OptimizationConfig) -> OptimizationResult:
|
async def execute(self, config: OptimizationConfig) -> OptimizationResult:
|
||||||
"""Full pipeline:
|
"""Full pipeline:
|
||||||
1. Bootstrap → generate synthetic inputs
|
1. Bootstrap → generate synthetic inputs
|
||||||
2. Evolution → optimization loop
|
2. Evolution → optimization loop (with optional checkpoint resume)
|
||||||
3. Return result
|
3. Return result
|
||||||
"""
|
"""
|
||||||
# Phase 0: Bootstrap
|
# Phase 0: Bootstrap (skip synthetic generation on resume if pool was saved)
|
||||||
|
initial_state = None
|
||||||
|
if config.resume and self._checkpoint_port is not None:
|
||||||
|
initial_state = self._checkpoint_port.load()
|
||||||
|
if initial_state is not None and initial_state.synthetic_pool:
|
||||||
|
synthetic_pool = initial_state.synthetic_pool
|
||||||
|
logger.info(
|
||||||
|
"Resumed checkpoint includes %d synthetic inputs — skipping bootstrap",
|
||||||
|
extra={"structured": {"event": "resume_skip_bootstrap", "pool_size": len(synthetic_pool)}},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
synthetic_pool = self._bootstrap.run(
|
||||||
|
task_description=config.task_description,
|
||||||
|
n_examples=config.n_synthetic_inputs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
synthetic_pool = self._bootstrap.run(
|
synthetic_pool = self._bootstrap.run(
|
||||||
task_description=config.task_description,
|
task_description=config.task_description,
|
||||||
n_examples=config.n_synthetic_inputs,
|
n_examples=config.n_synthetic_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Phase 1: Evolution
|
# Split into train / validation if configured
|
||||||
|
validation_pool: list = []
|
||||||
|
if config.validation_split > 0:
|
||||||
|
synthetic_pool, validation_pool = SyntheticBootstrap.split_pool(
|
||||||
|
synthetic_pool, config.validation_split,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Split synthetic pool: %d train, %d validation (%.0f%% hold-out)",
|
||||||
|
len(synthetic_pool), len(validation_pool),
|
||||||
|
config.validation_split * 100,
|
||||||
|
extra={"structured": {
|
||||||
|
"event": "pool_split",
|
||||||
|
"train_size": len(synthetic_pool),
|
||||||
|
"val_size": len(validation_pool),
|
||||||
|
}},
|
||||||
|
)
|
||||||
loop = EvolutionLoop(
|
loop = EvolutionLoop(
|
||||||
evaluator=self._evaluator,
|
evaluator=self._evaluator,
|
||||||
proposer=self._proposer,
|
proposer=self._proposer,
|
||||||
@@ -53,9 +97,22 @@ class OptimizePromptUseCase:
|
|||||||
verbose=config.verbose,
|
verbose=config.verbose,
|
||||||
circuit_breaker_threshold=config.circuit_breaker_threshold,
|
circuit_breaker_threshold=config.circuit_breaker_threshold,
|
||||||
error_strategy=config.error_strategy,
|
error_strategy=config.error_strategy,
|
||||||
|
checkpoint_port=self._checkpoint_port,
|
||||||
|
checkpoint_interval=config.checkpoint_interval,
|
||||||
|
population_size=config.population_size,
|
||||||
|
crossover_rate=config.crossover_rate,
|
||||||
|
mutation_rate=config.mutation_rate,
|
||||||
|
diversity_penalty=config.diversity_penalty,
|
||||||
|
crossover_port=self._crossover_port,
|
||||||
|
mutation_port=self._mutation_port,
|
||||||
|
early_stop_patience=config.early_stop_patience,
|
||||||
)
|
)
|
||||||
seed_prompt = Prompt(text=config.seed_prompt)
|
seed_prompt = Prompt(text=config.seed_prompt)
|
||||||
state = await loop.run(seed_prompt, synthetic_pool, config.task_description)
|
state = await loop.run(
|
||||||
|
seed_prompt, synthetic_pool, config.task_description,
|
||||||
|
initial_state=initial_state,
|
||||||
|
validation_pool=validation_pool or None,
|
||||||
|
)
|
||||||
|
|
||||||
# Phase 2: Result
|
# Phase 2: Result
|
||||||
initial_score = (
|
initial_score = (
|
||||||
@@ -71,9 +128,12 @@ class OptimizePromptUseCase:
|
|||||||
),
|
),
|
||||||
initial_prompt=config.seed_prompt,
|
initial_prompt=config.seed_prompt,
|
||||||
iterations_used=state.iteration,
|
iterations_used=state.iteration,
|
||||||
total_llm_calls=state.total_llm_calls + 1, # +1 for bootstrap
|
total_llm_calls=state.total_llm_calls + 1, # +1 for bootstrap synthesis call
|
||||||
initial_score=initial_score,
|
initial_score=initial_score,
|
||||||
final_score=final_score,
|
final_score=final_score,
|
||||||
improvement=final_score - initial_score,
|
improvement=final_score - initial_score,
|
||||||
history=state.history,
|
history=state.history,
|
||||||
|
final_validation_score=state.best_validation_score,
|
||||||
|
best_validation_score=state.best_validation_score,
|
||||||
|
early_stopped=state.early_stopped,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,31 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
CLI — user entry point.
|
CLI — user entry point.
|
||||||
|
|
||||||
Typer interface with -i (input) and -o (output) options.
|
Registers all subcommands and delegates to cli/commands/.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from dataclasses import asdict
|
|
||||||
|
|
||||||
import dspy
|
|
||||||
import typer
|
import typer
|
||||||
from pydantic import ValidationError
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
|
|
||||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
from prometheus.cli.commands import init, list_runs, optimize, version
|
||||||
from prometheus.application.dto import OptimizationConfig, OptimizationResult
|
|
||||||
from prometheus.application.evaluator import PromptEvaluator
|
|
||||||
from prometheus.application.use_cases import OptimizePromptUseCase
|
|
||||||
from prometheus.infrastructure.file_io import YamlPersistence
|
|
||||||
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
|
||||||
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
|
||||||
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
|
||||||
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
|
||||||
|
|
||||||
app = typer.Typer(
|
app = typer.Typer(
|
||||||
name="prometheus",
|
name="prometheus",
|
||||||
@@ -33,205 +15,12 @@ app = typer.Typer(
|
|||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
console = Console()
|
# Register all subcommands — having multiple commands fixes the
|
||||||
|
# Typer 0.24+ bug where a single-command app absorbs the subcommand.
|
||||||
|
optimize.register(app)
|
||||||
@app.command()
|
version.register(app)
|
||||||
def optimize(
|
init.register(app)
|
||||||
input: str = typer.Option(
|
list_runs.register(app)
|
||||||
...,
|
|
||||||
"-i",
|
|
||||||
"--input",
|
|
||||||
help="Path to input YAML config file.",
|
|
||||||
exists=True,
|
|
||||||
readable=True,
|
|
||||||
),
|
|
||||||
output: str = typer.Option(
|
|
||||||
"output.yaml",
|
|
||||||
"-o",
|
|
||||||
"--output",
|
|
||||||
help="Path to output YAML result file.",
|
|
||||||
),
|
|
||||||
verbose: bool = typer.Option(
|
|
||||||
False,
|
|
||||||
"-v",
|
|
||||||
"--verbose",
|
|
||||||
help="Print detailed progress.",
|
|
||||||
),
|
|
||||||
max_retries: int = typer.Option(
|
|
||||||
3,
|
|
||||||
"--max-retries",
|
|
||||||
help="Max retry attempts for transient LLM errors (429, timeout, 5xx).",
|
|
||||||
),
|
|
||||||
error_strategy: str = typer.Option(
|
|
||||||
"retry",
|
|
||||||
"--error-strategy",
|
|
||||||
help="How to handle errors: skip | retry | abort.",
|
|
||||||
),
|
|
||||||
max_concurrency: int = typer.Option(
|
|
||||||
5,
|
|
||||||
"--max-concurrency",
|
|
||||||
help="Max parallel LLM calls for minibatch execution and judging.",
|
|
||||||
),
|
|
||||||
) -> None:
|
|
||||||
"""Optimize a prompt without any reference data.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
prometheus optimize -i config.yaml -o result.yaml
|
|
||||||
"""
|
|
||||||
asyncio.run(
|
|
||||||
_async_optimize(input, output, verbose, max_retries, error_strategy, max_concurrency)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _async_optimize(
|
|
||||||
input: str,
|
|
||||||
output: str,
|
|
||||||
verbose: bool,
|
|
||||||
max_retries: int,
|
|
||||||
error_strategy: str,
|
|
||||||
max_concurrency: int,
|
|
||||||
) -> None:
|
|
||||||
# Configure verbose logging
|
|
||||||
if verbose:
|
|
||||||
logging.basicConfig(level=logging.INFO, format="[PROMETHEUS] %(message)s")
|
|
||||||
|
|
||||||
console.print(
|
|
||||||
Panel.fit(
|
|
||||||
"PROMETHEUS — Prompt Evolution Engine",
|
|
||||||
subtitle="No reference data required",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. Load & validate config
|
|
||||||
persistence = YamlPersistence()
|
|
||||||
raw_config = persistence.read_config(input)
|
|
||||||
|
|
||||||
# CLI flags override config file values
|
|
||||||
raw_config.setdefault("max_retries", max_retries)
|
|
||||||
raw_config.setdefault("error_strategy", error_strategy)
|
|
||||||
raw_config.setdefault("max_concurrency", max_concurrency)
|
|
||||||
raw_config["output_path"] = output
|
|
||||||
raw_config["verbose"] = verbose
|
|
||||||
|
|
||||||
try:
|
|
||||||
config = OptimizationConfig.model_validate(raw_config)
|
|
||||||
except ValidationError as exc:
|
|
||||||
console.print("[bold red]Configuration error:[/bold red]\n")
|
|
||||||
for err in exc.errors():
|
|
||||||
loc = " → ".join(str(l) for l in err["loc"])
|
|
||||||
console.print(f" [red]• {loc}: {err['msg']}[/red]")
|
|
||||||
raise typer.Exit(code=1) from exc
|
|
||||||
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
|
||||||
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
|
||||||
|
|
||||||
# 2. Create per-model DSPy LM instances
|
|
||||||
def _model_lm_kwargs(
|
|
||||||
model_api_base: str | None,
|
|
||||||
model_api_key_env: str | None,
|
|
||||||
) -> dict:
|
|
||||||
"""Build kwargs for dspy.LM, using per-model overrides with global fallback."""
|
|
||||||
kwargs: dict = {}
|
|
||||||
api_base = model_api_base or config.api_base
|
|
||||||
api_key_env = model_api_key_env or config.api_key_env
|
|
||||||
if api_base:
|
|
||||||
kwargs["api_base"] = api_base
|
|
||||||
if api_key_env:
|
|
||||||
kwargs["api_key"] = os.environ.get(api_key_env, "")
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
task_lm = dspy.LM(
|
|
||||||
config.task_model,
|
|
||||||
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env),
|
|
||||||
)
|
|
||||||
judge_lm = dspy.LM(
|
|
||||||
config.judge_model,
|
|
||||||
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env),
|
|
||||||
)
|
|
||||||
proposer_lm = dspy.LM(
|
|
||||||
config.proposer_model,
|
|
||||||
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env),
|
|
||||||
)
|
|
||||||
synth_lm = dspy.LM(
|
|
||||||
config.synth_model,
|
|
||||||
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Build adapters (Dependency Injection — each gets its own LM + retry config)
|
|
||||||
synth_adapter = DSPySyntheticAdapter(lm=synth_lm)
|
|
||||||
llm_adapter = DSPyLLMAdapter(
|
|
||||||
lm=task_lm,
|
|
||||||
max_retries=config.max_retries,
|
|
||||||
retry_delay_base=config.retry_delay_base,
|
|
||||||
)
|
|
||||||
judge_adapter = DSPyJudgeAdapter(
|
|
||||||
lm=judge_lm,
|
|
||||||
max_retries=config.max_retries,
|
|
||||||
retry_delay_base=config.retry_delay_base,
|
|
||||||
max_concurrency=config.max_concurrency,
|
|
||||||
)
|
|
||||||
proposer_adapter = DSPyProposerAdapter(
|
|
||||||
lm=proposer_lm,
|
|
||||||
max_retries=config.max_retries,
|
|
||||||
retry_delay_base=config.retry_delay_base,
|
|
||||||
)
|
|
||||||
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
|
||||||
evaluator = PromptEvaluator(
|
|
||||||
executor=llm_adapter,
|
|
||||||
judge=judge_adapter,
|
|
||||||
max_concurrency=config.max_concurrency,
|
|
||||||
)
|
|
||||||
use_case = OptimizePromptUseCase(
|
|
||||||
evaluator=evaluator,
|
|
||||||
proposer=proposer_adapter,
|
|
||||||
bootstrap=bootstrap,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Execute
|
|
||||||
with console.status("[bold green]Evolving prompt..."):
|
|
||||||
result = await use_case.execute(config)
|
|
||||||
|
|
||||||
# 5. Display results
|
|
||||||
_display_result(result)
|
|
||||||
|
|
||||||
# 6. Save
|
|
||||||
_save_result(persistence, output, result)
|
|
||||||
console.print(f"\n[green]Results saved to {output}[/green]")
|
|
||||||
|
|
||||||
|
|
||||||
def _display_result(result: OptimizationResult) -> None:
|
|
||||||
"""Display a Rich summary in the terminal."""
|
|
||||||
console.print()
|
|
||||||
console.print(
|
|
||||||
Panel(
|
|
||||||
f"[bold green]Optimized Prompt[/bold green]\n\n{result.optimized_prompt}",
|
|
||||||
title="Result",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
table = Table(title="Metrics")
|
|
||||||
table.add_column("Metric", style="cyan")
|
|
||||||
table.add_column("Value", style="bold")
|
|
||||||
table.add_row("Initial Score", f"{result.initial_score:.2f}")
|
|
||||||
table.add_row("Final Score", f"{result.final_score:.2f}")
|
|
||||||
table.add_row("Improvement", f"{result.improvement:+.2f}")
|
|
||||||
table.add_row("Iterations", str(result.iterations_used))
|
|
||||||
table.add_row("LLM Calls", str(result.total_llm_calls))
|
|
||||||
console.print(table)
|
|
||||||
|
|
||||||
|
|
||||||
def _save_result(
|
|
||||||
persistence: YamlPersistence,
|
|
||||||
path: str,
|
|
||||||
result: OptimizationResult,
|
|
||||||
) -> None:
|
|
||||||
"""Save the result as YAML."""
|
|
||||||
persistence.write_result(path, asdict(result))
|
|
||||||
|
|
||||||
|
|
||||||
@app.command(hidden=True)
|
|
||||||
def _help() -> None:
|
|
||||||
"""Internal placeholder to force multi-command Typer behavior."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
1
src/prometheus/cli/commands/__init__.py
Normal file
1
src/prometheus/cli/commands/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""CLI command modules."""
|
||||||
97
src/prometheus/cli/commands/init.py
Normal file
97
src/prometheus/cli/commands/init.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""prometheus init — scaffold a config YAML interactively."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
|
import yaml
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
_TEMPLATE = """\
|
||||||
|
# PROMETHEUS configuration
|
||||||
|
# Generated by `prometheus init`
|
||||||
|
|
||||||
|
# --- Required ---
|
||||||
|
seed_prompt: {seed_prompt}
|
||||||
|
task_description: {task_description}
|
||||||
|
|
||||||
|
# --- Models ---
|
||||||
|
task_model: {task_model}
|
||||||
|
judge_model: {judge_model}
|
||||||
|
proposer_model: {proposer_model}
|
||||||
|
synth_model: {synth_model}
|
||||||
|
|
||||||
|
# --- Global API settings (optional) ---
|
||||||
|
# api_base: https://api.openai.com/v1
|
||||||
|
# api_key_env: OPENAI_API_KEY
|
||||||
|
|
||||||
|
# --- Evolution parameters ---
|
||||||
|
max_iterations: 30
|
||||||
|
n_synthetic_inputs: 20
|
||||||
|
minibatch_size: 5
|
||||||
|
perfect_score: 1.0
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
# --- Concurrency ---
|
||||||
|
max_concurrency: 5
|
||||||
|
|
||||||
|
# --- Error handling ---
|
||||||
|
max_retries: 3
|
||||||
|
retry_delay_base: 1.0
|
||||||
|
circuit_breaker_threshold: 5
|
||||||
|
error_strategy: retry
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def register(app: typer.Typer) -> None:
|
||||||
|
"""Register the init command on the Typer app."""
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def init(
|
||||||
|
output: str = typer.Option(
|
||||||
|
"config.yaml",
|
||||||
|
"-o",
|
||||||
|
"--output",
|
||||||
|
help="Path for the generated config file.",
|
||||||
|
),
|
||||||
|
) -> None:
|
||||||
|
"""Interactively scaffold a PROMETHEUS config YAML.
|
||||||
|
|
||||||
|
Prompts for required fields and writes a ready-to-edit config file.
|
||||||
|
"""
|
||||||
|
target = Path(output)
|
||||||
|
|
||||||
|
if target.exists() and not typer.confirm(
|
||||||
|
f"{output} already exists. Overwrite?", default=False
|
||||||
|
):
|
||||||
|
raise typer.Exit(code=0)
|
||||||
|
|
||||||
|
seed_prompt: str = typer.prompt("Seed prompt")
|
||||||
|
task_description: str = typer.prompt("Task description")
|
||||||
|
task_model: str = typer.prompt("Task model", default="openai/gpt-4o-mini")
|
||||||
|
judge_model: str = typer.prompt("Judge model", default="openai/gpt-4o")
|
||||||
|
proposer_model: str = typer.prompt("Proposer model", default="openai/gpt-4o")
|
||||||
|
synth_model: str = typer.prompt("Synth model", default="openai/gpt-4o")
|
||||||
|
|
||||||
|
content = _TEMPLATE.format(
|
||||||
|
seed_prompt=_yaml_string(seed_prompt),
|
||||||
|
task_description=_yaml_string(task_description),
|
||||||
|
task_model=task_model,
|
||||||
|
judge_model=judge_model,
|
||||||
|
proposer_model=proposer_model,
|
||||||
|
synth_model=synth_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
target.write_text(content, encoding="utf-8")
|
||||||
|
console.print(f"[green]Config written to {output}[/green]")
|
||||||
|
console.print("[dim]Edit it as needed, then run: prometheus optimize -i config.yaml[/dim]")
|
||||||
|
|
||||||
|
|
||||||
|
def _yaml_string(value: str) -> str:
|
||||||
|
"""Quote a string for YAML if it contains special characters."""
|
||||||
|
if any(ch in value for ch in (":", "#", "'", '"', "\n", "{", "}", "[", "]", ",")):
|
||||||
|
escaped = value.replace("'", "''")
|
||||||
|
return f"'{escaped}'"
|
||||||
|
return value
|
||||||
101
src/prometheus/cli/commands/list_runs.py
Normal file
101
src/prometheus/cli/commands/list_runs.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""prometheus list — list past optimization runs."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import glob as globmod
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
|
import yaml
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
_DEFAULT_PATTERNS = ("output.yaml", "results/*.yaml", "*.result.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def register(app: typer.Typer) -> None:
|
||||||
|
"""Register the list command on the Typer app."""
|
||||||
|
|
||||||
|
@app.command("list")
|
||||||
|
def list_runs(
|
||||||
|
directory: str = typer.Option(
|
||||||
|
".",
|
||||||
|
"-d",
|
||||||
|
"--directory",
|
||||||
|
help="Directory to scan for result YAML files.",
|
||||||
|
),
|
||||||
|
) -> None:
|
||||||
|
"""List past optimization runs found in result YAML files.
|
||||||
|
|
||||||
|
Scans the given directory for YAML files that look like PROMETHEUS
|
||||||
|
output (they contain 'optimized_prompt' and 'final_score' keys) and
|
||||||
|
displays a summary table.
|
||||||
|
"""
|
||||||
|
base = Path(directory)
|
||||||
|
if not base.is_dir():
|
||||||
|
console.print(f"[red]Directory not found: {directory}[/red]")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
runs: list[dict] = []
|
||||||
|
|
||||||
|
for pattern in _DEFAULT_PATTERNS:
|
||||||
|
for path_str in globmod.glob(str(base / pattern), recursive=False):
|
||||||
|
_try_read_run(path_str, runs)
|
||||||
|
|
||||||
|
# Also try nested directories one level deep
|
||||||
|
for path_str in globmod.glob(str(base / "**/*.yaml"), recursive=True):
|
||||||
|
_try_read_run(path_str, runs)
|
||||||
|
|
||||||
|
if not runs:
|
||||||
|
console.print("[dim]No optimization runs found.[/dim]")
|
||||||
|
raise typer.Exit(code=0)
|
||||||
|
|
||||||
|
# Deduplicate by path
|
||||||
|
seen: set[str] = set()
|
||||||
|
unique_runs: list[dict] = []
|
||||||
|
for run in runs:
|
||||||
|
if run["path"] not in seen:
|
||||||
|
seen.add(run["path"])
|
||||||
|
unique_runs.append(run)
|
||||||
|
|
||||||
|
table = Table(title="PROMETHEUS Runs")
|
||||||
|
table.add_column("File", style="cyan")
|
||||||
|
table.add_column("Initial", justify="right")
|
||||||
|
table.add_column("Final", justify="right", style="green")
|
||||||
|
table.add_column("Delta", justify="right")
|
||||||
|
table.add_column("Iters", justify="right")
|
||||||
|
table.add_column("Prompt (first 60 chars)", style="dim")
|
||||||
|
|
||||||
|
for run in sorted(unique_runs, key=lambda r: r["path"]):
|
||||||
|
table.add_row(
|
||||||
|
run["path"],
|
||||||
|
f"{run['initial_score']:.2f}",
|
||||||
|
f"{run['final_score']:.2f}",
|
||||||
|
f"{run['improvement']:+.2f}",
|
||||||
|
str(run["iterations"]),
|
||||||
|
run["prompt_preview"],
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def _try_read_run(path_str: str, runs: list[dict]) -> None:
|
||||||
|
"""Try to parse a YAML file as a PROMETHEUS result and append metadata."""
|
||||||
|
try:
|
||||||
|
with open(path_str, encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return
|
||||||
|
if "optimized_prompt" not in data or "final_score" not in data:
|
||||||
|
return
|
||||||
|
runs.append({
|
||||||
|
"path": path_str,
|
||||||
|
"initial_score": float(data.get("initial_score", 0.0)),
|
||||||
|
"final_score": float(data.get("final_score", 0.0)),
|
||||||
|
"improvement": float(data.get("improvement", 0.0)),
|
||||||
|
"iterations": int(data.get("iterations_used", 0)),
|
||||||
|
"prompt_preview": str(data.get("optimized_prompt", ""))[:60],
|
||||||
|
})
|
||||||
|
except (OSError, yaml.YAMLError, ValueError, TypeError):
|
||||||
|
pass
|
||||||
@@ -27,8 +27,6 @@ from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
|||||||
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
||||||
from prometheus.infrastructure.crossover_adapter import DSPyCrossoverAdapter
|
from prometheus.infrastructure.crossover_adapter import DSPyCrossoverAdapter
|
||||||
from prometheus.infrastructure.mutation_adapter import DSPyMutationAdapter
|
from prometheus.infrastructure.mutation_adapter import DSPyMutationAdapter
|
||||||
from prometheus.infrastructure.crossover_adapter import DSPyCrossoverAdapter
|
|
||||||
from prometheus.infrastructure.mutation_adapter import DSPyMutationAdapter
|
|
||||||
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
||||||
from prometheus.infrastructure.similarity import create_similarity_adapter
|
from prometheus.infrastructure.similarity import create_similarity_adapter
|
||||||
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
||||||
@@ -337,6 +335,29 @@ async def _async_optimize(
|
|||||||
with console.status("[bold green]Evolving prompt..."):
|
with console.status("[bold green]Evolving prompt..."):
|
||||||
result = await use_case.execute(config)
|
result = await use_case.execute(config)
|
||||||
|
|
||||||
|
# 4b. Compute actual LLM call count from adapter counters
|
||||||
|
actual_llm_calls = (
|
||||||
|
llm_adapter.call_count
|
||||||
|
+ judge_adapter.call_count
|
||||||
|
+ proposer_adapter.call_count
|
||||||
|
+ synth_adapter.call_count
|
||||||
|
+ (crossover_adapter.call_count if crossover_adapter else 0)
|
||||||
|
+ (mutation_adapter.call_count if mutation_adapter else 0)
|
||||||
|
)
|
||||||
|
result = OptimizationResult(
|
||||||
|
optimized_prompt=result.optimized_prompt,
|
||||||
|
initial_prompt=result.initial_prompt,
|
||||||
|
iterations_used=result.iterations_used,
|
||||||
|
total_llm_calls=actual_llm_calls,
|
||||||
|
initial_score=result.initial_score,
|
||||||
|
final_score=result.final_score,
|
||||||
|
improvement=result.improvement,
|
||||||
|
history=result.history,
|
||||||
|
final_validation_score=result.final_validation_score,
|
||||||
|
best_validation_score=result.best_validation_score,
|
||||||
|
early_stopped=result.early_stopped,
|
||||||
|
)
|
||||||
|
|
||||||
# 5. Display results
|
# 5. Display results
|
||||||
_display_result(result)
|
_display_result(result)
|
||||||
|
|
||||||
|
|||||||
18
src/prometheus/cli/commands/version.py
Normal file
18
src/prometheus/cli/commands/version.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""prometheus version — print the current version."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
from prometheus import __version__
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
def register(app: typer.Typer) -> None:
|
||||||
|
"""Register the version command on the Typer app."""
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def version() -> None:
|
||||||
|
"""Print the PROMETHEUS version."""
|
||||||
|
console.print(f"PROMETHEUS {__version__}")
|
||||||
96
src/prometheus/cli/logging_setup.py
Normal file
96
src/prometheus/cli/logging_setup.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
Structured logging configuration for PROMETHEUS.
|
||||||
|
|
||||||
|
Supports text (human-readable) and JSON (machine-parseable) output,
|
||||||
|
configurable log levels, and optional file output.
|
||||||
|
|
||||||
|
Fixes Bug #4: verbose mode now reliably produces output by configuring
|
||||||
|
handlers explicitly instead of relying on ``logging.basicConfig``.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import TextIO
|
||||||
|
|
||||||
|
|
||||||
|
_PROMETHEUS_LOGGER = "prometheus"
|
||||||
|
|
||||||
|
|
||||||
|
class _JsonFormatter(logging.Formatter):
|
||||||
|
"""Emit one JSON object per log line."""
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
message = record.getMessage()
|
||||||
|
payload: dict = {
|
||||||
|
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
||||||
|
"level": record.levelname,
|
||||||
|
"logger": record.name,
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
# Merge any extra structured fields the caller attached.
|
||||||
|
if hasattr(record, "structured"):
|
||||||
|
payload["structured"] = record.structured # type: ignore[attr-defined]
|
||||||
|
if record.exc_info and record.exc_info[1] is not None:
|
||||||
|
payload["exception"] = self.formatException(record.exc_info)
|
||||||
|
return json.dumps(payload, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
class _TextFormatter(logging.Formatter):
|
||||||
|
"""Human-readable format with structured extras appended."""
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
base = super().format(record)
|
||||||
|
if hasattr(record, "structured") and record.structured:
|
||||||
|
extras = " ".join(f"{k}={v}" for k, v in record.structured.items())
|
||||||
|
base = f"{base} {extras}"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(
|
||||||
|
*,
|
||||||
|
level: int = logging.WARNING,
|
||||||
|
log_format: str = "text",
|
||||||
|
log_file: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Configure the prometheus root logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Logging level (e.g. logging.DEBUG, logging.INFO).
|
||||||
|
log_format: ``"text"`` for human-readable or ``"json"`` for
|
||||||
|
machine-parseable output.
|
||||||
|
log_file: Optional path to also write logs to a file.
|
||||||
|
"""
|
||||||
|
prom_logger = logging.getLogger(_PROMETHEUS_LOGGER)
|
||||||
|
prom_logger.setLevel(level)
|
||||||
|
# Remove any stale handlers so re-configuration is idempotent.
|
||||||
|
prom_logger.handlers.clear()
|
||||||
|
|
||||||
|
if log_format == "json":
|
||||||
|
fmt = _JsonFormatter()
|
||||||
|
else:
|
||||||
|
fmt = _TextFormatter(
|
||||||
|
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
|
datefmt="%H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Console handler (stderr so it doesn't mix with Rich stdout)
|
||||||
|
console_handler = logging.StreamHandler(sys.stderr)
|
||||||
|
console_handler.setFormatter(fmt)
|
||||||
|
prom_logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# Optional file handler
|
||||||
|
if log_file:
|
||||||
|
file_handler = logging.FileHandler(log_file)
|
||||||
|
file_handler.setFormatter(fmt)
|
||||||
|
prom_logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# Prevent propagation to root logger to avoid duplicate output
|
||||||
|
prom_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
"""Return a child logger under the prometheus namespace."""
|
||||||
|
return logging.getLogger(f"{_PROMETHEUS_LOGGER}.{name}")
|
||||||
@@ -31,6 +31,15 @@ class SyntheticExample:
|
|||||||
id: int = 0
|
id: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GroundTruthExample:
|
||||||
|
"""A ground-truth evaluation example with a known-good expected output."""
|
||||||
|
|
||||||
|
input_text: str
|
||||||
|
expected_output: str
|
||||||
|
id: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Trajectory:
|
class Trajectory:
|
||||||
"""Execution trace of a prompt on an input.
|
"""Execution trace of a prompt on an input.
|
||||||
@@ -85,3 +94,6 @@ class OptimizationState:
|
|||||||
synthetic_pool: list[SyntheticExample] = field(default_factory=list)
|
synthetic_pool: list[SyntheticExample] = field(default_factory=list)
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
total_llm_calls: int = 0
|
total_llm_calls: int = 0
|
||||||
|
# Hold-out validation
|
||||||
|
best_validation_score: float | None = None
|
||||||
|
early_stopped: bool = False
|
||||||
|
|||||||
@@ -8,7 +8,14 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory
|
from prometheus.domain.entities import (
|
||||||
|
Candidate,
|
||||||
|
GroundTruthExample,
|
||||||
|
OptimizationState,
|
||||||
|
Prompt,
|
||||||
|
SyntheticExample,
|
||||||
|
Trajectory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMPort(ABC):
|
class LLMPort(ABC):
|
||||||
@@ -73,6 +80,34 @@ class SyntheticGeneratorPort(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class CrossoverPort(ABC):
|
||||||
|
"""Port for crossover — combining instructions from two parent candidates."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def crossover(
|
||||||
|
self,
|
||||||
|
parent_a: Prompt,
|
||||||
|
parent_b: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
) -> Prompt:
|
||||||
|
"""Combine instructions from two parents into a child prompt."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MutationPort(ABC):
|
||||||
|
"""Port for mutating a prompt — paraphrase, constrain, generalize, specialize."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def mutate(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
mutation_type: str = "paraphrase",
|
||||||
|
) -> Prompt:
|
||||||
|
"""Apply a mutation to the prompt."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class PersistencePort(ABC):
|
class PersistencePort(ABC):
|
||||||
"""Port for reading/writing files."""
|
"""Port for reading/writing files."""
|
||||||
|
|
||||||
@@ -83,3 +118,49 @@ class PersistencePort(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def write_result(self, path: str, data: dict[str, Any]) -> None:
|
def write_result(self, path: str, data: dict[str, Any]) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarityPort(ABC):
|
||||||
|
"""Port for computing similarity between a prediction and expected output.
|
||||||
|
|
||||||
|
Infrastructure provides concrete metrics (exact match, BLEU, ROUGE, cosine).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute(self, prediction: str, expected: str) -> float:
|
||||||
|
"""Compute similarity score in [0, 1]. 1.0 = perfect match."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetLoaderPort(ABC):
|
||||||
|
"""Port for loading ground-truth evaluation datasets."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, path: str) -> list[GroundTruthExample]:
|
||||||
|
"""Load a dataset from a CSV or JSON file.
|
||||||
|
|
||||||
|
Each row must have 'input' and 'expected_output' fields.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointPort(ABC):
|
||||||
|
"""Port for saving and loading optimization checkpoints.
|
||||||
|
|
||||||
|
Enables resuming long-running optimizations after interruption.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, state: OptimizationState) -> None:
|
||||||
|
"""Persist the current optimization state to disk."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self) -> OptimizationState | None:
|
||||||
|
"""Load the latest checkpoint. Returns None if no checkpoint exists."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def latest_exists(self) -> bool:
|
||||||
|
"""Check if a checkpoint file is available for resuming."""
|
||||||
|
...
|
||||||
|
|||||||
149
src/prometheus/infrastructure/checkpoint.py
Normal file
149
src/prometheus/infrastructure/checkpoint.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""
|
||||||
|
JSON checkpoint persistence — save/load optimization state to disk.
|
||||||
|
|
||||||
|
Implements the CheckpointPort with JSON for human-readable, versionable snapshots.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import asdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from prometheus.cli.logging_setup import get_logger
|
||||||
|
from prometheus.domain.entities import (
|
||||||
|
Candidate,
|
||||||
|
OptimizationState,
|
||||||
|
Prompt,
|
||||||
|
SyntheticExample,
|
||||||
|
)
|
||||||
|
from prometheus.domain.ports import CheckpointPort
|
||||||
|
|
||||||
|
logger = get_logger("checkpoint")
|
||||||
|
|
||||||
|
_CHECKPOINT_FILE = "latest.json"
|
||||||
|
|
||||||
|
|
||||||
|
class JsonCheckpointPersistence(CheckpointPort):
|
||||||
|
"""Saves optimization state as JSON to a configurable directory."""
|
||||||
|
|
||||||
|
def __init__(self, checkpoint_dir: str | Path = ".prometheus/checkpoints") -> None:
|
||||||
|
self._dir = Path(checkpoint_dir)
|
||||||
|
|
||||||
|
def save(self, state: OptimizationState) -> None:
|
||||||
|
"""Persist the current optimization state to disk."""
|
||||||
|
self._dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = self._dir / _CHECKPOINT_FILE
|
||||||
|
data = _serialize_state(state)
|
||||||
|
path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||||
|
logger.info(
|
||||||
|
"Checkpoint saved",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "checkpoint_saved",
|
||||||
|
"path": str(path),
|
||||||
|
"iteration": state.iteration,
|
||||||
|
"total_llm_calls": state.total_llm_calls,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self) -> OptimizationState | None:
|
||||||
|
"""Load the latest checkpoint. Returns None if no checkpoint exists."""
|
||||||
|
path = self._dir / _CHECKPOINT_FILE
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
raw = json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
state = _deserialize_state(raw)
|
||||||
|
logger.info(
|
||||||
|
"Checkpoint loaded",
|
||||||
|
extra={
|
||||||
|
"structured": {
|
||||||
|
"event": "checkpoint_loaded",
|
||||||
|
"path": str(path),
|
||||||
|
"iteration": state.iteration,
|
||||||
|
"total_llm_calls": state.total_llm_calls,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def latest_exists(self) -> bool:
|
||||||
|
"""Check if a checkpoint file is available for resuming."""
|
||||||
|
return (self._dir / _CHECKPOINT_FILE).exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Serialization helpers — keep the JSON format stable and self-describing.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_SCHEMA_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_state(state: OptimizationState) -> dict:
|
||||||
|
"""Convert OptimizationState to a JSON-safe dict."""
|
||||||
|
return {
|
||||||
|
"schema_version": _SCHEMA_VERSION,
|
||||||
|
"iteration": state.iteration,
|
||||||
|
"best_candidate": _serialize_candidate(state.best_candidate),
|
||||||
|
"candidates": [_serialize_candidate(c) for c in state.candidates],
|
||||||
|
"synthetic_pool": [
|
||||||
|
{"input_text": ex.input_text, "category": ex.category, "id": ex.id}
|
||||||
|
for ex in state.synthetic_pool
|
||||||
|
],
|
||||||
|
"history": state.history,
|
||||||
|
"total_llm_calls": state.total_llm_calls,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_candidate(candidate: Candidate | None) -> dict | None:
|
||||||
|
if candidate is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"prompt_text": candidate.prompt.text,
|
||||||
|
"prompt_metadata": candidate.prompt.metadata,
|
||||||
|
"best_score": candidate.best_score,
|
||||||
|
"generation": candidate.generation,
|
||||||
|
"parent_id": candidate.parent_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_state(data: dict) -> OptimizationState:
|
||||||
|
"""Reconstruct OptimizationState from a checkpoint dict."""
|
||||||
|
version = data.get("schema_version", 1)
|
||||||
|
# Future migration hooks go here: if version < 2: ...
|
||||||
|
|
||||||
|
best_raw = data.get("best_candidate")
|
||||||
|
best_candidate = _deserialize_candidate(best_raw)
|
||||||
|
|
||||||
|
candidates = [_deserialize_candidate(c) for c in data.get("candidates", [])]
|
||||||
|
|
||||||
|
synthetic_pool = [
|
||||||
|
SyntheticExample(
|
||||||
|
input_text=ex["input_text"],
|
||||||
|
category=ex.get("category", "default"),
|
||||||
|
id=ex.get("id", 0),
|
||||||
|
)
|
||||||
|
for ex in data.get("synthetic_pool", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
state = OptimizationState(
|
||||||
|
iteration=data.get("iteration", 0),
|
||||||
|
best_candidate=best_candidate,
|
||||||
|
candidates=candidates,
|
||||||
|
synthetic_pool=synthetic_pool,
|
||||||
|
history=data.get("history", []),
|
||||||
|
total_llm_calls=data.get("total_llm_calls", 0),
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_candidate(raw: dict | None) -> Candidate | None:
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
return Candidate(
|
||||||
|
prompt=Prompt(text=raw["prompt_text"], metadata=raw.get("prompt_metadata", {})),
|
||||||
|
best_score=raw.get("best_score", 0.0),
|
||||||
|
generation=raw.get("generation", 0),
|
||||||
|
parent_id=raw.get("parent_id"),
|
||||||
|
)
|
||||||
63
src/prometheus/infrastructure/crossover_adapter.py
Normal file
63
src/prometheus/infrastructure/crossover_adapter.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Adapter: Instruction Crossover via DSPy.
|
||||||
|
|
||||||
|
Implements CrossoverPort — combines two parent prompts into a child.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
|
||||||
|
from prometheus.domain.entities import Prompt
|
||||||
|
from prometheus.domain.ports import CrossoverPort
|
||||||
|
from prometheus.infrastructure.dspy_modules import InstructionCrossover
|
||||||
|
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||||
|
|
||||||
|
|
||||||
|
class DSPyCrossoverAdapter(CrossoverPort):
|
||||||
|
"""Uses DSPy to combine two parent instructions into a child."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lm: dspy.LM,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay_base: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
self._lm = lm
|
||||||
|
self._crossover = InstructionCrossover()
|
||||||
|
self._max_retries = max_retries
|
||||||
|
self._retry_delay_base = retry_delay_base
|
||||||
|
self.call_count: int = 0
|
||||||
|
|
||||||
|
async def crossover(
|
||||||
|
self,
|
||||||
|
parent_a: Prompt,
|
||||||
|
parent_b: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
) -> Prompt:
|
||||||
|
async def _call() -> Prompt:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self._sync_crossover, parent_a, parent_b, task_description,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await async_retry_with_backoff(
|
||||||
|
_call,
|
||||||
|
max_retries=self._max_retries,
|
||||||
|
retry_delay_base=self._retry_delay_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sync_crossover(
|
||||||
|
self,
|
||||||
|
parent_a: Prompt,
|
||||||
|
parent_b: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
) -> Prompt:
|
||||||
|
with dspy.context(lm=self._lm):
|
||||||
|
pred = self._crossover(
|
||||||
|
parent_a=parent_a.text,
|
||||||
|
parent_b=parent_b.text,
|
||||||
|
task_description=task_description,
|
||||||
|
)
|
||||||
|
self.call_count += 1
|
||||||
|
return Prompt(text=pred.child_instruction)
|
||||||
75
src/prometheus/infrastructure/dataset_loader.py
Normal file
75
src/prometheus/infrastructure/dataset_loader.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""Dataset loader — loads ground-truth CSV/JSON datasets."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from prometheus.domain.entities import GroundTruthExample
|
||||||
|
from prometheus.domain.ports import DatasetLoaderPort
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FileDatasetLoader(DatasetLoaderPort):
|
||||||
|
"""Loads evaluation datasets from CSV or JSON files.
|
||||||
|
|
||||||
|
CSV files must have 'input' and 'expected_output' columns.
|
||||||
|
JSON files must be an array of objects with 'input' and 'expected_output' keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load(self, path: str) -> list[GroundTruthExample]:
|
||||||
|
suffix = Path(path).suffix.lower()
|
||||||
|
if suffix == ".csv":
|
||||||
|
return self._load_csv(path)
|
||||||
|
elif suffix in (".json", ".jsonl"):
|
||||||
|
return self._load_json(path)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported dataset format '{suffix}'. Use .csv, .json, or .jsonl."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_csv(self, path: str) -> list[GroundTruthExample]:
|
||||||
|
examples: list[GroundTruthExample] = []
|
||||||
|
with open(path, newline="", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
input_text = row.get("input", "").strip()
|
||||||
|
expected = row.get("expected_output", "").strip()
|
||||||
|
if not input_text:
|
||||||
|
logger.warning("Skipping CSV row %d: empty 'input' field", i + 1)
|
||||||
|
continue
|
||||||
|
examples.append(
|
||||||
|
GroundTruthExample(
|
||||||
|
input_text=input_text,
|
||||||
|
expected_output=expected,
|
||||||
|
id=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info("Loaded %d examples from CSV: %s", len(examples), path)
|
||||||
|
return examples
|
||||||
|
|
||||||
|
def _load_json(self, path: str) -> list[GroundTruthExample]:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise ValueError("JSON dataset must be an array of objects.")
|
||||||
|
examples: list[GroundTruthExample] = []
|
||||||
|
for i, item in enumerate(data):
|
||||||
|
input_text = item.get("input", "").strip() if isinstance(item, dict) else ""
|
||||||
|
expected = (
|
||||||
|
item.get("expected_output", "").strip() if isinstance(item, dict) else ""
|
||||||
|
)
|
||||||
|
if not input_text:
|
||||||
|
logger.warning("Skipping JSON item %d: empty 'input' field", i)
|
||||||
|
continue
|
||||||
|
examples.append(
|
||||||
|
GroundTruthExample(
|
||||||
|
input_text=input_text,
|
||||||
|
expected_output=expected,
|
||||||
|
id=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info("Loaded %d examples from JSON: %s", len(examples), path)
|
||||||
|
return examples
|
||||||
@@ -59,6 +59,7 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
if self._judge_dimensions
|
if self._judge_dimensions
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
self.call_count: int = 0
|
||||||
|
|
||||||
async def judge_batch(
|
async def judge_batch(
|
||||||
self,
|
self,
|
||||||
@@ -104,13 +105,15 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
|
|
||||||
def _sync_judge(self, task_description: str, input_text: str, output_text: str):
|
def _sync_judge(self, task_description: str, input_text: str, output_text: str):
|
||||||
with dspy.context(lm=self._lm):
|
with dspy.context(lm=self._lm):
|
||||||
return self._judge(
|
result = self._judge(
|
||||||
task_description=task_description,
|
task_description=task_description,
|
||||||
input_text=input_text,
|
input_text=input_text,
|
||||||
output_text=output_text,
|
output_text=output_text,
|
||||||
judge_criteria=self._judge_criteria,
|
judge_criteria=self._judge_criteria,
|
||||||
dimension_names=self._dimension_names,
|
dimension_names=self._dimension_names,
|
||||||
)
|
)
|
||||||
|
self.call_count += 1
|
||||||
|
return result
|
||||||
|
|
||||||
def _aggregate_result(self, pred: Any) -> tuple[float, str]:
|
def _aggregate_result(self, pred: Any) -> tuple[float, str]:
|
||||||
"""Compute weighted aggregate score from dimension scores if available."""
|
"""Compute weighted aggregate score from dimension scores if available."""
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class DSPyLLMAdapter(LLMPort):
|
|||||||
self._predictor = dspy.Predict(self._ExecuteSignature)
|
self._predictor = dspy.Predict(self._ExecuteSignature)
|
||||||
self._max_retries = max_retries
|
self._max_retries = max_retries
|
||||||
self._retry_delay_base = retry_delay_base
|
self._retry_delay_base = retry_delay_base
|
||||||
|
self.call_count: int = 0
|
||||||
|
|
||||||
async def execute(self, prompt: Prompt, input_text: str) -> str:
|
async def execute(self, prompt: Prompt, input_text: str) -> str:
|
||||||
async def _call() -> str:
|
async def _call() -> str:
|
||||||
@@ -52,4 +53,5 @@ class DSPyLLMAdapter(LLMPort):
|
|||||||
instruction=prompt.text,
|
instruction=prompt.text,
|
||||||
input_text=input_text,
|
input_text=input_text,
|
||||||
)
|
)
|
||||||
|
self.call_count += 1
|
||||||
return str(result.output)
|
return str(result.output)
|
||||||
|
|||||||
70
src/prometheus/infrastructure/mutation_adapter.py
Normal file
70
src/prometheus/infrastructure/mutation_adapter.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""
|
||||||
|
Adapter: Instruction Mutation via DSPy.
|
||||||
|
|
||||||
|
Implements MutationPort — applies typed mutations (paraphrase, constrain,
|
||||||
|
generalize, specialize) to a prompt.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
|
||||||
|
from prometheus.domain.entities import Prompt
|
||||||
|
from prometheus.domain.ports import MutationPort
|
||||||
|
from prometheus.infrastructure.dspy_modules import InstructionMutator
|
||||||
|
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||||
|
|
||||||
|
_MUTATION_TYPES = ("paraphrase", "constrain", "generalize", "specialize")
|
||||||
|
|
||||||
|
|
||||||
|
class DSPyMutationAdapter(MutationPort):
|
||||||
|
"""Uses DSPy to apply typed mutations to an instruction."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lm: dspy.LM,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay_base: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
self._lm = lm
|
||||||
|
self._mutator = InstructionMutator()
|
||||||
|
self._max_retries = max_retries
|
||||||
|
self._retry_delay_base = retry_delay_base
|
||||||
|
self.call_count: int = 0
|
||||||
|
|
||||||
|
async def mutate(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
mutation_type: str = "paraphrase",
|
||||||
|
) -> Prompt:
|
||||||
|
if mutation_type not in _MUTATION_TYPES:
|
||||||
|
mutation_type = random.choice(_MUTATION_TYPES)
|
||||||
|
|
||||||
|
async def _call() -> Prompt:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self._sync_mutate, prompt, task_description, mutation_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await async_retry_with_backoff(
|
||||||
|
_call,
|
||||||
|
max_retries=self._max_retries,
|
||||||
|
retry_delay_base=self._retry_delay_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sync_mutate(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
mutation_type: str,
|
||||||
|
) -> Prompt:
|
||||||
|
with dspy.context(lm=self._lm):
|
||||||
|
pred = self._mutator(
|
||||||
|
current_instruction=prompt.text,
|
||||||
|
task_description=task_description,
|
||||||
|
mutation_type=mutation_type,
|
||||||
|
)
|
||||||
|
self.call_count += 1
|
||||||
|
return Prompt(text=pred.mutated_instruction)
|
||||||
@@ -29,6 +29,7 @@ class DSPyProposerAdapter(ProposerPort):
|
|||||||
self._proposer = InstructionProposer()
|
self._proposer = InstructionProposer()
|
||||||
self._max_retries = max_retries
|
self._max_retries = max_retries
|
||||||
self._retry_delay_base = retry_delay_base
|
self._retry_delay_base = retry_delay_base
|
||||||
|
self.call_count: int = 0
|
||||||
|
|
||||||
async def propose(
|
async def propose(
|
||||||
self,
|
self,
|
||||||
@@ -56,6 +57,7 @@ class DSPyProposerAdapter(ProposerPort):
|
|||||||
task_description=task_description,
|
task_description=task_description,
|
||||||
failure_examples=failure_examples,
|
failure_examples=failure_examples,
|
||||||
)
|
)
|
||||||
|
self.call_count += 1
|
||||||
return Prompt(text=pred.new_instruction)
|
return Prompt(text=pred.new_instruction)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
153
src/prometheus/infrastructure/similarity.py
Normal file
153
src/prometheus/infrastructure/similarity.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""Similarity adapters — concrete metrics for comparing prediction vs expected."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
from prometheus.domain.ports import SimilarityPort
|
||||||
|
|
||||||
|
|
||||||
|
class ExactMatchSimilarity(SimilarityPort):
|
||||||
|
"""Case-insensitive exact string match. Returns 1.0 or 0.0."""
|
||||||
|
|
||||||
|
def compute(self, prediction: str, expected: str) -> float:
|
||||||
|
return 1.0 if prediction.strip().lower() == expected.strip().lower() else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class BleuSimilarity(SimilarityPort):
|
||||||
|
"""BLEU-style n-gram precision (up to 4-grams).
|
||||||
|
|
||||||
|
Simplified implementation using sentence-level BLEU with brevity penalty.
|
||||||
|
Returns a score in [0, 1].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_n: int = 4):
|
||||||
|
self._max_n = max_n
|
||||||
|
|
||||||
|
def compute(self, prediction: str, expected: str) -> float:
|
||||||
|
pred_tokens = _tokenize(prediction)
|
||||||
|
ref_tokens = _tokenize(expected)
|
||||||
|
if not pred_tokens or not ref_tokens:
|
||||||
|
return 0.0 if not ref_tokens else 0.0
|
||||||
|
|
||||||
|
# Modified precision for each n-gram
|
||||||
|
precisions: list[float] = []
|
||||||
|
for n in range(1, self._max_n + 1):
|
||||||
|
pred_ngrams = _ngrams(pred_tokens, n)
|
||||||
|
ref_ngrams = _ngrams(ref_tokens, n)
|
||||||
|
if not pred_ngrams:
|
||||||
|
break
|
||||||
|
clipped = sum(min(pred_ngrams[ng], ref_ngrams.get(ng, 0)) for ng in pred_ngrams)
|
||||||
|
total = sum(pred_ngrams.values())
|
||||||
|
precisions.append(clipped / total if total > 0 else 0.0)
|
||||||
|
|
||||||
|
if not precisions:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Geometric mean of precisions
|
||||||
|
log_avg = sum(math.log(p) for p in precisions if p > 0)
|
||||||
|
n_nonzero = sum(1 for p in precisions if p > 0)
|
||||||
|
if n_nonzero == 0:
|
||||||
|
return 0.0
|
||||||
|
geo_mean = math.exp(log_avg / n_nonzero)
|
||||||
|
|
||||||
|
# Brevity penalty
|
||||||
|
bp = 1.0
|
||||||
|
if len(pred_tokens) < len(ref_tokens):
|
||||||
|
bp = math.exp(1 - len(ref_tokens) / len(pred_tokens))
|
||||||
|
|
||||||
|
return min(bp * geo_mean, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class RougeLSimilarity(SimilarityPort):
|
||||||
|
"""ROUGE-L using Longest Common Subsequence.
|
||||||
|
|
||||||
|
Returns F1 score combining precision and recall in [0, 1].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute(self, prediction: str, expected: str) -> float:
|
||||||
|
pred_tokens = _tokenize(prediction)
|
||||||
|
ref_tokens = _tokenize(expected)
|
||||||
|
if not pred_tokens or not ref_tokens:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
lcs_len = _lcs_length(pred_tokens, ref_tokens)
|
||||||
|
precision = lcs_len / len(pred_tokens)
|
||||||
|
recall = lcs_len / len(ref_tokens)
|
||||||
|
if precision + recall == 0:
|
||||||
|
return 0.0
|
||||||
|
f1 = 2 * precision * recall / (precision + recall)
|
||||||
|
return f1
|
||||||
|
|
||||||
|
|
||||||
|
class CosineSimilarity(SimilarityPort):
|
||||||
|
"""TF-IDF cosine similarity between bag-of-words vectors.
|
||||||
|
|
||||||
|
Lightweight semantic similarity without external embedding models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute(self, prediction: str, expected: str) -> float:
|
||||||
|
pred_tokens = _tokenize(prediction)
|
||||||
|
ref_tokens = _tokenize(expected)
|
||||||
|
if not pred_tokens or not ref_tokens:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
pred_counts = Counter(pred_tokens)
|
||||||
|
ref_counts = Counter(ref_tokens)
|
||||||
|
all_tokens = set(pred_counts) | set(ref_counts)
|
||||||
|
|
||||||
|
dot = sum(pred_counts.get(t, 0) * ref_counts.get(t, 0) for t in all_tokens)
|
||||||
|
norm_pred = math.sqrt(sum(v * v for v in pred_counts.values()))
|
||||||
|
norm_ref = math.sqrt(sum(v * v for v in ref_counts.values()))
|
||||||
|
|
||||||
|
if norm_pred == 0 or norm_ref == 0:
|
||||||
|
return 0.0
|
||||||
|
return dot / (norm_pred * norm_ref)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helpers ---
|
||||||
|
|
||||||
|
|
||||||
|
def _tokenize(text: str) -> list[str]:
|
||||||
|
"""Simple whitespace + punctuation tokenizer."""
|
||||||
|
return re.findall(r"\w+", text.lower())
|
||||||
|
|
||||||
|
|
||||||
|
def _ngrams(tokens: list[str], n: int) -> Counter:
|
||||||
|
"""Count n-grams in a token list."""
|
||||||
|
return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))
|
||||||
|
|
||||||
|
|
||||||
|
def _lcs_length(a: list[str], b: list[str]) -> int:
|
||||||
|
"""Compute length of the Longest Common Subsequence."""
|
||||||
|
m, n = len(a), len(b)
|
||||||
|
prev = [0] * (n + 1)
|
||||||
|
for i in range(1, m + 1):
|
||||||
|
curr = [0] * (n + 1)
|
||||||
|
for j in range(1, n + 1):
|
||||||
|
if a[i - 1] == b[j - 1]:
|
||||||
|
curr[j] = prev[j - 1] + 1
|
||||||
|
else:
|
||||||
|
curr[j] = max(prev[j], curr[j - 1])
|
||||||
|
prev = curr
|
||||||
|
return prev[n]
|
||||||
|
|
||||||
|
|
||||||
|
def create_similarity_adapter(metric: str) -> SimilarityPort:
|
||||||
|
"""Factory: create a SimilarityPort by metric name.
|
||||||
|
|
||||||
|
Supported metrics: exact, bleu, rouge_l, cosine.
|
||||||
|
"""
|
||||||
|
adapters = {
|
||||||
|
"exact": ExactMatchSimilarity,
|
||||||
|
"bleu": BleuSimilarity,
|
||||||
|
"rouge_l": RougeLSimilarity,
|
||||||
|
"cosine": CosineSimilarity,
|
||||||
|
}
|
||||||
|
cls = adapters.get(metric)
|
||||||
|
if cls is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown eval metric '{metric}'. Choose from: {sorted(adapters)}"
|
||||||
|
)
|
||||||
|
return cls()
|
||||||
@@ -18,6 +18,7 @@ class DSPySyntheticAdapter(SyntheticGeneratorPort):
|
|||||||
def __init__(self, lm: dspy.LM) -> None:
|
def __init__(self, lm: dspy.LM) -> None:
|
||||||
self._lm = lm
|
self._lm = lm
|
||||||
self._generator = SyntheticInputGenerator()
|
self._generator = SyntheticInputGenerator()
|
||||||
|
self.call_count: int = 0
|
||||||
|
|
||||||
def generate_inputs(
|
def generate_inputs(
|
||||||
self,
|
self,
|
||||||
@@ -29,6 +30,7 @@ class DSPySyntheticAdapter(SyntheticGeneratorPort):
|
|||||||
task_description=task_description,
|
task_description=task_description,
|
||||||
n_examples=n_examples,
|
n_examples=n_examples,
|
||||||
)
|
)
|
||||||
|
self.call_count += 1
|
||||||
return [
|
return [
|
||||||
SyntheticExample(
|
SyntheticExample(
|
||||||
input_text=text,
|
input_text=text,
|
||||||
|
|||||||
@@ -91,3 +91,27 @@ def mock_proposer_port() -> AsyncMock:
|
|||||||
text="You are a very helpful assistant. Answer the question precisely."
|
text="You are a very helpful assistant. Answer the question precisely."
|
||||||
)
|
)
|
||||||
return port
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_crossover_port() -> AsyncMock:
|
||||||
|
"""Mock CrossoverPort that combines two parent prompts."""
|
||||||
|
port = AsyncMock()
|
||||||
|
|
||||||
|
async def _crossover(parent_a: Prompt, parent_b: Prompt, task_description: str) -> Prompt:
|
||||||
|
return Prompt(text=f"{parent_a.text} Also, {parent_b.text.lower()}")
|
||||||
|
|
||||||
|
port.crossover = AsyncMock(side_effect=_crossover)
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_mutation_port() -> AsyncMock:
|
||||||
|
"""Mock MutationPort that paraphrases a prompt."""
|
||||||
|
port = AsyncMock()
|
||||||
|
|
||||||
|
async def _mutate(prompt: Prompt, task_description: str, mutation_type: str = "paraphrase") -> Prompt:
|
||||||
|
return Prompt(text=f"[{mutation_type}] {prompt.text}")
|
||||||
|
|
||||||
|
port.mutate = AsyncMock(side_effect=_mutate)
|
||||||
|
return port
|
||||||
|
|||||||
@@ -20,9 +20,10 @@ def mock_lm() -> dspy.LM:
|
|||||||
|
|
||||||
|
|
||||||
class TestDSPyLLMAdapter:
|
class TestDSPyLLMAdapter:
|
||||||
def test_execute_returns_response(self, mock_lm: dspy.LM) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_returns_response(self, mock_lm: dspy.LM) -> None:
|
||||||
adapter = DSPyLLMAdapter(lm=mock_lm)
|
adapter = DSPyLLMAdapter(lm=mock_lm)
|
||||||
prompt = Prompt(text="Answer the question.")
|
prompt = Prompt(text="Answer the question.")
|
||||||
result = adapter.execute(prompt, "What is 2+2?")
|
result = await adapter.execute(prompt, "What is 2+2?")
|
||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
|||||||
300
tests/integration/test_evolution_integration.py
Normal file
300
tests/integration/test_evolution_integration.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""Integration tests for multi-iteration evolution with mixed accept/reject."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||||
|
from prometheus.application.evaluator import PromptEvaluator
|
||||||
|
from prometheus.application.evolution import EvolutionLoop
|
||||||
|
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory
|
||||||
|
from prometheus.domain.ports import JudgePort, LLMPort, ProposerPort
|
||||||
|
|
||||||
|
|
||||||
|
def _make_eval(scores: list[float]) -> EvalResult:
|
||||||
|
return EvalResult(
|
||||||
|
scores=scores,
|
||||||
|
feedbacks=["feedback"] * len(scores),
|
||||||
|
trajectories=[
|
||||||
|
Trajectory(f"in{i}", f"out{i}", s, "feedback", "prompt")
|
||||||
|
for i, s in enumerate(scores)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiIterationEvolution:
|
||||||
|
"""Tests for the evolution loop across multiple iterations."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def seed_prompt(self) -> Prompt:
|
||||||
|
return Prompt(text="You are a helpful assistant.")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def task_description(self) -> str:
|
||||||
|
return "Answer factual questions."
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def synthetic_pool(self) -> list[SyntheticExample]:
|
||||||
|
return [SyntheticExample(input_text=f"input {i}", id=i) for i in range(20)]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mixed_accept_reject(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
) -> None:
|
||||||
|
"""Iteration 1: accept, iteration 2: reject, iteration 3: accept."""
|
||||||
|
mock_llm = MagicMock(spec=LLMPort)
|
||||||
|
mock_judge = MagicMock(spec=JudgePort)
|
||||||
|
mock_proposer = MagicMock(spec=ProposerPort)
|
||||||
|
|
||||||
|
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:3]
|
||||||
|
|
||||||
|
# Build eval sequence: initial, then per-iteration (current, new)
|
||||||
|
evals = [
|
||||||
|
_make_eval([0.3, 0.3, 0.3]), # initial seed eval
|
||||||
|
# Iter 1: accept (old=0.4, new=0.8)
|
||||||
|
_make_eval([0.4, 0.4, 0.4]),
|
||||||
|
_make_eval([0.8, 0.8, 0.8]),
|
||||||
|
# Iter 2: reject (old=0.7, new=0.2)
|
||||||
|
_make_eval([0.7, 0.7, 0.7]),
|
||||||
|
_make_eval([0.2, 0.2, 0.2]),
|
||||||
|
# Iter 3: accept (old=0.5, new=0.9)
|
||||||
|
_make_eval([0.5, 0.5, 0.5]),
|
||||||
|
_make_eval([0.9, 0.9, 0.9]),
|
||||||
|
]
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||||
|
|
||||||
|
mock_proposer.propose.side_effect = [
|
||||||
|
Prompt(text="Better prompt v1"),
|
||||||
|
Prompt(text="Worse prompt v2"),
|
||||||
|
Prompt(text="Best prompt v3"),
|
||||||
|
]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=3,
|
||||||
|
minibatch_size=3,
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
assert state.iteration == 3
|
||||||
|
assert state.best_candidate is not None
|
||||||
|
assert state.best_candidate.best_score == pytest.approx(2.7) # 0.9*3
|
||||||
|
assert len(state.history) == 3
|
||||||
|
assert state.history[0]["event"] == "accepted"
|
||||||
|
assert state.history[1]["event"] == "rejected"
|
||||||
|
assert state.history[2]["event"] == "accepted"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_rejected_keeps_seed(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
) -> None:
|
||||||
|
"""When all proposals are rejected, the seed prompt stays as best."""
|
||||||
|
mock_llm = MagicMock(spec=LLMPort)
|
||||||
|
mock_judge = MagicMock(spec=JudgePort)
|
||||||
|
mock_proposer = MagicMock(spec=ProposerPort)
|
||||||
|
|
||||||
|
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:3]
|
||||||
|
|
||||||
|
evals = [
|
||||||
|
_make_eval([0.5, 0.5, 0.5]), # initial
|
||||||
|
]
|
||||||
|
for _ in range(3):
|
||||||
|
evals.append(_make_eval([0.5, 0.5, 0.5])) # current
|
||||||
|
evals.append(_make_eval([0.1, 0.1, 0.1])) # worse proposal
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||||
|
|
||||||
|
mock_proposer.propose.side_effect = [
|
||||||
|
Prompt(text="bad v1"),
|
||||||
|
Prompt(text="bad v2"),
|
||||||
|
Prompt(text="bad v3"),
|
||||||
|
]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=3,
|
||||||
|
minibatch_size=3,
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
assert state.best_candidate.prompt.text == seed_prompt.text
|
||||||
|
assert state.best_candidate.best_score == pytest.approx(1.5) # 0.5*3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_accepted_chain(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
) -> None:
|
||||||
|
"""All iterations accept, forming an improvement chain."""
|
||||||
|
mock_llm = MagicMock(spec=LLMPort)
|
||||||
|
mock_judge = MagicMock(spec=JudgePort)
|
||||||
|
mock_proposer = MagicMock(spec=ProposerPort)
|
||||||
|
|
||||||
|
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:2]
|
||||||
|
|
||||||
|
evals = [
|
||||||
|
_make_eval([0.2, 0.2]), # initial
|
||||||
|
]
|
||||||
|
for i in range(1, 5):
|
||||||
|
score = 0.2 + i * 0.15
|
||||||
|
evals.append(_make_eval([score, score])) # current
|
||||||
|
evals.append(_make_eval([score + 0.1, score + 0.1])) # new (accepted)
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||||
|
|
||||||
|
mock_proposer.propose.side_effect = [
|
||||||
|
Prompt(text=f"Improved v{i}") for i in range(4)
|
||||||
|
]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=4,
|
||||||
|
minibatch_size=2,
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
assert len(state.candidates) == 5 # seed + 4 accepted
|
||||||
|
assert all(h["event"] == "accepted" for h in state.history)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_recovery_continues_loop(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
) -> None:
|
||||||
|
"""When an iteration errors, the loop continues."""
|
||||||
|
mock_llm = MagicMock(spec=LLMPort)
|
||||||
|
mock_judge = MagicMock(spec=JudgePort)
|
||||||
|
mock_proposer = MagicMock(spec=ProposerPort)
|
||||||
|
|
||||||
|
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:2]
|
||||||
|
|
||||||
|
# Eval sequence for 3 iterations:
|
||||||
|
# - iter 1: evaluate current → propose → evaluate new (accepted)
|
||||||
|
# - iter 2: evaluate current → propose (ERROR, no new eval)
|
||||||
|
# - iter 3: evaluate current → propose → evaluate new (accepted)
|
||||||
|
evals = [
|
||||||
|
_make_eval([0.3, 0.3]), # initial
|
||||||
|
_make_eval([0.5, 0.5]), # iter 1 current
|
||||||
|
_make_eval([0.9, 0.9]), # iter 1 new (accepted)
|
||||||
|
_make_eval([0.5, 0.5]), # iter 2 current (proposer errors after this)
|
||||||
|
_make_eval([0.5, 0.5]), # iter 3 current
|
||||||
|
_make_eval([0.8, 0.8]), # iter 3 new (accepted)
|
||||||
|
]
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||||
|
|
||||||
|
# Proposer raises on iter 2
|
||||||
|
mock_proposer.propose.side_effect = [
|
||||||
|
Prompt(text="good v1"),
|
||||||
|
RuntimeError("LLM timeout"),
|
||||||
|
Prompt(text="good v3"),
|
||||||
|
]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=3,
|
||||||
|
minibatch_size=2,
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
assert state.iteration == 3
|
||||||
|
assert state.history[1]["event"] == "error"
|
||||||
|
assert "LLM timeout" in state.history[1]["error"]
|
||||||
|
assert state.history[0]["event"] == "accepted"
|
||||||
|
assert state.history[2]["event"] == "accepted"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_perfect_score_skips_proposer(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
) -> None:
|
||||||
|
"""When all scores are perfect, no proposition is made."""
|
||||||
|
mock_llm = MagicMock(spec=LLMPort)
|
||||||
|
mock_judge = MagicMock(spec=JudgePort)
|
||||||
|
mock_proposer = MagicMock(spec=ProposerPort)
|
||||||
|
|
||||||
|
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:2]
|
||||||
|
|
||||||
|
perfect_eval = _make_eval([1.0, 1.0])
|
||||||
|
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=5,
|
||||||
|
minibatch_size=2,
|
||||||
|
perfect_score=1.0,
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
mock_proposer.propose.assert_not_called()
|
||||||
|
assert all(h["event"] == "skip_perfect" for h in state.history)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_call_counting(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
task_description: str,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
) -> None:
|
||||||
|
"""Verify LLM call counting: 2*N per eval (execute + judge) + 1 per propose."""
|
||||||
|
mock_llm = MagicMock(spec=LLMPort)
|
||||||
|
mock_judge = MagicMock(spec=JudgePort)
|
||||||
|
mock_proposer = MagicMock(spec=ProposerPort)
|
||||||
|
|
||||||
|
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:3]
|
||||||
|
|
||||||
|
evals = [_make_eval([0.3, 0.3, 0.3])] # initial
|
||||||
|
for _ in range(2):
|
||||||
|
evals.append(_make_eval([0.4, 0.4, 0.4]))
|
||||||
|
evals.append(_make_eval([0.6, 0.6, 0.6]))
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||||
|
|
||||||
|
mock_proposer.propose.side_effect = [
|
||||||
|
Prompt(text="v1"),
|
||||||
|
Prompt(text="v2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=2,
|
||||||
|
minibatch_size=3,
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
# Initial: 2*3=6, Iter1: 2*3 + 1 + 2*3 = 13, Iter2: same = 13
|
||||||
|
# Total: 6 + 13 + 13 = 32
|
||||||
|
assert state.total_llm_calls == 32
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
"""End-to-end pipeline test with mocked LLM calls."""
|
"""End-to-end pipeline test with mocked LLM calls."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||||
from prometheus.application.dto import OptimizationConfig
|
from prometheus.application.dto import OptimizationConfig
|
||||||
@@ -23,9 +25,10 @@ def _make_eval(scores: list[float]) -> EvalResult:
|
|||||||
|
|
||||||
|
|
||||||
class TestFullPipeline:
|
class TestFullPipeline:
|
||||||
def test_pipeline_produces_result(self) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_produces_result(self) -> None:
|
||||||
"""Full pipeline with mocked ports produces an OptimizationResult."""
|
"""Full pipeline with mocked ports produces an OptimizationResult."""
|
||||||
mock_llm = MagicMock(spec=LLMPort)
|
mock_llm = AsyncMock(spec=LLMPort)
|
||||||
mock_llm.execute.return_value = "mock response"
|
mock_llm.execute.return_value = "mock response"
|
||||||
|
|
||||||
mock_judge = MagicMock(spec=JudgePort)
|
mock_judge = MagicMock(spec=JudgePort)
|
||||||
@@ -38,11 +41,11 @@ class TestFullPipeline:
|
|||||||
eval_sequence.append(_make_eval([0.6, 0.6, 0.6, 0.6, 0.6])) # new eval (accepted)
|
eval_sequence.append(_make_eval([0.6, 0.6, 0.6, 0.6, 0.6])) # new eval (accepted)
|
||||||
mock_judge.judge_batch.return_value = [(0.5, "ok")] * 5
|
mock_judge.judge_batch.return_value = [(0.5, "ok")] * 5
|
||||||
|
|
||||||
mock_proposer = MagicMock(spec=ProposerPort)
|
mock_proposer = AsyncMock(spec=ProposerPort)
|
||||||
mock_proposer.propose.return_value = Prompt(text="Improved prompt")
|
mock_proposer.propose.return_value = Prompt(text="Improved prompt")
|
||||||
|
|
||||||
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||||
evaluator.evaluate = MagicMock(side_effect=eval_sequence)
|
evaluator.evaluate = AsyncMock(side_effect=eval_sequence)
|
||||||
|
|
||||||
mock_gen = MagicMock()
|
mock_gen = MagicMock()
|
||||||
mock_gen.generate_inputs.return_value = [
|
mock_gen.generate_inputs.return_value = [
|
||||||
@@ -65,7 +68,7 @@ class TestFullPipeline:
|
|||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = use_case.execute(config)
|
result = await use_case.execute(config)
|
||||||
|
|
||||||
assert result.initial_prompt == "Answer questions."
|
assert result.initial_prompt == "Answer questions."
|
||||||
assert result.optimized_prompt == "Improved prompt"
|
assert result.optimized_prompt == "Improved prompt"
|
||||||
|
|||||||
199
tests/integration/test_ground_truth_eval.py
Normal file
199
tests/integration/test_ground_truth_eval.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
"""Integration test — ground-truth evaluation end-to-end with real similarity metrics."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from prometheus.application.ground_truth_evaluator import GroundTruthEvaluator
|
||||||
|
from prometheus.domain.entities import GroundTruthExample, Prompt
|
||||||
|
from prometheus.domain.ports import LLMPort
|
||||||
|
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
|
||||||
|
from prometheus.infrastructure.similarity import (
|
||||||
|
BleuSimilarity,
|
||||||
|
CosineSimilarity,
|
||||||
|
ExactMatchSimilarity,
|
||||||
|
RougeLSimilarity,
|
||||||
|
create_similarity_adapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_dataset(items: list[tuple[str, str]]) -> list[GroundTruthExample]:
|
||||||
|
return [
|
||||||
|
GroundTruthExample(input_text=inp, expected_output=exp, id=i)
|
||||||
|
for i, (inp, exp) in enumerate(items)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def qa_dataset():
|
||||||
|
return _make_dataset([
|
||||||
|
("What is the capital of France?", "Paris"),
|
||||||
|
("What is 2+2?", "4"),
|
||||||
|
("What color is the sky?", "blue"),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prompt():
|
||||||
|
return Prompt(text="Answer the following question concisely.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_executor():
|
||||||
|
"""Returns responses that partially match the ground truth."""
|
||||||
|
port = AsyncMock(spec=LLMPort)
|
||||||
|
port.execute.side_effect = [
|
||||||
|
"Paris is the capital of France.",
|
||||||
|
"The answer is 4.",
|
||||||
|
"The sky is blue.",
|
||||||
|
]
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
class TestGroundTruthIntegrationWithExactMatch:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exact_match_on_qa(self, mock_executor, qa_dataset, prompt):
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor,
|
||||||
|
similarity=ExactMatchSimilarity(),
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||||
|
# None of the outputs are exact matches with expected outputs
|
||||||
|
assert all(s == 0.0 for s in result.scores)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exact_match_with_exact_outputs(self, qa_dataset, prompt):
|
||||||
|
exact_executor = AsyncMock(spec=LLMPort)
|
||||||
|
exact_executor.execute.side_effect = ["Paris", "4", "blue"]
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=exact_executor,
|
||||||
|
similarity=ExactMatchSimilarity(),
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||||
|
assert all(s == 1.0 for s in result.scores)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGroundTruthIntegrationWithBleu:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bleu_scores_partial_match(self, mock_executor, qa_dataset, prompt):
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor,
|
||||||
|
similarity=BleuSimilarity(),
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||||
|
assert all(0.0 < s < 1.0 for s in result.scores)
|
||||||
|
assert result.mean_score > 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bleu_perfect_match(self, qa_dataset, prompt):
|
||||||
|
perfect_executor = AsyncMock(spec=LLMPort)
|
||||||
|
perfect_executor.execute.side_effect = ["Paris", "4", "blue"]
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=perfect_executor,
|
||||||
|
similarity=BleuSimilarity(),
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||||
|
assert all(s > 0.0 for s in result.scores)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGroundTruthIntegrationWithRouge:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rouge_l_scores(self, mock_executor, qa_dataset, prompt):
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor,
|
||||||
|
similarity=RougeLSimilarity(),
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||||
|
assert all(s > 0.0 for s in result.scores)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGroundTruthIntegrationWithCosine:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cosine_scores(self, mock_executor, qa_dataset, prompt):
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor,
|
||||||
|
similarity=CosineSimilarity(),
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||||
|
assert all(s > 0.0 for s in result.scores)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetLoaderIntegration:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_csv_and_evaluate(self, tmp_path, prompt):
|
||||||
|
csv_file = tmp_path / "eval.csv"
|
||||||
|
csv_file.write_text("input,expected_output\nWhat is 2+2?,4\nWhat color is grass?,green\n")
|
||||||
|
|
||||||
|
loader = FileDatasetLoader()
|
||||||
|
dataset = loader.load(str(csv_file))
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
|
executor = AsyncMock(spec=LLMPort)
|
||||||
|
executor.execute.side_effect = ["4", "green"]
|
||||||
|
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=executor,
|
||||||
|
similarity=ExactMatchSimilarity(),
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, dataset)
|
||||||
|
assert all(s == 1.0 for s in result.scores)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_json_and_evaluate(self, tmp_path, prompt):
|
||||||
|
json_file = tmp_path / "eval.json"
|
||||||
|
data = [
|
||||||
|
{"input": "What is 2+2?", "expected_output": "4"},
|
||||||
|
{"input": "What color is grass?", "expected_output": "green"},
|
||||||
|
]
|
||||||
|
json_file.write_text(json.dumps(data))
|
||||||
|
|
||||||
|
loader = FileDatasetLoader()
|
||||||
|
dataset = loader.load(str(json_file))
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
|
executor = AsyncMock(spec=LLMPort)
|
||||||
|
executor.execute.side_effect = ["4", "not green"]
|
||||||
|
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=executor,
|
||||||
|
similarity=create_similarity_adapter("bleu"),
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, dataset)
|
||||||
|
# First item should score well, second poorly
|
||||||
|
assert result.scores[0] > result.scores[1]
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricComparison:
|
||||||
|
"""Compare different metrics on the same outputs to ensure they behave differently."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_metrics_give_different_scores(self, qa_dataset, prompt):
|
||||||
|
results = {}
|
||||||
|
for metric_name, metric_cls in [
|
||||||
|
("exact", ExactMatchSimilarity),
|
||||||
|
("bleu", BleuSimilarity),
|
||||||
|
("rouge_l", RougeLSimilarity),
|
||||||
|
("cosine", CosineSimilarity),
|
||||||
|
]:
|
||||||
|
executor = AsyncMock(spec=LLMPort)
|
||||||
|
executor.execute.side_effect = [
|
||||||
|
"Paris is the capital of France.",
|
||||||
|
"The answer is 4.",
|
||||||
|
"The sky is blue.",
|
||||||
|
]
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=executor,
|
||||||
|
similarity=metric_cls(),
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||||
|
results[metric_name] = result.mean_score
|
||||||
|
|
||||||
|
# Exact match should be 0 (no exact matches)
|
||||||
|
assert results["exact"] == 0.0
|
||||||
|
# All other metrics should give partial credit
|
||||||
|
assert results["bleu"] > 0.0
|
||||||
|
assert results["rouge_l"] > 0.0
|
||||||
|
assert results["cosine"] > 0.0
|
||||||
294
tests/unit/test_adapters.py
Normal file
294
tests/unit/test_adapters.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
"""Unit tests for infrastructure adapters — LLM, Judge, Proposer, Synthetic.
|
||||||
|
|
||||||
|
Uses mocked DSPy modules to isolate adapter logic from LLM calls.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory
|
||||||
|
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||||
|
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
||||||
|
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
||||||
|
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
||||||
|
|
||||||
|
|
||||||
|
# --- LLM Adapter ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestDSPyLLMAdapter:
|
||||||
|
"""Tests for DSPyLLMAdapter.execute()."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_lm(self) -> MagicMock:
|
||||||
|
return MagicMock(spec=dspy.LM)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def adapter(self, mock_lm: MagicMock) -> DSPyLLMAdapter:
|
||||||
|
return DSPyLLMAdapter(lm=mock_lm)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_returns_output_string(
|
||||||
|
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
mock_predictor = MagicMock()
|
||||||
|
mock_predictor.return_value = MagicMock(output="Hello response")
|
||||||
|
adapter._predictor = mock_predictor
|
||||||
|
|
||||||
|
prompt = Prompt(text="Say hello.")
|
||||||
|
result = await adapter.execute(prompt, "Hi there")
|
||||||
|
|
||||||
|
assert result == "Hello response"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_passes_prompt_text_and_input(
|
||||||
|
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
mock_predictor = MagicMock()
|
||||||
|
mock_predictor.return_value = MagicMock(output="response")
|
||||||
|
adapter._predictor = mock_predictor
|
||||||
|
|
||||||
|
prompt = Prompt(text="Translate this.")
|
||||||
|
await adapter.execute(prompt, "Hello world")
|
||||||
|
|
||||||
|
mock_predictor.assert_called_once_with(
|
||||||
|
instruction="Translate this.",
|
||||||
|
input_text="Hello world",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_uses_dspy_context(
|
||||||
|
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
mock_predictor = MagicMock()
|
||||||
|
mock_predictor.return_value = MagicMock(output="ok")
|
||||||
|
adapter._predictor = mock_predictor
|
||||||
|
|
||||||
|
with patch("prometheus.infrastructure.llm_adapter.dspy.context") as mock_ctx:
|
||||||
|
await adapter.execute(Prompt(text="test"), "input")
|
||||||
|
mock_ctx.assert_called_once_with(lm=mock_lm)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_converts_output_to_str(
|
||||||
|
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
mock_predictor = MagicMock()
|
||||||
|
mock_predictor.return_value = MagicMock(output=42)
|
||||||
|
adapter._predictor = mock_predictor
|
||||||
|
|
||||||
|
result = await adapter.execute(Prompt(text="test"), "input")
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert result == "42"
|
||||||
|
|
||||||
|
|
||||||
|
# --- Judge Adapter ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestDSPyJudgeAdapter:
|
||||||
|
"""Tests for DSPyJudgeAdapter.judge_batch()."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_lm(self) -> MagicMock:
|
||||||
|
return MagicMock(spec=dspy.LM)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def adapter(self, mock_lm: MagicMock) -> DSPyJudgeAdapter:
|
||||||
|
return DSPyJudgeAdapter(lm=mock_lm)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_judge_batch_returns_scores_and_feedback(
|
||||||
|
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
adapter._judge = MagicMock()
|
||||||
|
adapter._judge.side_effect = [
|
||||||
|
MagicMock(score=0.9, feedback="Excellent."),
|
||||||
|
MagicMock(score=0.4, feedback="Incomplete."),
|
||||||
|
]
|
||||||
|
|
||||||
|
pairs = [("What is 2+2?", "4"), ("Capital of France?", "London")]
|
||||||
|
result = await adapter.judge_batch("math and geography", pairs)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0] == (0.9, "Excellent.")
|
||||||
|
assert result[1] == (0.4, "Incomplete.")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_judge_batch_empty_pairs(
|
||||||
|
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
result = await adapter.judge_batch("task", [])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_judge_batch_uses_dspy_context(
|
||||||
|
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
adapter._judge = MagicMock()
|
||||||
|
adapter._judge.return_value = MagicMock(score=0.5, feedback="ok")
|
||||||
|
|
||||||
|
with patch("prometheus.infrastructure.judge_adapter.dspy.context") as mock_ctx:
|
||||||
|
await adapter.judge_batch("task", [("in", "out")])
|
||||||
|
mock_ctx.assert_called_once_with(lm=mock_lm)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_judge_batch_returns_all_results(
|
||||||
|
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
"""Judge calls run in parallel but all results are returned."""
|
||||||
|
adapter._judge = MagicMock()
|
||||||
|
adapter._judge.side_effect = [
|
||||||
|
MagicMock(score=0.5, feedback="ok"),
|
||||||
|
MagicMock(score=0.7, feedback="better"),
|
||||||
|
MagicMock(score=0.3, feedback="worse"),
|
||||||
|
]
|
||||||
|
|
||||||
|
pairs = [("first", "out1"), ("second", "out2"), ("third", "out3")]
|
||||||
|
results = await adapter.judge_batch("task", pairs)
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
scores = [r[0] for r in results]
|
||||||
|
assert 0.5 in scores
|
||||||
|
assert 0.7 in scores
|
||||||
|
assert 0.3 in scores
|
||||||
|
|
||||||
|
|
||||||
|
# --- Proposer Adapter ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestDSPyProposerAdapter:
|
||||||
|
"""Tests for DSPyProposerAdapter.propose()."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_lm(self) -> MagicMock:
|
||||||
|
return MagicMock(spec=dspy.LM)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def adapter(self, mock_lm: MagicMock) -> DSPyProposerAdapter:
|
||||||
|
return DSPyProposerAdapter(lm=mock_lm)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_propose_returns_new_prompt(
|
||||||
|
self, adapter: DSPyProposerAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
adapter._proposer = MagicMock()
|
||||||
|
adapter._proposer.return_value = MagicMock(
|
||||||
|
new_instruction="Be concise and accurate."
|
||||||
|
)
|
||||||
|
|
||||||
|
current = Prompt(text="Answer questions.")
|
||||||
|
trajectories = [
|
||||||
|
Trajectory("in", "out", 0.3, "too verbose", "Answer questions.")
|
||||||
|
]
|
||||||
|
result = await adapter.propose(current, trajectories, "Q&A task")
|
||||||
|
|
||||||
|
assert isinstance(result, Prompt)
|
||||||
|
assert result.text == "Be concise and accurate."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_propose_uses_dspy_context(
|
||||||
|
self, adapter: DSPyProposerAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
adapter._proposer = MagicMock()
|
||||||
|
adapter._proposer.return_value = MagicMock(new_instruction="improved")
|
||||||
|
|
||||||
|
with patch("prometheus.infrastructure.proposer_adapter.dspy.context") as mock_ctx:
|
||||||
|
await adapter.propose(Prompt(text="t"), [], "task")
|
||||||
|
mock_ctx.assert_called_once_with(lm=mock_lm)
|
||||||
|
|
||||||
|
def test_format_failures_single_trajectory(self) -> None:
|
||||||
|
trajectories = [
|
||||||
|
Trajectory("What is AI?", "A type of robot.", 0.3, "Incomplete definition.", "prompt")
|
||||||
|
]
|
||||||
|
result = DSPyProposerAdapter._format_failures(trajectories)
|
||||||
|
|
||||||
|
assert "What is AI?" in result
|
||||||
|
assert "A type of robot." in result
|
||||||
|
assert "0.30" in result
|
||||||
|
assert "Incomplete definition." in result
|
||||||
|
assert "# Example 1" in result
|
||||||
|
|
||||||
|
def test_format_failures_multiple_trajectories(self) -> None:
|
||||||
|
trajectories = [
|
||||||
|
Trajectory("input1", "output1", 0.4, "bad", "prompt"),
|
||||||
|
Trajectory("input2", "output2", 0.2, "worse", "prompt"),
|
||||||
|
]
|
||||||
|
result = DSPyProposerAdapter._format_failures(trajectories)
|
||||||
|
|
||||||
|
assert "# Example 1" in result
|
||||||
|
assert "# Example 2" in result
|
||||||
|
assert "---" in result
|
||||||
|
assert "input1" in result
|
||||||
|
assert "input2" in result
|
||||||
|
|
||||||
|
def test_format_failures_empty_list(self) -> None:
|
||||||
|
result = DSPyProposerAdapter._format_failures([])
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
|
# --- Synthetic Adapter ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestDSPySyntheticAdapter:
|
||||||
|
"""Tests for DSPySyntheticAdapter.generate_inputs()."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_lm(self) -> MagicMock:
|
||||||
|
return MagicMock(spec=dspy.LM)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def adapter(self, mock_lm: MagicMock) -> DSPySyntheticAdapter:
|
||||||
|
return DSPySyntheticAdapter(lm=mock_lm)
|
||||||
|
|
||||||
|
def test_generate_inputs_returns_examples(
|
||||||
|
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
adapter._generator = MagicMock()
|
||||||
|
adapter._generator.return_value = MagicMock(
|
||||||
|
examples=["What is AI?", "Explain ML.", "What is NLP?"]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = adapter.generate_inputs("AI task", 3)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
assert all(isinstance(ex, SyntheticExample) for ex in result)
|
||||||
|
assert result[0].input_text == "What is AI?"
|
||||||
|
assert result[0].id == 0
|
||||||
|
assert result[1].id == 1
|
||||||
|
|
||||||
|
def test_generate_inputs_truncates_to_n(
|
||||||
|
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
adapter._generator = MagicMock()
|
||||||
|
adapter._generator.return_value = MagicMock(
|
||||||
|
examples=["q1", "q2", "q3", "q4", "q5"]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = adapter.generate_inputs("task", 3)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
def test_generate_inputs_passes_correct_args(
|
||||||
|
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
adapter._generator = MagicMock()
|
||||||
|
adapter._generator.return_value = MagicMock(examples=["q1"])
|
||||||
|
|
||||||
|
adapter.generate_inputs("my task", 5)
|
||||||
|
|
||||||
|
adapter._generator.assert_called_once_with(
|
||||||
|
task_description="my task",
|
||||||
|
n_examples=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_generate_inputs_empty_list(
|
||||||
|
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||||
|
) -> None:
|
||||||
|
adapter._generator = MagicMock()
|
||||||
|
adapter._generator.return_value = MagicMock(examples=[])
|
||||||
|
|
||||||
|
result = adapter.generate_inputs("task", 0)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
333
tests/unit/test_checkpoint.py
Normal file
333
tests/unit/test_checkpoint.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
"""Unit tests for checkpoint & resume functionality."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||||
|
from prometheus.application.evaluator import PromptEvaluator
|
||||||
|
from prometheus.application.evolution import EvolutionLoop
|
||||||
|
from prometheus.domain.entities import (
|
||||||
|
Candidate,
|
||||||
|
EvalResult,
|
||||||
|
OptimizationState,
|
||||||
|
Prompt,
|
||||||
|
SyntheticExample,
|
||||||
|
Trajectory,
|
||||||
|
)
|
||||||
|
from prometheus.infrastructure.checkpoint import JsonCheckpointPersistence
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JsonCheckpointPersistence — save/load round-trip
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonCheckpointPersistence:
|
||||||
|
def test_roundtrip_full_state(self, tmp_path: Path) -> None:
|
||||||
|
"""Saving and loading preserves all fields."""
|
||||||
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
|
||||||
|
|
||||||
|
state = OptimizationState(
|
||||||
|
iteration=7,
|
||||||
|
best_candidate=Candidate(
|
||||||
|
prompt=Prompt(text="best prompt", metadata={"source": "test"}),
|
||||||
|
best_score=0.92,
|
||||||
|
generation=5,
|
||||||
|
),
|
||||||
|
candidates=[
|
||||||
|
Candidate(prompt=Prompt(text="p1"), best_score=0.5, generation=0),
|
||||||
|
Candidate(prompt=Prompt(text="p2"), best_score=0.92, generation=5),
|
||||||
|
],
|
||||||
|
synthetic_pool=[
|
||||||
|
SyntheticExample(input_text="q1", category="cat_a", id=0),
|
||||||
|
SyntheticExample(input_text="q2", category="cat_b", id=1),
|
||||||
|
],
|
||||||
|
history=[{"iteration": 1, "event": "accepted", "old_score": 0.5, "new_score": 0.7}],
|
||||||
|
total_llm_calls=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt.save(state)
|
||||||
|
assert ckpt.latest_exists()
|
||||||
|
|
||||||
|
loaded = ckpt.load()
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.iteration == 7
|
||||||
|
assert loaded.total_llm_calls == 42
|
||||||
|
assert loaded.best_candidate is not None
|
||||||
|
assert loaded.best_candidate.prompt.text == "best prompt"
|
||||||
|
assert loaded.best_candidate.prompt.metadata == {"source": "test"}
|
||||||
|
assert loaded.best_candidate.best_score == 0.92
|
||||||
|
assert len(loaded.candidates) == 2
|
||||||
|
assert len(loaded.synthetic_pool) == 2
|
||||||
|
assert loaded.synthetic_pool[0].input_text == "q1"
|
||||||
|
assert loaded.synthetic_pool[1].category == "cat_b"
|
||||||
|
assert loaded.history[0]["event"] == "accepted"
|
||||||
|
|
||||||
|
def test_load_returns_none_when_no_checkpoint(self, tmp_path: Path) -> None:
|
||||||
|
"""Loading from empty dir returns None."""
|
||||||
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "nope")
|
||||||
|
assert ckpt.load() is None
|
||||||
|
assert not ckpt.latest_exists()
|
||||||
|
|
||||||
|
def test_creates_directory_on_save(self, tmp_path: Path) -> None:
|
||||||
|
"""Save creates the directory tree if it doesn't exist."""
|
||||||
|
deep_dir = tmp_path / "a" / "b" / "c"
|
||||||
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=deep_dir)
|
||||||
|
state = OptimizationState(iteration=1)
|
||||||
|
ckpt.save(state)
|
||||||
|
assert (deep_dir / "latest.json").exists()
|
||||||
|
|
||||||
|
def test_overwrites_previous_checkpoint(self, tmp_path: Path) -> None:
|
||||||
|
"""Second save overwrites the first."""
|
||||||
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
|
||||||
|
|
||||||
|
ckpt.save(OptimizationState(iteration=1, total_llm_calls=10))
|
||||||
|
ckpt.save(OptimizationState(iteration=5, total_llm_calls=50))
|
||||||
|
|
||||||
|
loaded = ckpt.load()
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.iteration == 5
|
||||||
|
assert loaded.total_llm_calls == 50
|
||||||
|
|
||||||
|
def test_json_is_human_readable(self, tmp_path: Path) -> None:
|
||||||
|
"""Checkpoint file is valid, pretty-printed JSON."""
|
||||||
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
|
||||||
|
state = OptimizationState(
|
||||||
|
iteration=3,
|
||||||
|
best_candidate=Candidate(prompt=Prompt(text="hello"), best_score=0.8),
|
||||||
|
)
|
||||||
|
ckpt.save(state)
|
||||||
|
|
||||||
|
raw = json.loads((tmp_path / "latest.json").read_text())
|
||||||
|
assert raw["schema_version"] == 1
|
||||||
|
assert raw["iteration"] == 3
|
||||||
|
assert raw["best_candidate"]["prompt_text"] == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EvolutionLoop — checkpoint integration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvolutionCheckpoint:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_saved_on_interval(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
) -> None:
|
||||||
|
"""Checkpoint is saved every checkpoint_interval iterations."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
# All iterations accepted so checkpoint triggers
|
||||||
|
good_eval = EvalResult(
|
||||||
|
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||||
|
feedbacks=["ok"] * 5,
|
||||||
|
trajectories=[
|
||||||
|
Trajectory(f"input{i}", f"out{i}", s, "ok", "p")
|
||||||
|
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||||
|
],
|
||||||
|
)
|
||||||
|
better_eval = EvalResult(
|
||||||
|
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
||||||
|
feedbacks=["good"] * 5,
|
||||||
|
trajectories=[],
|
||||||
|
)
|
||||||
|
# initial_eval + 5 iterations (each needs old_eval + new_eval)
|
||||||
|
evaluator.evaluate = AsyncMock(
|
||||||
|
side_effect=[good_eval] # initial
|
||||||
|
+ [good_eval, better_eval] * 5 # 5 iterations
|
||||||
|
)
|
||||||
|
|
||||||
|
proposer = AsyncMock()
|
||||||
|
proposer.propose.return_value = Prompt(text="improved prompt")
|
||||||
|
|
||||||
|
checkpoint_port = MagicMock()
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=5,
|
||||||
|
minibatch_size=5,
|
||||||
|
checkpoint_port=checkpoint_port,
|
||||||
|
checkpoint_interval=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
# Checkpoint at iterations 2, 4 (every 2nd)
|
||||||
|
save_calls = checkpoint_port.save.call_count
|
||||||
|
assert save_calls >= 2 # at least at iters 2 and 4
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_checkpoint_without_port(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
) -> None:
|
||||||
|
"""No checkpointing happens when checkpoint_port is None (default)."""
|
||||||
|
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
perfect_eval = EvalResult(
|
||||||
|
scores=[1.0] * 5,
|
||||||
|
feedbacks=["perfect"] * 5,
|
||||||
|
trajectories=[
|
||||||
|
Trajectory(f"in{i}", f"out{i}", 1.0, "perfect", "p")
|
||||||
|
for i in range(5)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=AsyncMock(),
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=3,
|
||||||
|
minibatch_size=5,
|
||||||
|
checkpoint_port=None,
|
||||||
|
)
|
||||||
|
# Should run without error — no checkpoint port, no crash
|
||||||
|
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_skips_seed_evaluation(
|
||||||
|
self,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
) -> None:
|
||||||
|
"""When initial_state is provided, seed eval is skipped and loop starts from saved iteration."""
|
||||||
|
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
proposer = AsyncMock()
|
||||||
|
proposer.propose.return_value = Prompt(text="new prompt")
|
||||||
|
|
||||||
|
# Only return evaluations for resumed iterations (1 iter: old_eval + new_eval)
|
||||||
|
old_eval = EvalResult(
|
||||||
|
scores=[0.5] * 5,
|
||||||
|
feedbacks=["ok"] * 5,
|
||||||
|
trajectories=[
|
||||||
|
Trajectory(f"in{i}", f"out{i}", 0.5, "ok", "p") for i in range(5)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
new_eval = EvalResult(
|
||||||
|
scores=[0.8] * 5,
|
||||||
|
feedbacks=["good"] * 5,
|
||||||
|
trajectories=[],
|
||||||
|
)
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
|
||||||
|
|
||||||
|
# Create a state simulating checkpoint at iteration 4
|
||||||
|
initial_state = OptimizationState(
|
||||||
|
iteration=4,
|
||||||
|
best_candidate=Candidate(
|
||||||
|
prompt=Prompt(text="checkpoint prompt"), best_score=2.5, generation=4
|
||||||
|
),
|
||||||
|
candidates=[Candidate(prompt=Prompt(text="checkpoint prompt"), best_score=2.5)],
|
||||||
|
total_llm_calls=40,
|
||||||
|
)
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=5, # only iteration 5 remains
|
||||||
|
minibatch_size=5,
|
||||||
|
)
|
||||||
|
state = await loop.run(
|
||||||
|
seed_prompt=Prompt(text="seed"),
|
||||||
|
synthetic_pool=synthetic_pool,
|
||||||
|
task_description=task_description,
|
||||||
|
initial_state=initial_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have run only 1 iteration (iter 5)
|
||||||
|
assert state.iteration == 5
|
||||||
|
# total_llm_calls should include the 40 from checkpoint + new calls
|
||||||
|
assert state.total_llm_calls > 40
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_save_and_resume_roundtrip(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""End-to-end: run a few iterations, checkpoint, resume, finish."""
|
||||||
|
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
old_eval = EvalResult(
|
||||||
|
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||||
|
feedbacks=["ok"] * 5,
|
||||||
|
trajectories=[
|
||||||
|
Trajectory(f"in{i}", f"out{i}", s, "ok", "p")
|
||||||
|
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||||
|
],
|
||||||
|
)
|
||||||
|
new_eval = EvalResult(
|
||||||
|
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
||||||
|
feedbacks=["good"] * 5,
|
||||||
|
trajectories=[],
|
||||||
|
)
|
||||||
|
evaluator.evaluate = AsyncMock(
|
||||||
|
side_effect=[old_eval, old_eval, new_eval, old_eval, new_eval]
|
||||||
|
)
|
||||||
|
proposer = AsyncMock()
|
||||||
|
proposer.propose.return_value = Prompt(text="improved prompt")
|
||||||
|
|
||||||
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=2,
|
||||||
|
minibatch_size=5,
|
||||||
|
checkpoint_port=ckpt,
|
||||||
|
checkpoint_interval=1,
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
assert state.iteration == 2
|
||||||
|
assert ckpt.latest_exists()
|
||||||
|
|
||||||
|
# Capture the checkpoint state *before* resume (state is mutated in-place)
|
||||||
|
loaded = ckpt.load()
|
||||||
|
assert loaded is not None
|
||||||
|
saved_llm_calls = loaded.total_llm_calls
|
||||||
|
saved_iteration = loaded.iteration
|
||||||
|
|
||||||
|
# Set up evaluator for resumed run (just 1 more iteration)
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
|
||||||
|
proposer.propose.return_value = Prompt(text="even better prompt")
|
||||||
|
|
||||||
|
loop2 = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=3,
|
||||||
|
minibatch_size=5,
|
||||||
|
checkpoint_port=ckpt,
|
||||||
|
checkpoint_interval=1,
|
||||||
|
)
|
||||||
|
resumed = await loop2.run(
|
||||||
|
seed_prompt, synthetic_pool, task_description,
|
||||||
|
initial_state=loaded,
|
||||||
|
)
|
||||||
|
assert resumed.iteration == 3
|
||||||
|
assert resumed.total_llm_calls > saved_llm_calls
|
||||||
|
assert resumed.iteration > saved_iteration
|
||||||
278
tests/unit/test_cli.py
Normal file
278
tests/unit/test_cli.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
"""Tests for the CLI interface — prometheus optimize, version, etc.
|
||||||
|
|
||||||
|
Uses Typer's CliRunner for isolated command testing.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from prometheus.application.dto import OptimizationResult
|
||||||
|
from prometheus.cli.app import app
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCLIOptimize:
|
||||||
|
"""Tests for the `prometheus optimize` command."""
|
||||||
|
|
||||||
|
def _write_config(self, tmp_path: Path, **overrides: object) -> Path:
|
||||||
|
"""Write a minimal valid config YAML and return its path."""
|
||||||
|
data = {
|
||||||
|
"seed_prompt": "You are a helpful assistant.",
|
||||||
|
"task_description": "Answer factual questions accurately.",
|
||||||
|
}
|
||||||
|
data.update(overrides)
|
||||||
|
config_file = tmp_path / "config.yaml"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(data, f)
|
||||||
|
return config_file
|
||||||
|
|
||||||
|
def test_optimize_with_valid_config(self, tmp_path: Path) -> None:
|
||||||
|
config_file = self._write_config(tmp_path)
|
||||||
|
output_file = tmp_path / "output.yaml"
|
||||||
|
|
||||||
|
mock_result = OptimizationResult(
|
||||||
|
optimized_prompt="Improved prompt",
|
||||||
|
initial_prompt="You are a helpful assistant.",
|
||||||
|
iterations_used=5,
|
||||||
|
total_llm_calls=50,
|
||||||
|
initial_score=0.3,
|
||||||
|
final_score=0.9,
|
||||||
|
improvement=0.6,
|
||||||
|
history=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_uc = AsyncMock()
|
||||||
|
mock_uc.execute.return_value = mock_result
|
||||||
|
|
||||||
|
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||||
|
mock_llm_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||||
|
mock_judge_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||||
|
mock_prop_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"optimize",
|
||||||
|
"-i",
|
||||||
|
str(config_file),
|
||||||
|
"-o",
|
||||||
|
str(output_file),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Optimized Prompt" in result.output
|
||||||
|
|
||||||
|
def test_optimize_missing_input_file(self) -> None:
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["optimize", "-i", "/nonexistent/config.yaml"],
|
||||||
|
)
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
def test_optimize_with_verbose_flag(self, tmp_path: Path) -> None:
|
||||||
|
config_file = self._write_config(tmp_path)
|
||||||
|
output_file = tmp_path / "output.yaml"
|
||||||
|
|
||||||
|
mock_result = OptimizationResult(
|
||||||
|
optimized_prompt="Improved",
|
||||||
|
initial_prompt="test",
|
||||||
|
iterations_used=1,
|
||||||
|
total_llm_calls=10,
|
||||||
|
initial_score=0.3,
|
||||||
|
final_score=0.8,
|
||||||
|
improvement=0.5,
|
||||||
|
history=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_uc = AsyncMock()
|
||||||
|
mock_uc.execute.return_value = mock_result
|
||||||
|
|
||||||
|
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||||
|
mock_llm_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||||
|
mock_judge_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||||
|
mock_prop_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"optimize",
|
||||||
|
"-i",
|
||||||
|
str(config_file),
|
||||||
|
"-o",
|
||||||
|
str(output_file),
|
||||||
|
"-v",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
def test_optimize_displays_metrics(self, tmp_path: Path) -> None:
|
||||||
|
config_file = self._write_config(tmp_path)
|
||||||
|
output_file = tmp_path / "output.yaml"
|
||||||
|
|
||||||
|
mock_result = OptimizationResult(
|
||||||
|
optimized_prompt="Better prompt",
|
||||||
|
initial_prompt="test",
|
||||||
|
iterations_used=3,
|
||||||
|
total_llm_calls=30,
|
||||||
|
initial_score=0.40,
|
||||||
|
final_score=0.85,
|
||||||
|
improvement=0.45,
|
||||||
|
history=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_uc = AsyncMock()
|
||||||
|
mock_uc.execute.return_value = mock_result
|
||||||
|
|
||||||
|
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||||
|
mock_llm_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||||
|
mock_judge_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||||
|
mock_prop_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"optimize",
|
||||||
|
"-i",
|
||||||
|
str(config_file),
|
||||||
|
"-o",
|
||||||
|
str(output_file),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "0.40" in result.output
|
||||||
|
assert "0.85" in result.output
|
||||||
|
assert "+0.45" in result.output
|
||||||
|
|
||||||
|
def test_optimize_with_max_concurrency_flag(self, tmp_path: Path) -> None:
|
||||||
|
config_file = self._write_config(tmp_path)
|
||||||
|
output_file = tmp_path / "output.yaml"
|
||||||
|
|
||||||
|
mock_result = OptimizationResult(
|
||||||
|
optimized_prompt="Better prompt",
|
||||||
|
initial_prompt="test",
|
||||||
|
iterations_used=1,
|
||||||
|
total_llm_calls=10,
|
||||||
|
initial_score=0.3,
|
||||||
|
final_score=0.8,
|
||||||
|
improvement=0.5,
|
||||||
|
history=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_uc = AsyncMock()
|
||||||
|
mock_uc.execute.return_value = mock_result
|
||||||
|
|
||||||
|
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||||
|
mock_llm_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||||
|
mock_judge_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||||
|
mock_prop_cls.return_value = MagicMock()
|
||||||
|
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"optimize",
|
||||||
|
"-i",
|
||||||
|
str(config_file),
|
||||||
|
"-o",
|
||||||
|
str(output_file),
|
||||||
|
"--max-concurrency",
|
||||||
|
"10",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCLIHelp:
|
||||||
|
"""Tests for CLI help and no-args behavior."""
|
||||||
|
|
||||||
|
def test_no_args_shows_help(self) -> None:
|
||||||
|
result = runner.invoke(app, [])
|
||||||
|
# Typer uses exit code 2 when no_args_is_help=True
|
||||||
|
assert result.exit_code in (0, 2)
|
||||||
|
assert "PROMETHEUS" in result.output or "Usage" in result.output
|
||||||
|
|
||||||
|
def test_optimize_help(self) -> None:
|
||||||
|
result = runner.invoke(app, ["optimize", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "input" in result.output.lower() or "INPUT" in result.output
|
||||||
|
|
||||||
|
def test_version_help(self) -> None:
|
||||||
|
result = runner.invoke(app, ["version", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
def test_init_help(self) -> None:
|
||||||
|
result = runner.invoke(app, ["init", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
def test_list_help(self) -> None:
|
||||||
|
result = runner.invoke(app, ["list", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCLIVersion:
|
||||||
|
"""Tests for the `prometheus version` command."""
|
||||||
|
|
||||||
|
def test_version_prints_version(self) -> None:
|
||||||
|
result = runner.invoke(app, ["version"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "PROMETHEUS" in result.output
|
||||||
|
assert "0.1.0" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestCLIList:
|
||||||
|
"""Tests for the `prometheus list` command."""
|
||||||
|
|
||||||
|
def test_list_no_runs(self, tmp_path: Path) -> None:
|
||||||
|
result = runner.invoke(app, ["list", "-d", str(tmp_path)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "No optimization runs found" in result.output
|
||||||
|
|
||||||
|
def test_list_with_result(self, tmp_path: Path) -> None:
|
||||||
|
result_data = {
|
||||||
|
"optimized_prompt": "Better prompt for testing",
|
||||||
|
"initial_prompt": "test",
|
||||||
|
"iterations_used": 5,
|
||||||
|
"total_llm_calls": 50,
|
||||||
|
"initial_score": 0.30,
|
||||||
|
"final_score": 0.90,
|
||||||
|
"improvement": 0.60,
|
||||||
|
"history": [],
|
||||||
|
}
|
||||||
|
result_file = tmp_path / "output.yaml"
|
||||||
|
import yaml as _yaml
|
||||||
|
with open(result_file, "w") as f:
|
||||||
|
_yaml.dump(result_data, f)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["list", "-d", str(tmp_path)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "0.30" in result.output
|
||||||
|
assert "0.90" in result.output
|
||||||
|
|
||||||
|
def test_list_nonexistent_directory(self) -> None:
|
||||||
|
result = runner.invoke(app, ["list", "-d", "/nonexistent/dir"])
|
||||||
|
assert result.exit_code == 1
|
||||||
@@ -300,3 +300,33 @@ class TestConfigValidation:
|
|||||||
)
|
)
|
||||||
assert config.max_iterations == 1
|
assert config.max_iterations == 1
|
||||||
assert config.perfect_score == 0.0
|
assert config.perfect_score == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvalConfigValidation:
|
||||||
|
"""Tests for ground-truth evaluation config fields."""
|
||||||
|
|
||||||
|
def test_eval_defaults(self) -> None:
|
||||||
|
config = OptimizationConfig(seed_prompt="a", task_description="b")
|
||||||
|
assert config.eval_dataset_path is None
|
||||||
|
assert config.eval_metric == "bleu"
|
||||||
|
|
||||||
|
def test_eval_dataset_path_set(self) -> None:
|
||||||
|
config = OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b",
|
||||||
|
eval_dataset_path="data.csv",
|
||||||
|
)
|
||||||
|
assert config.eval_dataset_path == "data.csv"
|
||||||
|
|
||||||
|
def test_valid_eval_metrics(self) -> None:
|
||||||
|
for metric in ("exact", "bleu", "rouge_l", "cosine", "llm_judge"):
|
||||||
|
config = OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", eval_metric=metric,
|
||||||
|
)
|
||||||
|
assert config.eval_metric == metric
|
||||||
|
|
||||||
|
def test_invalid_eval_metric_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError, match="eval_metric must be one of"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b",
|
||||||
|
eval_metric="invalid_metric",
|
||||||
|
)
|
||||||
|
|||||||
86
tests/unit/test_dataset_loader.py
Normal file
86
tests/unit/test_dataset_loader.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""Tests for the ground-truth dataset loader."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.domain.entities import GroundTruthExample
|
||||||
|
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def loader():
|
||||||
|
return FileDatasetLoader()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCsvLoader:
|
||||||
|
def test_load_csv(self, loader, tmp_path):
|
||||||
|
csv_file = tmp_path / "test.csv"
|
||||||
|
csv_file.write_text("input,expected_output\nhello,world\nfoo,bar\n")
|
||||||
|
result = loader.load(str(csv_file))
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0].input_text == "hello"
|
||||||
|
assert result[0].expected_output == "world"
|
||||||
|
assert result[1].input_text == "foo"
|
||||||
|
assert result[1].expected_output == "bar"
|
||||||
|
|
||||||
|
def test_load_csv_skips_empty_input(self, loader, tmp_path):
|
||||||
|
csv_file = tmp_path / "test.csv"
|
||||||
|
csv_file.write_text("input,expected_output\n,bar\nhello,world\n")
|
||||||
|
result = loader.load(str(csv_file))
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].input_text == "hello"
|
||||||
|
|
||||||
|
def test_load_csv_with_whitespace(self, loader, tmp_path):
|
||||||
|
csv_file = tmp_path / "test.csv"
|
||||||
|
csv_file.write_text("input,expected_output\n hello , world \n")
|
||||||
|
result = loader.load(str(csv_file))
|
||||||
|
assert result[0].input_text == "hello"
|
||||||
|
assert result[0].expected_output == "world"
|
||||||
|
|
||||||
|
def test_load_csv_empty_file(self, loader, tmp_path):
|
||||||
|
csv_file = tmp_path / "test.csv"
|
||||||
|
csv_file.write_text("input,expected_output\n")
|
||||||
|
result = loader.load(str(csv_file))
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonLoader:
|
||||||
|
def test_load_json(self, loader, tmp_path):
|
||||||
|
json_file = tmp_path / "test.json"
|
||||||
|
data = [
|
||||||
|
{"input": "hello", "expected_output": "world"},
|
||||||
|
{"input": "foo", "expected_output": "bar"},
|
||||||
|
]
|
||||||
|
json_file.write_text(json.dumps(data))
|
||||||
|
result = loader.load(str(json_file))
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0].input_text == "hello"
|
||||||
|
assert result[0].expected_output == "world"
|
||||||
|
|
||||||
|
def test_load_json_skips_empty_input(self, loader, tmp_path):
|
||||||
|
json_file = tmp_path / "test.json"
|
||||||
|
data = [
|
||||||
|
{"input": "", "expected_output": "bar"},
|
||||||
|
{"input": "hello", "expected_output": "world"},
|
||||||
|
]
|
||||||
|
json_file.write_text(json.dumps(data))
|
||||||
|
result = loader.load(str(json_file))
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_load_json_not_array_raises(self, loader, tmp_path):
|
||||||
|
json_file = tmp_path / "test.json"
|
||||||
|
json_file.write_text(json.dumps({"not": "an array"}))
|
||||||
|
with pytest.raises(ValueError, match="must be an array"):
|
||||||
|
loader.load(str(json_file))
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnsupportedFormat:
|
||||||
|
def test_unsupported_extension_raises(self, loader, tmp_path):
|
||||||
|
txt_file = tmp_path / "test.txt"
|
||||||
|
txt_file.write_text("hello")
|
||||||
|
with pytest.raises(ValueError, match="Unsupported dataset format"):
|
||||||
|
loader.load(str(txt_file))
|
||||||
@@ -278,6 +278,7 @@ class TestPerCallIsolation:
|
|||||||
adapter._judge_dimensions = []
|
adapter._judge_dimensions = []
|
||||||
adapter._dimension_names = ""
|
adapter._dimension_names = ""
|
||||||
adapter._weights = {}
|
adapter._weights = {}
|
||||||
|
adapter.call_count = 0
|
||||||
|
|
||||||
# Mock _judge to fail on first call, succeed on second
|
# Mock _judge to fail on first call, succeed on second
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|||||||
@@ -8,10 +8,30 @@ import pytest
|
|||||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||||
from prometheus.application.evaluator import PromptEvaluator
|
from prometheus.application.evaluator import PromptEvaluator
|
||||||
from prometheus.application.evolution import EvolutionLoop
|
from prometheus.application.evolution import EvolutionLoop
|
||||||
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory
|
from prometheus.domain.entities import (
|
||||||
|
Candidate,
|
||||||
|
EvalResult,
|
||||||
|
Prompt,
|
||||||
|
SyntheticExample,
|
||||||
|
Trajectory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_eval(scores: list[float], label: str = "ok") -> EvalResult:
|
||||||
|
"""Helper to build an EvalResult from a list of scores."""
|
||||||
|
return EvalResult(
|
||||||
|
scores=scores,
|
||||||
|
feedbacks=[label] * len(scores),
|
||||||
|
trajectories=[
|
||||||
|
Trajectory(f"input{i}", f"output{i}", s, label, "prompt")
|
||||||
|
for i, s in enumerate(scores)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestEvolutionLoop:
|
class TestEvolutionLoop:
|
||||||
|
"""Tests for the original single-candidate hill-climbing mode (population_size=1)."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_accepts_improvement(
|
async def test_accepts_improvement(
|
||||||
self,
|
self,
|
||||||
@@ -27,28 +47,9 @@ class TestEvolutionLoop:
|
|||||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
initial_eval = EvalResult(
|
low_eval = _make_eval([0.3, 0.4, 0.3, 0.5, 0.2], "bad")
|
||||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
high_eval = _make_eval([0.8, 0.9, 0.7, 0.8, 0.9], "good")
|
||||||
feedbacks=["bad"] * 5,
|
evaluator.evaluate = AsyncMock(side_effect=[low_eval, low_eval, high_eval])
|
||||||
trajectories=[
|
|
||||||
Trajectory(f"input{i}", f"output{i}", s, "bad", "prompt")
|
|
||||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
|
||||||
],
|
|
||||||
)
|
|
||||||
old_eval = EvalResult(
|
|
||||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
|
||||||
feedbacks=["bad"] * 5,
|
|
||||||
trajectories=[
|
|
||||||
Trajectory(f"input{i}", f"output{i}", s, "bad", "prompt")
|
|
||||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
|
||||||
],
|
|
||||||
)
|
|
||||||
new_eval = EvalResult(
|
|
||||||
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
|
||||||
feedbacks=["good"] * 5,
|
|
||||||
trajectories=[],
|
|
||||||
)
|
|
||||||
evaluator.evaluate = AsyncMock(side_effect=[initial_eval, old_eval, new_eval])
|
|
||||||
|
|
||||||
loop = EvolutionLoop(
|
loop = EvolutionLoop(
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
@@ -57,7 +58,6 @@ class TestEvolutionLoop:
|
|||||||
max_iterations=1,
|
max_iterations=1,
|
||||||
minibatch_size=5,
|
minibatch_size=5,
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
|
||||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
assert state.best_candidate is not None
|
assert state.best_candidate is not None
|
||||||
@@ -78,28 +78,9 @@ class TestEvolutionLoop:
|
|||||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
initial_eval = EvalResult(
|
high_eval = _make_eval([0.7, 0.8, 0.7, 0.8, 0.9], "ok")
|
||||||
scores=[0.7, 0.8, 0.7, 0.8, 0.9],
|
low_eval = _make_eval([0.2, 0.1, 0.3, 0.2, 0.1], "bad")
|
||||||
feedbacks=["ok"] * 5,
|
evaluator.evaluate = AsyncMock(side_effect=[high_eval, high_eval, low_eval])
|
||||||
trajectories=[
|
|
||||||
Trajectory(f"input{i}", f"output{i}", s, "ok", "prompt")
|
|
||||||
for i, s in enumerate([0.7, 0.8, 0.7, 0.8, 0.9])
|
|
||||||
],
|
|
||||||
)
|
|
||||||
old_eval = EvalResult(
|
|
||||||
scores=[0.7, 0.8, 0.7, 0.8, 0.9],
|
|
||||||
feedbacks=["ok"] * 5,
|
|
||||||
trajectories=[
|
|
||||||
Trajectory(f"input{i}", f"output{i}", s, "ok", "prompt")
|
|
||||||
for i, s in enumerate([0.7, 0.8, 0.7, 0.8, 0.9])
|
|
||||||
],
|
|
||||||
)
|
|
||||||
new_eval = EvalResult(
|
|
||||||
scores=[0.2, 0.1, 0.3, 0.2, 0.1],
|
|
||||||
feedbacks=["bad"] * 5,
|
|
||||||
trajectories=[],
|
|
||||||
)
|
|
||||||
evaluator.evaluate = AsyncMock(side_effect=[initial_eval, old_eval, new_eval])
|
|
||||||
|
|
||||||
loop = EvolutionLoop(
|
loop = EvolutionLoop(
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
@@ -108,7 +89,6 @@ class TestEvolutionLoop:
|
|||||||
max_iterations=1,
|
max_iterations=1,
|
||||||
minibatch_size=5,
|
minibatch_size=5,
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
|
||||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
assert state.best_candidate is not None
|
assert state.best_candidate is not None
|
||||||
@@ -129,14 +109,7 @@ class TestEvolutionLoop:
|
|||||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
perfect_eval = EvalResult(
|
perfect_eval = _make_eval([1.0, 1.0, 1.0, 1.0, 1.0], "perfect")
|
||||||
scores=[1.0, 1.0, 1.0, 1.0, 1.0],
|
|
||||||
feedbacks=["perfect"] * 5,
|
|
||||||
trajectories=[
|
|
||||||
Trajectory(f"input{i}", f"output{i}", 1.0, "perfect", "prompt")
|
|
||||||
for i in range(5)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
||||||
|
|
||||||
loop = EvolutionLoop(
|
loop = EvolutionLoop(
|
||||||
@@ -146,7 +119,226 @@ class TestEvolutionLoop:
|
|||||||
max_iterations=3,
|
max_iterations=3,
|
||||||
minibatch_size=5,
|
minibatch_size=5,
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
|
||||||
await loop.run(seed_prompt, synthetic_pool, task_description)
|
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
mock_proposer_port.propose.assert_not_called()
|
mock_proposer_port.propose.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPopulationEvolution:
|
||||||
|
"""Tests for population-based evolution (population_size > 1)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_population_initialization(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
mock_llm_port: AsyncMock,
|
||||||
|
mock_judge_port: AsyncMock,
|
||||||
|
mock_proposer_port: AsyncMock,
|
||||||
|
mock_mutation_port: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Population is initialized with the right number of candidates."""
|
||||||
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
|
evaluator.evaluate = AsyncMock(
|
||||||
|
return_value=_make_eval([0.5] * 5, "ok")
|
||||||
|
)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer_port,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=0, # no iterations, just initialization
|
||||||
|
minibatch_size=5,
|
||||||
|
population_size=4,
|
||||||
|
mutation_port=mock_mutation_port,
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
# 1 seed + 3 mutations = 4 candidates
|
||||||
|
assert len(state.candidates) == 4
|
||||||
|
assert mock_mutation_port.mutate.call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_population_initialization_uses_proposer_fallback(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
mock_llm_port: AsyncMock,
|
||||||
|
mock_judge_port: AsyncMock,
|
||||||
|
mock_proposer_port: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""When no mutation_port is provided, population init falls back to proposer."""
|
||||||
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
|
evaluator.evaluate = AsyncMock(
|
||||||
|
return_value=_make_eval([0.5] * 5, "ok")
|
||||||
|
)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer_port,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=0,
|
||||||
|
minibatch_size=5,
|
||||||
|
population_size=3,
|
||||||
|
# mutation_port intentionally omitted
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
assert len(state.candidates) == 3
|
||||||
|
assert mock_proposer_port.propose.call_count == 2 # 3-1 = 2 init mutations
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_population_iteration_replaces_worst(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
mock_llm_port: AsyncMock,
|
||||||
|
mock_judge_port: AsyncMock,
|
||||||
|
mock_proposer_port: AsyncMock,
|
||||||
|
mock_crossover_port: AsyncMock,
|
||||||
|
mock_mutation_port: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Crossover child replaces worst candidate when its fitness is higher."""
|
||||||
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
# Sequence:
|
||||||
|
# 1. Initial eval (seed)
|
||||||
|
# 2. Population init: 3 mutation calls use proposer.propose(), NOT evaluator.evaluate
|
||||||
|
# 3. Population iteration: crossover produces child → eval child
|
||||||
|
# Only 2 evaluator.evaluate calls total
|
||||||
|
seed_eval = _make_eval([0.5] * 5, "ok")
|
||||||
|
# Crossover child eval - high score to beat worst
|
||||||
|
child_eval = _make_eval([0.9, 0.9, 0.8, 0.9, 0.8], "great")
|
||||||
|
|
||||||
|
all_evals = [seed_eval, child_eval]
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=all_evals)
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer_port,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=1,
|
||||||
|
minibatch_size=5,
|
||||||
|
population_size=4,
|
||||||
|
crossover_rate=1.0,
|
||||||
|
crossover_port=mock_crossover_port,
|
||||||
|
mutation_rate=0.0, # disable post-crossover mutation for determinism
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
accepted_events = [h for h in state.history if h.get("event") == "pop_accepted"]
|
||||||
|
assert len(accepted_events) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_population_iteration_rejects_inferior_child(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
mock_llm_port: AsyncMock,
|
||||||
|
mock_judge_port: AsyncMock,
|
||||||
|
mock_proposer_port: AsyncMock,
|
||||||
|
mock_crossover_port: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Inferior child is rejected and doesn't replace any candidate."""
|
||||||
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
seed_eval = _make_eval([0.8] * 5, "ok")
|
||||||
|
# Crossover produces very LOW-scoring child
|
||||||
|
child_eval = _make_eval([0.1] * 5, "terrible")
|
||||||
|
|
||||||
|
all_evals = [seed_eval, child_eval]
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=all_evals)
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer_port,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=1,
|
||||||
|
minibatch_size=5,
|
||||||
|
population_size=4,
|
||||||
|
crossover_rate=1.0,
|
||||||
|
crossover_port=mock_crossover_port,
|
||||||
|
mutation_rate=0.0,
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
rejected_events = [h for h in state.history if h.get("event") == "pop_rejected"]
|
||||||
|
assert len(rejected_events) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiversityScore:
|
||||||
|
"""Tests for the diversity/similarity scoring logic."""
|
||||||
|
|
||||||
|
def test_identical_prompts_have_high_similarity(self) -> None:
|
||||||
|
"""Identical prompts should have very high similarity."""
|
||||||
|
identical = Prompt(text="You are a helpful assistant. Answer the question.")
|
||||||
|
pop_a = Candidate(prompt=identical, best_score=4.0, generation=0)
|
||||||
|
pop_b = Candidate(
|
||||||
|
prompt=Prompt(text="Completely different prompt about data analysis."),
|
||||||
|
best_score=3.0,
|
||||||
|
generation=0,
|
||||||
|
)
|
||||||
|
sim_same = EvolutionLoop._compute_diversity_score(identical, [pop_a, pop_b])
|
||||||
|
# Average includes similarity to the different member, so ~0.5 not 0.9+
|
||||||
|
assert sim_same > 0.3
|
||||||
|
|
||||||
|
def test_different_prompts_have_lower_similarity(self) -> None:
|
||||||
|
"""Different prompts should have lower similarity than identical ones."""
|
||||||
|
prompt_a = Prompt(text="You are a helpful assistant. Answer the question.")
|
||||||
|
prompt_b = Prompt(text="Provide detailed analysis of complex data patterns with precision.")
|
||||||
|
pop_a = Candidate(prompt=prompt_a, best_score=4.0, generation=0)
|
||||||
|
pop_b = Candidate(prompt=prompt_b, best_score=3.0, generation=0)
|
||||||
|
sim_a = EvolutionLoop._compute_diversity_score(prompt_a, [pop_a, pop_b])
|
||||||
|
sim_b = EvolutionLoop._compute_diversity_score(prompt_b, [pop_a, pop_b])
|
||||||
|
# Both should be < 1.0 since they're different
|
||||||
|
assert sim_a < 1.0
|
||||||
|
assert sim_b < 1.0
|
||||||
|
|
||||||
|
def test_single_member_population_returns_1(self) -> None:
|
||||||
|
"""Single-member population always returns 1.0 (no penalty)."""
|
||||||
|
prompt = Prompt(text="Any prompt text here.")
|
||||||
|
pop = [Candidate(prompt=prompt, best_score=1.0, generation=0)]
|
||||||
|
sim = EvolutionLoop._compute_diversity_score(prompt, pop)
|
||||||
|
assert sim == 1.0
|
||||||
|
|
||||||
|
def test_empty_prompt_returns_zero(self) -> None:
|
||||||
|
"""Empty prompt text returns 0.0 when population has >1 member."""
|
||||||
|
prompt = Prompt(text="")
|
||||||
|
pop = [
|
||||||
|
Candidate(prompt=Prompt(text="some text"), best_score=1.0, generation=0),
|
||||||
|
Candidate(prompt=Prompt(text="other text"), best_score=2.0, generation=0),
|
||||||
|
]
|
||||||
|
sim = EvolutionLoop._compute_diversity_score(prompt, pop)
|
||||||
|
assert sim == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptDiff:
|
||||||
|
"""Tests for the static _compute_prompt_diff helper."""
|
||||||
|
|
||||||
|
def test_identical_prompts(self) -> None:
|
||||||
|
result = EvolutionLoop._compute_prompt_diff("hello\nworld", "hello\nworld")
|
||||||
|
assert result["lines_added"] == 0
|
||||||
|
assert result["lines_removed"] == 0
|
||||||
|
assert result["chars_delta"] == 0
|
||||||
|
|
||||||
|
def test_added_lines(self) -> None:
|
||||||
|
result = EvolutionLoop._compute_prompt_diff("hello", "hello\nworld")
|
||||||
|
assert result["lines_added"] == 1
|
||||||
|
assert result["lines_removed"] == 0
|
||||||
|
assert result["chars_delta"] == 6 # "\nworld"
|
||||||
|
|
||||||
|
def test_removed_lines(self) -> None:
|
||||||
|
result = EvolutionLoop._compute_prompt_diff("hello\nworld", "hello")
|
||||||
|
assert result["lines_added"] == 0
|
||||||
|
assert result["lines_removed"] == 1
|
||||||
|
|||||||
133
tests/unit/test_ground_truth_evaluator.py
Normal file
133
tests/unit/test_ground_truth_evaluator.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""Tests for GroundTruthEvaluator — execution + similarity comparison."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.application.ground_truth_evaluator import GroundTruthEvaluator
|
||||||
|
from prometheus.domain.entities import EvalResult, GroundTruthExample, Prompt
|
||||||
|
from prometheus.domain.ports import LLMPort, SimilarityPort
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_executor() -> AsyncMock:
|
||||||
|
port = AsyncMock(spec=LLMPort)
|
||||||
|
port.execute.return_value = "Paris is the capital of France."
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_similarity() -> AsyncMock:
|
||||||
|
port = AsyncMock(spec=SimilarityPort)
|
||||||
|
port.compute.return_value = 0.85
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gt_dataset() -> list[GroundTruthExample]:
|
||||||
|
return [
|
||||||
|
GroundTruthExample(input_text="What is the capital of France?", expected_output="Paris", id=0),
|
||||||
|
GroundTruthExample(input_text="What is 2+2?", expected_output="4", id=1),
|
||||||
|
GroundTruthExample(input_text="What color is the sky?", expected_output="blue", id=2),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prompt() -> Prompt:
|
||||||
|
return Prompt(text="Answer the following question accurately.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestGroundTruthEvaluator:
|
||||||
|
async def test_evaluate_happy_path(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor,
|
||||||
|
similarity=mock_similarity,
|
||||||
|
max_concurrency=2,
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||||
|
|
||||||
|
assert isinstance(result, EvalResult)
|
||||||
|
assert len(result.scores) == 3
|
||||||
|
assert len(result.feedbacks) == 3
|
||||||
|
assert len(result.trajectories) == 3
|
||||||
|
assert all(s == 0.85 for s in result.scores)
|
||||||
|
assert result.mean_score == pytest.approx(0.85)
|
||||||
|
assert result.total_score == pytest.approx(2.55)
|
||||||
|
|
||||||
|
async def test_executor_called_for_each_input(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor, similarity=mock_similarity,
|
||||||
|
)
|
||||||
|
await evaluator.evaluate(prompt, gt_dataset)
|
||||||
|
assert mock_executor.execute.call_count == 3
|
||||||
|
|
||||||
|
async def test_similarity_called_for_each_output(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor, similarity=mock_similarity,
|
||||||
|
)
|
||||||
|
await evaluator.evaluate(prompt, gt_dataset)
|
||||||
|
assert mock_similarity.compute.call_count == 3
|
||||||
|
|
||||||
|
async def test_execution_error_produces_zero_score(self, mock_similarity, gt_dataset, prompt):
|
||||||
|
failing_executor = AsyncMock(spec=LLMPort)
|
||||||
|
failing_executor.execute.side_effect = RuntimeError("API timeout")
|
||||||
|
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=failing_executor, similarity=mock_similarity,
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||||
|
|
||||||
|
assert len(result.scores) == 3
|
||||||
|
# The similarity adapter is called with the error sentinel
|
||||||
|
assert all(isinstance(s, float) for s in result.scores)
|
||||||
|
assert all("[execution error:" in t.output_text for t in result.trajectories)
|
||||||
|
|
||||||
|
async def test_empty_dataset(self, mock_executor, mock_similarity, prompt):
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor, similarity=mock_similarity,
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, [])
|
||||||
|
assert result.scores == []
|
||||||
|
assert result.mean_score == 0.0
|
||||||
|
assert result.total_score == 0.0
|
||||||
|
|
||||||
|
async def test_trajectory_contains_prompt_used(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor, similarity=mock_similarity,
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||||
|
for t in result.trajectories:
|
||||||
|
assert t.prompt_used == prompt.text
|
||||||
|
|
||||||
|
async def test_scores_clamped_to_unit_range(self, mock_executor, gt_dataset, prompt):
|
||||||
|
# Similarity returns a value > 1.0 (should be clamped)
|
||||||
|
over_similarity = AsyncMock(spec=SimilarityPort)
|
||||||
|
over_similarity.compute.return_value = 1.5
|
||||||
|
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor, similarity=over_similarity,
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||||
|
assert all(0.0 <= s <= 1.0 for s in result.scores)
|
||||||
|
|
||||||
|
async def test_feedback_for_exact_match(self, mock_executor, gt_dataset, prompt):
|
||||||
|
exact_similarity = AsyncMock(spec=SimilarityPort)
|
||||||
|
exact_similarity.compute.return_value = 1.0
|
||||||
|
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor, similarity=exact_similarity,
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||||
|
assert all("Exact match" in fb for fb in result.feedbacks)
|
||||||
|
|
||||||
|
async def test_feedback_for_poor_match(self, mock_executor, gt_dataset, prompt):
|
||||||
|
poor_similarity = AsyncMock(spec=SimilarityPort)
|
||||||
|
poor_similarity.compute.return_value = 0.1
|
||||||
|
|
||||||
|
evaluator = GroundTruthEvaluator(
|
||||||
|
executor=mock_executor, similarity=poor_similarity,
|
||||||
|
)
|
||||||
|
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||||
|
assert all("Poor match" in fb for fb in result.feedbacks)
|
||||||
316
tests/unit/test_holdout_validation.py
Normal file
316
tests/unit/test_holdout_validation.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
"""Unit tests for hold-out validation and early stopping."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||||
|
from prometheus.application.evaluator import PromptEvaluator
|
||||||
|
from prometheus.application.evolution import EvolutionLoop
|
||||||
|
from prometheus.domain.entities import (
|
||||||
|
Candidate,
|
||||||
|
EvalResult,
|
||||||
|
Prompt,
|
||||||
|
SyntheticExample,
|
||||||
|
Trajectory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_eval(mean_score: float, n: int = 5) -> EvalResult:
|
||||||
|
"""Helper: create an EvalResult with a given mean score."""
|
||||||
|
scores = [mean_score] * n
|
||||||
|
return EvalResult(
|
||||||
|
scores=scores,
|
||||||
|
feedbacks=["feedback"] * n,
|
||||||
|
trajectories=[
|
||||||
|
Trajectory(f"input{i}", f"output{i}", mean_score, "feedback", "prompt")
|
||||||
|
for i in range(n)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBootstrapSplit:
|
||||||
|
"""Tests for SyntheticBootstrap.split_pool."""
|
||||||
|
|
||||||
|
def test_split_produces_correct_sizes(self):
|
||||||
|
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(20)]
|
||||||
|
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
|
||||||
|
assert len(train) + len(val) == 20
|
||||||
|
assert len(val) == 6 # 20 * 0.3 = 6
|
||||||
|
assert len(train) == 14
|
||||||
|
|
||||||
|
def test_split_zero_fraction_returns_all_train(self):
|
||||||
|
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(10)]
|
||||||
|
train, val = SyntheticBootstrap.split_pool(pool, 0.0)
|
||||||
|
assert len(train) == 10
|
||||||
|
assert len(val) == 0
|
||||||
|
|
||||||
|
def test_split_single_element(self):
|
||||||
|
pool = [SyntheticExample(input_text="only", id=0)]
|
||||||
|
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
|
||||||
|
assert len(train) == 1
|
||||||
|
assert len(val) == 0
|
||||||
|
|
||||||
|
def test_split_deterministic_with_seed(self):
|
||||||
|
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(50)]
|
||||||
|
train1, val1 = SyntheticBootstrap.split_pool(pool, 0.3, rng=MagicMock(wraps=__import__("random").Random(42)))
|
||||||
|
train2, val2 = SyntheticBootstrap.split_pool(pool, 0.3, rng=MagicMock(wraps=__import__("random").Random(42)))
|
||||||
|
assert [ex.id for ex in train1] == [ex.id for ex in train2]
|
||||||
|
assert [ex.id for ex in val1] == [ex.id for ex in val2]
|
||||||
|
|
||||||
|
def test_split_no_overlap(self):
|
||||||
|
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(30)]
|
||||||
|
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
|
||||||
|
train_ids = {ex.id for ex in train}
|
||||||
|
val_ids = {ex.id for ex in val}
|
||||||
|
assert train_ids.isdisjoint(val_ids)
|
||||||
|
assert train_ids | val_ids == {ex.id for ex in pool}
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidationEvaluation:
|
||||||
|
"""Tests for hold-out evaluation during evolution."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validation_pool_evaluated_after_each_iteration(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
mock_llm_port: AsyncMock,
|
||||||
|
mock_judge_port: AsyncMock,
|
||||||
|
mock_proposer_port: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""When a validation pool is provided, the best candidate is evaluated on it."""
|
||||||
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
# Initial eval (train) + validation eval + iteration train eval + new prompt eval + validation eval
|
||||||
|
train_eval = _make_eval(0.5)
|
||||||
|
val_eval = _make_eval(0.6)
|
||||||
|
new_eval = _make_eval(0.7)
|
||||||
|
val_eval_2 = _make_eval(0.65)
|
||||||
|
|
||||||
|
evaluator.evaluate = AsyncMock(
|
||||||
|
side_effect=[train_eval, val_eval, train_eval, new_eval, val_eval_2]
|
||||||
|
)
|
||||||
|
|
||||||
|
validation_pool = synthetic_pool[-6:]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer_port,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=1,
|
||||||
|
minibatch_size=5,
|
||||||
|
)
|
||||||
|
state = await loop.run(
|
||||||
|
seed_prompt, synthetic_pool, task_description,
|
||||||
|
validation_pool=validation_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have validation metrics in state
|
||||||
|
assert state.best_validation_score is not None
|
||||||
|
# History should contain validation_eval entries
|
||||||
|
val_events = [h for h in state.history if h["event"] == "validation_eval"]
|
||||||
|
assert len(val_events) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_validation_without_pool(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
mock_llm_port: AsyncMock,
|
||||||
|
mock_judge_port: AsyncMock,
|
||||||
|
mock_proposer_port: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Without a validation pool, no validation is performed."""
|
||||||
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
train_eval = _make_eval(0.5)
|
||||||
|
old_eval = _make_eval(0.5)
|
||||||
|
new_eval = _make_eval(0.7)
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=[train_eval, old_eval, new_eval])
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer_port,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=1,
|
||||||
|
minibatch_size=5,
|
||||||
|
)
|
||||||
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
|
assert state.best_validation_score is None
|
||||||
|
assert not state.early_stopped
|
||||||
|
val_events = [h for h in state.history if h["event"] == "validation_eval"]
|
||||||
|
assert len(val_events) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestEarlyStopping:
|
||||||
|
"""Tests for early stopping when validation score degrades."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_early_stop_triggers_on_patience_exceeded(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
mock_llm_port: AsyncMock,
|
||||||
|
mock_judge_port: AsyncMock,
|
||||||
|
mock_proposer_port: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Early stopping triggers when validation doesn't improve for K iterations."""
|
||||||
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
patience = 3
|
||||||
|
# Build eval sequence:
|
||||||
|
# 1. Initial train eval
|
||||||
|
# 2. Initial validation eval (0.5)
|
||||||
|
# Then for each of 3 iterations:
|
||||||
|
# - train eval (current best)
|
||||||
|
# - train eval (new prompt - accepted)
|
||||||
|
# - validation eval (degrading)
|
||||||
|
evals = [
|
||||||
|
_make_eval(0.5), # initial train
|
||||||
|
_make_eval(0.5), # initial validation
|
||||||
|
]
|
||||||
|
for i in range(patience):
|
||||||
|
evals.extend([
|
||||||
|
_make_eval(0.5 + i * 0.1), # current eval (train)
|
||||||
|
_make_eval(0.6 + i * 0.1), # new eval (train) - accepted
|
||||||
|
_make_eval(0.4), # validation eval (degrading)
|
||||||
|
])
|
||||||
|
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||||
|
|
||||||
|
validation_pool = synthetic_pool[-5:]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer_port,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=10, # would go further without early stop
|
||||||
|
minibatch_size=5,
|
||||||
|
early_stop_patience=patience,
|
||||||
|
)
|
||||||
|
state = await loop.run(
|
||||||
|
seed_prompt, synthetic_pool, task_description,
|
||||||
|
validation_pool=validation_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.early_stopped is True
|
||||||
|
assert state.iteration == patience
|
||||||
|
assert state.best_validation_score is not None
|
||||||
|
# Should have an early_stop event in history
|
||||||
|
early_stop_events = [h for h in state.history if h["event"] == "early_stop"]
|
||||||
|
assert len(early_stop_events) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_early_stop_does_not_trigger_when_improving(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
mock_llm_port: AsyncMock,
|
||||||
|
mock_judge_port: AsyncMock,
|
||||||
|
mock_proposer_port: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""When validation keeps improving, early stopping does not trigger."""
|
||||||
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
evals = [
|
||||||
|
_make_eval(0.3), # initial train
|
||||||
|
_make_eval(0.3), # initial validation
|
||||||
|
]
|
||||||
|
# 3 iterations, each with improving validation
|
||||||
|
for i in range(3):
|
||||||
|
evals.extend([
|
||||||
|
_make_eval(0.3 + i * 0.1), # current train eval
|
||||||
|
_make_eval(0.4 + i * 0.1), # new train eval (accepted)
|
||||||
|
_make_eval(0.3 + (i + 1) * 0.1), # validation eval (improving)
|
||||||
|
])
|
||||||
|
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||||
|
|
||||||
|
validation_pool = synthetic_pool[-5:]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer_port,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=3,
|
||||||
|
minibatch_size=5,
|
||||||
|
early_stop_patience=5,
|
||||||
|
)
|
||||||
|
state = await loop.run(
|
||||||
|
seed_prompt, synthetic_pool, task_description,
|
||||||
|
validation_pool=validation_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.early_stopped is False
|
||||||
|
assert state.iteration == 3
|
||||||
|
assert state.best_validation_score is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validation_patience_resets_on_improvement(
|
||||||
|
self,
|
||||||
|
seed_prompt: Prompt,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
mock_llm_port: AsyncMock,
|
||||||
|
mock_judge_port: AsyncMock,
|
||||||
|
mock_proposer_port: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Patience counter resets when validation improves after degrading."""
|
||||||
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||||
|
|
||||||
|
evals = [
|
||||||
|
_make_eval(0.3), # initial train
|
||||||
|
_make_eval(0.3), # initial validation
|
||||||
|
# iter 1: degrade
|
||||||
|
_make_eval(0.3), # current train
|
||||||
|
_make_eval(0.5), # new train (accepted)
|
||||||
|
_make_eval(0.2), # validation degrade (patience=1)
|
||||||
|
# iter 2: degrade
|
||||||
|
_make_eval(0.5), # current train
|
||||||
|
_make_eval(0.6), # new train (accepted)
|
||||||
|
_make_eval(0.2), # validation degrade (patience=2)
|
||||||
|
# iter 3: improve! (resets patience)
|
||||||
|
_make_eval(0.6), # current train
|
||||||
|
_make_eval(0.7), # new train (accepted)
|
||||||
|
_make_eval(0.4), # validation improve (patience=0)
|
||||||
|
# iter 4: degrade again
|
||||||
|
_make_eval(0.7), # current train
|
||||||
|
_make_eval(0.8), # new train (accepted)
|
||||||
|
_make_eval(0.2), # validation degrade (patience=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||||
|
validation_pool = synthetic_pool[-5:]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=mock_proposer_port,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=4,
|
||||||
|
minibatch_size=5,
|
||||||
|
early_stop_patience=3,
|
||||||
|
)
|
||||||
|
state = await loop.run(
|
||||||
|
seed_prompt, synthetic_pool, task_description,
|
||||||
|
validation_pool=validation_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.early_stopped is False
|
||||||
|
assert state.iteration == 4
|
||||||
189
tests/unit/test_logging.py
Normal file
189
tests/unit/test_logging.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""Unit tests for structured logging configuration."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from prometheus.cli.logging_setup import configure_logging, get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigureLogging:
|
||||||
|
def _count_handlers(self, name: str = "prometheus") -> int:
|
||||||
|
return len(logging.getLogger(name).handlers)
|
||||||
|
|
||||||
|
def test_default_creates_console_handler(self) -> None:
|
||||||
|
configure_logging(level=logging.INFO)
|
||||||
|
prom = logging.getLogger("prometheus")
|
||||||
|
assert len(prom.handlers) == 1
|
||||||
|
assert isinstance(prom.handlers[0], logging.StreamHandler)
|
||||||
|
prom.handlers.clear()
|
||||||
|
|
||||||
|
def test_json_format_produces_valid_json(self, capsys) -> None:
|
||||||
|
configure_logging(level=logging.INFO, log_format="json")
|
||||||
|
logger = get_logger("test_json")
|
||||||
|
logger.info("hello", extra={"structured": {"key": "value"}})
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
# Output goes to stderr
|
||||||
|
line = captured.err.strip().split("\n")[-1]
|
||||||
|
data = json.loads(line)
|
||||||
|
assert data["message"] == "hello"
|
||||||
|
assert data["structured"]["key"] == "value"
|
||||||
|
assert data["level"] == "INFO"
|
||||||
|
assert "timestamp" in data
|
||||||
|
|
||||||
|
logging.getLogger("prometheus").handlers.clear()
|
||||||
|
|
||||||
|
def test_text_format_includes_structured_extras(self, capsys) -> None:
|
||||||
|
configure_logging(level=logging.INFO, log_format="text")
|
||||||
|
logger = get_logger("test_text")
|
||||||
|
logger.info("msg", extra={"structured": {"foo": "bar"}})
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "foo=bar" in captured.err
|
||||||
|
|
||||||
|
logging.getLogger("prometheus").handlers.clear()
|
||||||
|
|
||||||
|
def test_debug_level_shows_debug_messages(self, capsys) -> None:
|
||||||
|
configure_logging(level=logging.DEBUG)
|
||||||
|
logger = get_logger("test_debug")
|
||||||
|
logger.debug("debug msg")
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "debug msg" in captured.err
|
||||||
|
|
||||||
|
logging.getLogger("prometheus").handlers.clear()
|
||||||
|
|
||||||
|
def test_warning_level_hides_debug_messages(self, capsys) -> None:
|
||||||
|
configure_logging(level=logging.WARNING)
|
||||||
|
logger = get_logger("test_warn")
|
||||||
|
logger.debug("should not appear")
|
||||||
|
logger.info("also hidden")
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "should not appear" not in captured.err
|
||||||
|
assert "also hidden" not in captured.err
|
||||||
|
|
||||||
|
logging.getLogger("prometheus").handlers.clear()
|
||||||
|
|
||||||
|
def test_file_handler_writes_to_file(self, tmp_path: Path) -> None:
|
||||||
|
log_file = tmp_path / "test.log"
|
||||||
|
configure_logging(level=logging.INFO, log_file=str(log_file))
|
||||||
|
logger = get_logger("test_file")
|
||||||
|
logger.info("file message")
|
||||||
|
|
||||||
|
prom = logging.getLogger("prometheus")
|
||||||
|
# Flush handlers
|
||||||
|
for h in prom.handlers:
|
||||||
|
h.flush()
|
||||||
|
prom.handlers.clear()
|
||||||
|
|
||||||
|
content = log_file.read_text()
|
||||||
|
assert "file message" in content
|
||||||
|
|
||||||
|
def test_json_file_output(self, tmp_path: Path) -> None:
|
||||||
|
log_file = tmp_path / "test.json.log"
|
||||||
|
configure_logging(level=logging.INFO, log_format="json", log_file=str(log_file))
|
||||||
|
logger = get_logger("test_json_file")
|
||||||
|
logger.info("json file msg", extra={"structured": {"x": 1}})
|
||||||
|
|
||||||
|
prom = logging.getLogger("prometheus")
|
||||||
|
for h in prom.handlers:
|
||||||
|
h.flush()
|
||||||
|
prom.handlers.clear()
|
||||||
|
|
||||||
|
content = log_file.read_text().strip()
|
||||||
|
data = json.loads(content)
|
||||||
|
assert data["message"] == "json file msg"
|
||||||
|
assert data["structured"]["x"] == 1
|
||||||
|
|
||||||
|
def test_reconfigure_clears_old_handlers(self) -> None:
|
||||||
|
configure_logging(level=logging.INFO)
|
||||||
|
configure_logging(level=logging.DEBUG)
|
||||||
|
prom = logging.getLogger("prometheus")
|
||||||
|
assert len(prom.handlers) == 1
|
||||||
|
prom.handlers.clear()
|
||||||
|
|
||||||
|
def test_propagate_false_prevents_duplicate_output(self, capsys) -> None:
|
||||||
|
configure_logging(level=logging.INFO)
|
||||||
|
prom = logging.getLogger("prometheus")
|
||||||
|
assert prom.propagate is False
|
||||||
|
prom.handlers.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLogger:
|
||||||
|
def test_returns_child_of_prometheus(self) -> None:
|
||||||
|
logger = get_logger("mymodule")
|
||||||
|
assert logger.name == "prometheus.mymodule"
|
||||||
|
|
||||||
|
def test_inherits_level_from_parent(self) -> None:
|
||||||
|
configure_logging(level=logging.DEBUG)
|
||||||
|
logger = get_logger("child")
|
||||||
|
assert logger.getEffectiveLevel() <= logging.DEBUG
|
||||||
|
logging.getLogger("prometheus").handlers.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonFormatter:
|
||||||
|
def test_exception_included(self, capsys) -> None:
|
||||||
|
configure_logging(level=logging.ERROR, log_format="json")
|
||||||
|
logger = get_logger("test_exc")
|
||||||
|
try:
|
||||||
|
raise ValueError("boom")
|
||||||
|
except ValueError:
|
||||||
|
logger.error("failed", exc_info=True)
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
line = captured.err.strip().split("\n")[-1]
|
||||||
|
data = json.loads(line)
|
||||||
|
assert "ValueError: boom" in data["exception"]
|
||||||
|
|
||||||
|
logging.getLogger("prometheus").handlers.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoggingCLIIntegration:
|
||||||
|
"""Tests for CLI flags that configure logging."""
|
||||||
|
|
||||||
|
def test_verbose_flag_enables_info(self, tmp_path: Path) -> None:
|
||||||
|
"""Simulate what -v does — configure_logging at INFO level."""
|
||||||
|
configure_logging(level=logging.INFO)
|
||||||
|
logger = get_logger("evolution")
|
||||||
|
logger.info("test message")
|
||||||
|
|
||||||
|
prom = logging.getLogger("prometheus")
|
||||||
|
assert len(prom.handlers) == 1
|
||||||
|
prom.handlers.clear()
|
||||||
|
|
||||||
|
def test_debug_flag_enables_debug(self) -> None:
|
||||||
|
"""Simulate what --debug does — configure_logging at DEBUG level."""
|
||||||
|
configure_logging(level=logging.DEBUG)
|
||||||
|
logger = get_logger("evolution")
|
||||||
|
logger.debug("debug message")
|
||||||
|
|
||||||
|
prom = logging.getLogger("prometheus")
|
||||||
|
assert prom.level == logging.DEBUG
|
||||||
|
prom.handlers.clear()
|
||||||
|
|
||||||
|
def test_log_format_invalid_rejected(self) -> None:
|
||||||
|
"""Invalid log_format should be caught by OptimizationConfig validator."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from prometheus.application.dto import OptimizationConfig
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="log_format must be one of"):
|
||||||
|
OptimizationConfig(
|
||||||
|
seed_prompt="a",
|
||||||
|
task_description="b",
|
||||||
|
log_format="xml",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_log_format_text_and_json_accepted(self) -> None:
|
||||||
|
"""Both text and json log_format values should be valid."""
|
||||||
|
from prometheus.application.dto import OptimizationConfig
|
||||||
|
|
||||||
|
for fmt in ("text", "json"):
|
||||||
|
config = OptimizationConfig(
|
||||||
|
seed_prompt="a", task_description="b", log_format=fmt,
|
||||||
|
)
|
||||||
|
assert config.log_format == fmt
|
||||||
96
tests/unit/test_scoring_extended.py
Normal file
96
tests/unit/test_scoring_extended.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""Additional unit tests for scoring edge cases."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.domain.entities import EvalResult, Trajectory
|
||||||
|
from prometheus.domain.scoring import normalize_score, should_accept
|
||||||
|
|
||||||
|
|
||||||
|
def _make_eval(scores: list[float]) -> EvalResult:
|
||||||
|
return EvalResult(
|
||||||
|
scores=scores,
|
||||||
|
feedbacks=[""] * len(scores),
|
||||||
|
trajectories=[
|
||||||
|
Trajectory(f"in{i}", f"out{i}", s, "", "p")
|
||||||
|
for i, s in enumerate(scores)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestShouldAcceptEdgeCases:
|
||||||
|
"""Extended edge-case tests for should_accept."""
|
||||||
|
|
||||||
|
def test_tiny_improvement_accepted(self) -> None:
|
||||||
|
old = _make_eval([0.5])
|
||||||
|
new = _make_eval([0.5001])
|
||||||
|
assert should_accept(old, new) is True
|
||||||
|
|
||||||
|
def test_tiny_improvement_below_threshold(self) -> None:
|
||||||
|
old = _make_eval([0.5])
|
||||||
|
new = _make_eval([0.5001])
|
||||||
|
assert should_accept(old, new, min_improvement=0.01) is False
|
||||||
|
|
||||||
|
def test_zero_scores_equal(self) -> None:
|
||||||
|
old = _make_eval([0.0, 0.0])
|
||||||
|
new = _make_eval([0.0, 0.0])
|
||||||
|
assert should_accept(old, new) is False
|
||||||
|
|
||||||
|
def test_negative_to_zero_not_accepted(self) -> None:
|
||||||
|
"""Scores should be [0,1] but test should_accept with edge values."""
|
||||||
|
old = _make_eval([-0.1])
|
||||||
|
new = _make_eval([0.0])
|
||||||
|
assert should_accept(old, new) is True
|
||||||
|
|
||||||
|
def test_large_improvement(self) -> None:
|
||||||
|
old = _make_eval([0.0, 0.0, 0.0])
|
||||||
|
new = _make_eval([1.0, 1.0, 1.0])
|
||||||
|
assert should_accept(old, new) is True
|
||||||
|
|
||||||
|
def test_single_score_improvement(self) -> None:
|
||||||
|
old = _make_eval([0.4])
|
||||||
|
new = _make_eval([0.5])
|
||||||
|
assert should_accept(old, new) is True
|
||||||
|
|
||||||
|
def test_min_improvement_exactly_met(self) -> None:
|
||||||
|
"""When improvement exactly equals min_improvement, still rejected (strict >)."""
|
||||||
|
old = _make_eval([0.5])
|
||||||
|
new = _make_eval([0.7])
|
||||||
|
assert should_accept(old, new, min_improvement=0.2) is False
|
||||||
|
|
||||||
|
def test_min_improvement_just_over(self) -> None:
|
||||||
|
old = _make_eval([0.5])
|
||||||
|
new = _make_eval([0.7001])
|
||||||
|
assert should_accept(old, new, min_improvement=0.2) is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeScoreEdgeCases:
|
||||||
|
"""Extended edge-case tests for normalize_score."""
|
||||||
|
|
||||||
|
def test_exact_bounds(self) -> None:
|
||||||
|
assert normalize_score(0.0) == 0.0
|
||||||
|
assert normalize_score(1.0) == 1.0
|
||||||
|
|
||||||
|
def test_very_large_value(self) -> None:
|
||||||
|
assert normalize_score(1e10) == 1.0
|
||||||
|
|
||||||
|
def test_very_negative_value(self) -> None:
|
||||||
|
assert normalize_score(-1e10) == 0.0
|
||||||
|
|
||||||
|
def test_custom_bounds_at_edges(self) -> None:
|
||||||
|
assert normalize_score(5.0, min_val=0.0, max_val=10.0) == 5.0
|
||||||
|
assert normalize_score(0.0, min_val=0.0, max_val=10.0) == 0.0
|
||||||
|
assert normalize_score(10.0, min_val=0.0, max_val=10.0) == 10.0
|
||||||
|
|
||||||
|
def test_negative_custom_range(self) -> None:
|
||||||
|
assert normalize_score(0.0, min_val=-5.0, max_val=5.0) == 0.0
|
||||||
|
assert normalize_score(-3.0, min_val=-5.0, max_val=5.0) == -3.0
|
||||||
|
assert normalize_score(-10.0, min_val=-5.0, max_val=5.0) == -5.0
|
||||||
|
|
||||||
|
def test_zero_span_range(self) -> None:
|
||||||
|
"""When min == max, clamps to min."""
|
||||||
|
assert normalize_score(5.0, min_val=5.0, max_val=5.0) == 5.0
|
||||||
|
assert normalize_score(0.0, min_val=5.0, max_val=5.0) == 5.0
|
||||||
|
|
||||||
|
def test_fractional_score(self) -> None:
|
||||||
|
assert normalize_score(0.3333) == pytest.approx(0.3333)
|
||||||
133
tests/unit/test_similarity.py
Normal file
133
tests/unit/test_similarity.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""Tests for similarity adapters — exact, BLEU, ROUGE-L, cosine."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.infrastructure.similarity import (
|
||||||
|
BleuSimilarity,
|
||||||
|
CosineSimilarity,
|
||||||
|
ExactMatchSimilarity,
|
||||||
|
RougeLSimilarity,
|
||||||
|
create_similarity_adapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExactMatchSimilarity:
|
||||||
|
def test_exact_match(self):
|
||||||
|
s = ExactMatchSimilarity()
|
||||||
|
assert s.compute("Hello World", "Hello World") == 1.0
|
||||||
|
|
||||||
|
def test_case_insensitive(self):
|
||||||
|
s = ExactMatchSimilarity()
|
||||||
|
assert s.compute("hello world", "HELLO WORLD") == 1.0
|
||||||
|
|
||||||
|
def test_whitespace_trimmed(self):
|
||||||
|
s = ExactMatchSimilarity()
|
||||||
|
assert s.compute(" hello ", "hello") == 1.0
|
||||||
|
|
||||||
|
def test_no_match(self):
|
||||||
|
s = ExactMatchSimilarity()
|
||||||
|
assert s.compute("hello", "world") == 0.0
|
||||||
|
|
||||||
|
def test_partial_no_match(self):
|
||||||
|
s = ExactMatchSimilarity()
|
||||||
|
assert s.compute("hello world", "hello") == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestBleuSimilarity:
|
||||||
|
def test_perfect_match(self):
|
||||||
|
s = BleuSimilarity()
|
||||||
|
assert s.compute("the cat sat on the mat", "the cat sat on the mat") == 1.0
|
||||||
|
|
||||||
|
def test_no_overlap(self):
|
||||||
|
s = BleuSimilarity()
|
||||||
|
assert s.compute("aaa bbb ccc", "ddd eee fff") == 0.0
|
||||||
|
|
||||||
|
def test_partial_overlap(self):
|
||||||
|
s = BleuSimilarity()
|
||||||
|
score = s.compute("the cat sat", "the cat")
|
||||||
|
assert 0.0 < score < 1.0
|
||||||
|
|
||||||
|
def test_empty_prediction(self):
|
||||||
|
s = BleuSimilarity()
|
||||||
|
assert s.compute("", "hello world") == 0.0
|
||||||
|
|
||||||
|
def test_empty_expected(self):
|
||||||
|
s = BleuSimilarity()
|
||||||
|
assert s.compute("hello world", "") == 0.0
|
||||||
|
|
||||||
|
def test_both_empty(self):
|
||||||
|
s = BleuSimilarity()
|
||||||
|
assert s.compute("", "") == 0.0
|
||||||
|
|
||||||
|
def test_shorter_prediction_gets_brevity_penalty(self):
|
||||||
|
s = BleuSimilarity()
|
||||||
|
short = s.compute("cat", "the cat sat on the mat")
|
||||||
|
full = s.compute("the cat sat on the mat", "the cat sat on the mat")
|
||||||
|
assert short < full
|
||||||
|
|
||||||
|
|
||||||
|
class TestRougeLSimilarity:
|
||||||
|
def test_perfect_match(self):
|
||||||
|
s = RougeLSimilarity()
|
||||||
|
assert s.compute("the cat sat", "the cat sat") == 1.0
|
||||||
|
|
||||||
|
def test_no_overlap(self):
|
||||||
|
s = RougeLSimilarity()
|
||||||
|
assert s.compute("aaa bbb", "ccc ddd") == 0.0
|
||||||
|
|
||||||
|
def test_partial_overlap(self):
|
||||||
|
s = RougeLSimilarity()
|
||||||
|
score = s.compute("the cat sat on the mat", "the cat on the rug")
|
||||||
|
assert 0.0 < score < 1.0
|
||||||
|
|
||||||
|
def test_empty_prediction(self):
|
||||||
|
s = RougeLSimilarity()
|
||||||
|
assert s.compute("", "hello") == 0.0
|
||||||
|
|
||||||
|
def test_subsequence(self):
|
||||||
|
s = RougeLSimilarity()
|
||||||
|
# "cat mat" is a subsequence of "the cat sat on the mat"
|
||||||
|
score = s.compute("cat mat", "the cat sat on the mat")
|
||||||
|
assert score > 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCosineSimilarity:
|
||||||
|
def test_identical_texts(self):
|
||||||
|
s = CosineSimilarity()
|
||||||
|
assert s.compute("hello world", "hello world") == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_no_overlap(self):
|
||||||
|
s = CosineSimilarity()
|
||||||
|
assert s.compute("aaa bbb", "ccc ddd") == 0.0
|
||||||
|
|
||||||
|
def test_partial_overlap(self):
|
||||||
|
s = CosineSimilarity()
|
||||||
|
score = s.compute("hello world foo", "hello world bar")
|
||||||
|
assert 0.0 < score < 1.0
|
||||||
|
|
||||||
|
def test_empty_prediction(self):
|
||||||
|
s = CosineSimilarity()
|
||||||
|
assert s.compute("", "hello") == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSimilarityAdapter:
|
||||||
|
def test_create_exact(self):
|
||||||
|
adapter = create_similarity_adapter("exact")
|
||||||
|
assert isinstance(adapter, ExactMatchSimilarity)
|
||||||
|
|
||||||
|
def test_create_bleu(self):
|
||||||
|
adapter = create_similarity_adapter("bleu")
|
||||||
|
assert isinstance(adapter, BleuSimilarity)
|
||||||
|
|
||||||
|
def test_create_rouge_l(self):
|
||||||
|
adapter = create_similarity_adapter("rouge_l")
|
||||||
|
assert isinstance(adapter, RougeLSimilarity)
|
||||||
|
|
||||||
|
def test_create_cosine(self):
|
||||||
|
adapter = create_similarity_adapter("cosine")
|
||||||
|
assert isinstance(adapter, CosineSimilarity)
|
||||||
|
|
||||||
|
def test_unknown_metric_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Unknown eval metric"):
|
||||||
|
create_similarity_adapter("nonexistent")
|
||||||
233
tests/unit/test_use_cases.py
Normal file
233
tests/unit/test_use_cases.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""Unit tests for OptimizePromptUseCase — direct orchestration tests."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||||
|
from prometheus.application.dto import OptimizationConfig, OptimizationResult
|
||||||
|
from prometheus.application.evaluator import PromptEvaluator
|
||||||
|
from prometheus.application.evolution import EvolutionLoop
|
||||||
|
from prometheus.application.use_cases import OptimizePromptUseCase
|
||||||
|
from prometheus.domain.entities import (
|
||||||
|
Candidate,
|
||||||
|
EvalResult,
|
||||||
|
OptimizationState,
|
||||||
|
Prompt,
|
||||||
|
SyntheticExample,
|
||||||
|
Trajectory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_eval(scores: list[float]) -> EvalResult:
|
||||||
|
return EvalResult(
|
||||||
|
scores=scores,
|
||||||
|
feedbacks=["feedback"] * len(scores),
|
||||||
|
trajectories=[
|
||||||
|
Trajectory(f"in{i}", f"out{i}", s, "feedback", "prompt")
|
||||||
|
for i, s in enumerate(scores)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_state(
|
||||||
|
iterations: int = 3,
|
||||||
|
initial_score: float = 0.3,
|
||||||
|
final_score: float = 0.8,
|
||||||
|
accepted: bool = True,
|
||||||
|
) -> OptimizationState:
|
||||||
|
seed = Candidate(prompt=Prompt(text="seed"), best_score=initial_score, generation=0)
|
||||||
|
best = Candidate(
|
||||||
|
prompt=Prompt(text="optimized" if accepted else "seed"),
|
||||||
|
best_score=final_score,
|
||||||
|
generation=iterations if accepted else 0,
|
||||||
|
)
|
||||||
|
history = []
|
||||||
|
for i in range(1, iterations + 1):
|
||||||
|
event = "accepted" if accepted else "rejected"
|
||||||
|
history.append({"iteration": i, "event": event, "old_score": 0.3, "new_score": 0.8})
|
||||||
|
|
||||||
|
return OptimizationState(
|
||||||
|
iteration=iterations,
|
||||||
|
best_candidate=best,
|
||||||
|
candidates=[seed, best] if accepted else [seed],
|
||||||
|
total_llm_calls=iterations * 11 + 10,
|
||||||
|
history=history,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOptimizePromptUseCaseExecute:
|
||||||
|
"""Tests for the execute() orchestration method."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_evaluator(self) -> MagicMock:
|
||||||
|
return MagicMock(spec=PromptEvaluator)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_proposer(self) -> MagicMock:
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_bootstrap(self) -> MagicMock:
|
||||||
|
return MagicMock(spec=SyntheticBootstrap)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def use_case(
|
||||||
|
self,
|
||||||
|
mock_evaluator: MagicMock,
|
||||||
|
mock_proposer: MagicMock,
|
||||||
|
mock_bootstrap: MagicMock,
|
||||||
|
) -> OptimizePromptUseCase:
|
||||||
|
return OptimizePromptUseCase(
|
||||||
|
evaluator=mock_evaluator,
|
||||||
|
proposer=mock_proposer,
|
||||||
|
bootstrap=mock_bootstrap,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config(self) -> OptimizationConfig:
|
||||||
|
return OptimizationConfig(
|
||||||
|
seed_prompt="Answer the question.",
|
||||||
|
task_description="Q&A task",
|
||||||
|
max_iterations=5,
|
||||||
|
n_synthetic_inputs=20,
|
||||||
|
minibatch_size=5,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_optimization_result(
|
||||||
|
self,
|
||||||
|
use_case: OptimizePromptUseCase,
|
||||||
|
mock_bootstrap: MagicMock,
|
||||||
|
config: OptimizationConfig,
|
||||||
|
) -> None:
|
||||||
|
mock_bootstrap.run.return_value = [
|
||||||
|
SyntheticExample(input_text=f"q{i}", id=i) for i in range(20)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_state = _make_state(iterations=3, initial_score=0.3, final_score=0.9)
|
||||||
|
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||||
|
result = await use_case.execute(config)
|
||||||
|
|
||||||
|
assert isinstance(result, OptimizationResult)
|
||||||
|
assert result.initial_prompt == "Answer the question."
|
||||||
|
assert result.final_score == 0.9
|
||||||
|
assert result.improvement == pytest.approx(0.6)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bootstrap_called_with_config_params(
|
||||||
|
self,
|
||||||
|
use_case: OptimizePromptUseCase,
|
||||||
|
mock_bootstrap: MagicMock,
|
||||||
|
config: OptimizationConfig,
|
||||||
|
) -> None:
|
||||||
|
mock_bootstrap.run.return_value = []
|
||||||
|
mock_state = _make_state()
|
||||||
|
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||||
|
await use_case.execute(config)
|
||||||
|
|
||||||
|
mock_bootstrap.run.assert_called_once_with(
|
||||||
|
task_description="Q&A task",
|
||||||
|
n_examples=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_evolution_loop_configured_from_config(
|
||||||
|
self,
|
||||||
|
use_case: OptimizePromptUseCase,
|
||||||
|
mock_bootstrap: MagicMock,
|
||||||
|
config: OptimizationConfig,
|
||||||
|
) -> None:
|
||||||
|
mock_bootstrap.run.return_value = []
|
||||||
|
mock_state = _make_state()
|
||||||
|
|
||||||
|
with patch.object(EvolutionLoop, "run", return_value=mock_state) as mock_run:
|
||||||
|
await use_case.execute(config)
|
||||||
|
|
||||||
|
# Verify the loop was instantiated with correct params
|
||||||
|
mock_run.assert_called_once()
|
||||||
|
call_args = mock_run.call_args
|
||||||
|
seed_prompt = call_args[0][0]
|
||||||
|
assert seed_prompt.text == "Answer the question."
|
||||||
|
synthetic_pool = call_args[0][1]
|
||||||
|
assert len(synthetic_pool) == 0 # bootstrap returned empty
|
||||||
|
assert call_args[0][2] == "Q&A task"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_total_llm_calls_includes_bootstrap_call(
|
||||||
|
self,
|
||||||
|
use_case: OptimizePromptUseCase,
|
||||||
|
mock_bootstrap: MagicMock,
|
||||||
|
config: OptimizationConfig,
|
||||||
|
) -> None:
|
||||||
|
mock_bootstrap.run.return_value = []
|
||||||
|
mock_state = _make_state(iterations=3)
|
||||||
|
# total_llm_calls from state + 1 for bootstrap
|
||||||
|
expected = mock_state.total_llm_calls + 1
|
||||||
|
|
||||||
|
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||||
|
result = await use_case.execute(config)
|
||||||
|
|
||||||
|
assert result.total_llm_calls == expected
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_candidates_fallback(
|
||||||
|
self,
|
||||||
|
use_case: OptimizePromptUseCase,
|
||||||
|
mock_bootstrap: MagicMock,
|
||||||
|
config: OptimizationConfig,
|
||||||
|
) -> None:
|
||||||
|
mock_bootstrap.run.return_value = [
|
||||||
|
SyntheticExample(input_text=f"q{i}", id=i) for i in range(20)
|
||||||
|
]
|
||||||
|
mock_state = OptimizationState(
|
||||||
|
iteration=0,
|
||||||
|
best_candidate=None,
|
||||||
|
candidates=[],
|
||||||
|
total_llm_calls=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||||
|
result = await use_case.execute(config)
|
||||||
|
|
||||||
|
assert result.optimized_prompt == "Answer the question."
|
||||||
|
assert result.initial_score == 0.0
|
||||||
|
assert result.final_score == 0.0
|
||||||
|
assert result.improvement == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_iterations_used_matches_state(
|
||||||
|
self,
|
||||||
|
use_case: OptimizePromptUseCase,
|
||||||
|
mock_bootstrap: MagicMock,
|
||||||
|
config: OptimizationConfig,
|
||||||
|
) -> None:
|
||||||
|
mock_bootstrap.run.return_value = []
|
||||||
|
mock_state = _make_state(iterations=7)
|
||||||
|
|
||||||
|
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||||
|
result = await use_case.execute(config)
|
||||||
|
|
||||||
|
assert result.iterations_used == 7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_passed_through(
|
||||||
|
self,
|
||||||
|
use_case: OptimizePromptUseCase,
|
||||||
|
mock_bootstrap: MagicMock,
|
||||||
|
config: OptimizationConfig,
|
||||||
|
) -> None:
|
||||||
|
mock_bootstrap.run.return_value = []
|
||||||
|
history = [
|
||||||
|
{"iteration": 1, "event": "accepted"},
|
||||||
|
{"iteration": 2, "event": "rejected"},
|
||||||
|
]
|
||||||
|
mock_state = _make_state()
|
||||||
|
mock_state.history = history
|
||||||
|
|
||||||
|
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||||
|
result = await use_case.execute(config)
|
||||||
|
|
||||||
|
assert result.history == history
|
||||||
Reference in New Issue
Block a user