Compare commits
6 Commits
837a44970f
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5bf2ad59c | ||
|
|
b9745566c8 | ||
|
|
336774a164 | ||
|
|
c92ca4a2b8 | ||
|
|
e2d111ce5b | ||
|
|
f516ca4be6 |
369
docs/FEATURE_ROADMAP.md
Normal file
369
docs/FEATURE_ROADMAP.md
Normal file
@@ -0,0 +1,369 @@
|
||||
# PROMETHEUS Feature Roadmap
|
||||
|
||||
> Complete codebase review — features needed for production-grade prompt optimization.
|
||||
> Generated from v0.1.0 architecture review (2026-03-29).
|
||||
|
||||
---
|
||||
|
||||
## Legend
|
||||
|
||||
| Marker | Meaning |
|
||||
|--------|---------|
|
||||
| **CLI** | Exposed as a CLI option/flag |
|
||||
| **Config** | YAML config field |
|
||||
| **Internal** | No user-facing surface, architectural improvement |
|
||||
| **P1** | Critical / must-have for reliability |
|
||||
| **P2** | High value, should-have |
|
||||
| **P3** | Nice-to-have, deferred to later versions |
|
||||
|
||||
---
|
||||
|
||||
## 1. Multi-Model Routing (P1)
|
||||
|
||||
**Current state:** `OptimizationConfig` defines four model slots (`task_model`, `judge_model`, `proposer_model`, `synth_model`), but `cli/app.py` only configures a single global DSPy LM from `task_model`. All adapters silently use the same model regardless of config.
|
||||
|
||||
**Feature:**
|
||||
- Each adapter (`DSPyLLMAdapter`, `DSPyJudgeAdapter`, `DSPyProposerAdapter`, `DSPySyntheticAdapter`) must instantiate its own `dspy.LM` from the corresponding config field.
|
||||
- Support per-model `api_base` and `api_key_env` overrides (e.g., judge on GPT-4o, propose on a cheaper model).
|
||||
|
||||
**Surface:** Config (already partially defined) — `judge_model`, `proposer_model`, `synth_model` become functional. No new CLI flags needed; the YAML already has the fields.
|
||||
|
||||
**Scope:** Infrastructure layer (`llm_adapter.py`, `judge_adapter.py`, `proposer_adapter.py`, `synth_adapter.py`) + `cli/app.py` DI wiring.
|
||||
|
||||
---
|
||||
|
||||
## 2. Async / Parallel Execution (P1)
|
||||
|
||||
**Current state:** All LLM calls (execute, judge, propose) are sequential. A single iteration with `minibatch_size=5` makes ~11 sequential LLM calls. Wall-clock time scales linearly with minibatch size.
|
||||
|
||||
**Feature:**
|
||||
- Parallelize execution of the prompt across a minibatch (`asyncio.gather` or `dspy.Parallel`).
|
||||
- Parallelize judge calls within a batch.
|
||||
- Keep the proposer sequential (single call per iteration).
|
||||
|
||||
**Surface:** Internal. Optionally exposed via `--max-concurrency` CLI flag and `max_concurrency` YAML field.
|
||||
|
||||
**Scope:** `evaluator.py`, `judge_adapter.py`, `llm_adapter.py`.
|
||||
|
||||
---
|
||||
|
||||
## 3. Robust Error Handling & Retry (P1)
|
||||
|
||||
**Current state:** The evolution loop catches broad `Exception` per iteration and logs it, then continues. Individual LLM call failures (timeouts, rate limits, malformed responses) are not retried. DSPy module fallbacks only cover parsing, not network errors.
|
||||
|
||||
**Feature:**
|
||||
- Retry with exponential backoff for transient errors (rate limits, timeouts, 5xx).
|
||||
- Configurable `max_retries` and `retry_delay_base`.
|
||||
- Circuit breaker: if N consecutive iterations fail, pause and alert.
|
||||
- Per-call error isolation: one bad minibatch item shouldn't fail the whole evaluation.
|
||||
|
||||
**Surface:** `--max-retries` CLI flag, `max_retries` Config field. `--error-strategy` (skip | retry | abort) CLI flag.
|
||||
|
||||
**Scope:** Infrastructure adapters + evolution loop.
|
||||
|
||||
---
|
||||
|
||||
## 4. Checkpoint & Resume (P2)
|
||||
|
||||
**Current state:** If a long optimization run crashes or is interrupted, all progress is lost. There is no intermediate state persistence.
|
||||
|
||||
**Feature:**
|
||||
- Save `OptimizationState` to disk every K iterations (or every accepted improvement).
|
||||
- Resume from the latest checkpoint file on restart.
|
||||
- Checkpoint includes: current best candidate, all candidates, iteration number, LLM call count, RNG seed state.
|
||||
|
||||
**Surface:** `--checkpoint-dir` CLI flag (default: `.prometheus/checkpoints/`). `--resume` CLI flag to resume from latest checkpoint. `checkpoint_interval` Config field.
|
||||
|
||||
**Scope:** New `CheckpointPort` in domain, `JsonCheckpointPersistence` in infrastructure, modifications to `EvolutionLoop.run()`.
|
||||
|
||||
---
|
||||
|
||||
## 5. Population-Based Evolution (P2)
|
||||
|
||||
**Current state:** The evolution loop keeps only a single best candidate (hill climbing). No diversity, no crossover, no population dynamics. The `Candidate` entity has `generation` and `parent_id` fields that suggest population support was planned.
|
||||
|
||||
**Feature:**
|
||||
- Maintain a population of K candidates (e.g., top-K by score or Pareto front).
|
||||
- Crossover: combine instructions from two parent candidates.
|
||||
- Mutation operators: paraphrase, constrain, generalize, specialize.
|
||||
- Diversity maintenance: penalize candidates too similar to existing ones (cosine similarity or edit distance).
|
||||
|
||||
**Surface:** `--population-size` CLI flag, `population_size` Config field. `--crossover-rate`, `--mutation-rate` CLI flags.
|
||||
|
||||
**Scope:** `EvolutionLoop` refactor, new `CrossoverPort` and `MutationPort` in domain, new DSPy signatures for crossover/mutation in infrastructure.
|
||||
|
||||
---
|
||||
|
||||
## 6. Hold-Out Validation (P2)
|
||||
|
||||
**Current state:** The same synthetic inputs are used for both optimization and evaluation. No train/test split. Risk of overfitting to synthetic inputs.
|
||||
|
||||
**Feature:**
|
||||
- Split synthetic pool into train (e.g., 70%) and validation (30%) sets.
|
||||
- Evolution uses train minibatches for accept/reject decisions.
|
||||
- After each iteration, evaluate the best candidate on the hold-out set.
|
||||
- Report both train and validation scores in results.
|
||||
- Optional early stopping if validation score degrades for K consecutive iterations.
|
||||
|
||||
**Surface:** `--validation-split` CLI flag (default: 0.3). `--early-stop-patience` CLI flag (default: 5). Config fields: `validation_split`, `early_stop_patience`.
|
||||
|
||||
**Scope:** `SyntheticBootstrap`, `EvolutionLoop`, `OptimizationResult` (add validation metrics).
|
||||
|
||||
---
|
||||
|
||||
## 7. Custom Judge Criteria (P2)
|
||||
|
||||
**Current state:** The judge uses a hardcoded rubric in `JudgeOutput` DSPy signature ("score 0.0-1.0" with generic quality assessment). Users cannot customize evaluation criteria.
|
||||
|
||||
**Feature:**
|
||||
- Allow users to define custom judge rubrics, criteria, and scoring scales.
|
||||
- Support multi-dimensional scoring (e.g., accuracy: 0-10, clarity: 0-10, safety: 0-10) with configurable weights.
|
||||
- Allow `perfect_score` to reflect the custom scale.
|
||||
|
||||
**Surface:** `judge_criteria` YAML field (free text). `judge_dimensions` YAML field (list of `{name, weight, description}`). CLI: `--judge-criteria` for quick overrides.
|
||||
|
||||
**Scope:** `JudgeOutput` signature (dynamic instructions), `JudgePort`, `DSPyJudgeAdapter`, `scoring.py` (weighted aggregation).
|
||||
|
||||
---
|
||||
|
||||
## 8. Real-World Evaluation Harness (P2)
|
||||
|
||||
**Current state:** The system only evaluates against synthetic inputs. There is no way to test optimized prompts against real inputs with known-good outputs.
|
||||
|
||||
**Feature:**
|
||||
- Accept an optional evaluation dataset (CSV/JSON with `input` and `expected_output` columns).
|
||||
- When provided, use exact/semantic similarity matching against expected outputs instead of (or in addition to) LLM-as-Judge.
|
||||
- Report metrics: accuracy, BLEU, ROUGE, or embedding cosine similarity vs expected.
|
||||
|
||||
**Surface:** `--eval-dataset` CLI flag. `eval_dataset_path` Config field. `--eval-metric` CLI flag (exact | semantic | llm_judge).
|
||||
|
||||
**Scope:** New `GroundTruthEvaluator` in application, new `SimilarityPort` in domain, dataset loader in infrastructure.
|
||||
|
||||
---
|
||||
|
||||
## 9. Logging & Observability (P2)
|
||||
|
||||
**Current state:** Verbose mode (`-v`) configures Python's `logging` module but no handler is attached (Bug #4 in TEST_REPORT.md). No structured logging, no tracing.
|
||||
|
||||
**Feature:**
|
||||
- Proper structured logging with configurable levels (DEBUG, INFO, WARNING, ERROR).
|
||||
- JSON-formatted log output for machine parsing.
|
||||
- Per-iteration trace: minibatch sample IDs, execution outputs, judge scores, proposer prompt diff.
|
||||
- Optional OpenTelemetry export for distributed tracing.
|
||||
|
||||
**Surface:** `-v` / `--verbose` enables INFO level. `--debug` enables DEBUG level. `--log-format` (text | json). `--log-file` for file output. Config fields: `log_level`, `log_format`, `log_file`.
|
||||
|
||||
**Scope:** `cli/app.py` (logging setup), `evolution.py` (structured traces), new `TracingPort` in domain.
|
||||
|
||||
---
|
||||
|
||||
## 10. CLI Improvements (P2)
|
||||
|
||||
**Current state:** Single `optimize` command. Known Typer 0.24 bug absorbing subcommands (Bug #1). No `version`, `init`, or `list-results` commands.
|
||||
|
||||
**Feature:**
|
||||
- Fix Typer subcommand routing.
|
||||
- `prometheus version` — show version.
|
||||
- `prometheus init` — scaffold a config YAML interactively.
|
||||
- `prometheus list` — list past optimization runs.
|
||||
- `prometheus diff` — compare two result files (before/after prompt diff, score improvement).
|
||||
- `prometheus eval` — evaluate a prompt against a dataset without optimization.
|
||||
|
||||
**Surface:** CLI subcommands.
|
||||
|
||||
**Scope:** `cli/app.py` restructured into `cli/commands/` with one module per command.
|
||||
|
||||
---
|
||||
|
||||
## 11. Input Validation & Schema Enforcement (P2)
|
||||
|
||||
**Current state:** Config YAML is parsed as a raw dict with no schema validation. Missing or wrong-type fields cause cryptic errors deep in the pipeline.
|
||||
|
||||
**Feature:**
|
||||
- Validate input YAML against a Pydantic schema (leveraging the existing `pydantic` dependency).
|
||||
- Provide clear, actionable error messages for missing/invalid fields.
|
||||
- Support config migration/upgrade from older versions.
|
||||
|
||||
**Surface:** Internal. Errors surface as clear CLI messages.
|
||||
|
||||
**Scope:** `OptimizationConfig` converted to Pydantic model with validators, `cli/app.py` validation step before pipeline execution.
|
||||
|
||||
---
|
||||
|
||||
## 12. Adaptive Minibatch Sizing (P3)
|
||||
|
||||
**Current state:** Minibatch size is static throughout the run. Small batches are noisy; large batches are expensive.
|
||||
|
||||
**Feature:**
|
||||
- Start with a small minibatch for quick early iterations.
|
||||
- Increase minibatch size as the prompt improves (higher confidence needed for marginal gains).
|
||||
- Shrink if too many evaluations fail (cost optimization).
|
||||
|
||||
**Surface:** `--adaptive-minibatch` CLI flag (boolean toggle). `minibatch_size` becomes `minibatch_size_min` and `minibatch_size_max` in config.
|
||||
|
||||
**Scope:** `EvolutionLoop`, `SyntheticBootstrap`.
|
||||
|
||||
---
|
||||
|
||||
## 13. Prompt Diversity Tracking (P3)
|
||||
|
||||
**Current state:** No visibility into how much the prompt is actually changing between iterations. A "successful" optimization might just rephrase without structural change.
|
||||
|
||||
**Feature:**
|
||||
- Compute edit distance (Levenshtein) or embedding cosine similarity between consecutive prompts.
|
||||
- Report diversity metrics in the result.
|
||||
- Flag stagnation (N iterations with <epsilon change).
|
||||
|
||||
**Surface:** Internal. Reported in `OptimizationResult.history` entries.
|
||||
|
||||
**Scope:** `EvolutionLoop`, `OptimizationResult` (add diversity field per history entry).
|
||||
|
||||
---
|
||||
|
||||
## 14. Temperature & Sampling Control (P3)
|
||||
|
||||
**Current state:** No way to control LLM temperature, top_p, or other sampling parameters for any of the four model slots. DSPy defaults apply.
|
||||
|
||||
**Feature:**
|
||||
- Per-model-slot temperature and sampling parameters.
|
||||
- Higher temperature for proposer (creativity), lower for judge (consistency).
|
||||
|
||||
**Surface:** `task_temperature`, `judge_temperature`, `proposer_temperature`, `synth_temperature` Config fields. `--temperature` CLI flag for global override.
|
||||
|
||||
**Scope:** `cli/app.py` (DSPy LM configuration), infrastructure adapters.
|
||||
|
||||
---
|
||||
|
||||
## 15. Cost Estimation & Budget Caps (P3)
|
||||
|
||||
**Current state:** `total_llm_calls` is tracked (inaccurately). No cost estimation, no budget caps.
|
||||
|
||||
**Feature:**
|
||||
- Estimate cost per run based on model pricing and approximate token counts.
|
||||
- Allow users to set a budget cap (`--max-cost-usd`).
|
||||
- Report estimated cost in the result.
|
||||
|
||||
**Surface:** `--max-cost-usd` CLI flag. `max_cost_usd` Config field. Cost breakdown in result output.
|
||||
|
||||
**Scope:** `cli/app.py`, `OptimizationResult` (add cost fields), token counting in adapters.
|
||||
|
||||
---
|
||||
|
||||
## 16. Multi-Objective Optimization (P3)
|
||||
|
||||
**Current state:** Single scalar score from the judge. The `Prompt` entity comment mentions "Pareto tracking" but it's not implemented.
|
||||
|
||||
**Feature:**
|
||||
- Optimize for multiple objectives simultaneously (quality, latency, token efficiency, safety).
|
||||
- Maintain a Pareto front of non-dominated candidates.
|
||||
- Allow users to set objective weights or constraints.
|
||||
|
||||
**Surface:** `objectives` Config field (list of `{name, weight, judge_criteria}`). CLI: `--objective` repeatable flag.
|
||||
|
||||
**Scope:** `EvolutionLoop` (Pareto front), `scoring.py` (multi-objective acceptance), `OptimizationResult` (Pareto set).
|
||||
|
||||
---
|
||||
|
||||
## 17. Export Optimized Prompt (P3)
|
||||
|
||||
**Current state:** The optimized prompt is embedded in the YAML result file. No easy way to extract it for use.
|
||||
|
||||
**Feature:**
|
||||
- `prometheus export` command to extract the optimized prompt as plain text.
|
||||
- Support multiple export formats: plain text, Markdown, JSON, LangChain template, DSPy module.
|
||||
- Copy to clipboard option.
|
||||
|
||||
**Surface:** `prometheus export --format <txt|md|json|langchain|dspy>` CLI subcommand. `--clipboard` flag.
|
||||
|
||||
**Scope:** New `cli/commands/export.py`, format renderers in infrastructure.
|
||||
|
||||
---
|
||||
|
||||
## 18. Config Profiles / Presets (P3)
|
||||
|
||||
**Current state:** Every run requires a full config YAML. Common patterns (fast iterate, thorough optimize, cheap run) are not captured.
|
||||
|
||||
**Feature:**
|
||||
- Named profiles: `fast`, `thorough`, `economy`, `research`.
|
||||
- Profile overrides individual config fields.
|
||||
- User-defined profiles stored in `~/.prometheus/profiles/`.
|
||||
|
||||
**Surface:** `--profile` CLI flag. `prometheus profile list` / `prometheus profile create` subcommands.
|
||||
|
||||
**Scope:** `cli/app.py`, new `ProfileManager` in application.
|
||||
|
||||
---
|
||||
|
||||
## Summary Table
|
||||
|
||||
| # | Feature | Priority | CLI Surface | Config Surface | Estimated Scope |
|
||||
|---|---------|----------|-------------|----------------|-----------------|
|
||||
| 1 | Multi-Model Routing | P1 | Existing | Existing | Small |
|
||||
| 2 | Async / Parallel Execution | P1 | `--max-concurrency` | `max_concurrency` | Medium |
|
||||
| 3 | Error Handling & Retry | P1 | `--max-retries`, `--error-strategy` | `max_retries`, `error_strategy` | Medium |
|
||||
| 4 | Checkpoint & Resume | P2 | `--checkpoint-dir`, `--resume` | `checkpoint_interval` | Medium |
|
||||
| 5 | Population-Based Evolution | P2 | `--population-size`, `--crossover-rate` | `population_size`, `crossover_rate` | Large |
|
||||
| 6 | Hold-Out Validation | P2 | `--validation-split`, `--early-stop-patience` | `validation_split`, `early_stop_patience` | Medium |
|
||||
| 7 | Custom Judge Criteria | P2 | `--judge-criteria` | `judge_criteria`, `judge_dimensions` | Medium |
|
||||
| 8 | Real-World Eval Harness | P2 | `--eval-dataset`, `--eval-metric` | `eval_dataset_path` | Large |
|
||||
| 9 | Logging & Observability | P2 | `--debug`, `--log-format`, `--log-file` | `log_level`, `log_format` | Medium |
|
||||
| 10 | CLI Improvements | P2 | Subcommands | — | Medium |
|
||||
| 11 | Input Validation | P2 | — (error messages) | — | Small |
|
||||
| 12 | Adaptive Minibatch | P3 | `--adaptive-minibatch` | `minibatch_size_min/max` | Small |
|
||||
| 13 | Prompt Diversity Tracking | P3 | — | — | Small |
|
||||
| 14 | Temperature & Sampling | P3 | `--temperature` | `*_temperature` | Small |
|
||||
| 15 | Cost Estimation | P3 | `--max-cost-usd` | `max_cost_usd` | Small |
|
||||
| 16 | Multi-Objective Optimization | P3 | `--objective` | `objectives` | Large |
|
||||
| 17 | Export Optimized Prompt | P3 | `prometheus export` | — | Small |
|
||||
| 18 | Config Profiles / Presets | P3 | `--profile` | — | Small |
|
||||
|
||||
---
|
||||
|
||||
## Known Bugs (from TEST_REPORT.md and code review)
|
||||
|
||||
| # | Bug | Severity | File |
|
||||
|---|-----|----------|------|
|
||||
| 1 | Multi-model config not wired — all adapters use single global LM | HIGH | `cli/app.py`, all adapters |
|
||||
| 2 | `DSPyLLMAdapter` accepts `model` param but never uses it | HIGH | `infrastructure/llm_adapter.py` |
|
||||
| 3 | CLI subcommand `optimize` absorbed by Typer 0.24 | HIGH | `cli/app.py` |
|
||||
| 4 | Verbose logging produces no output — no handler configured | MEDIUM | `cli/app.py` |
|
||||
| 5 | `total_llm_calls` counter is inaccurate | LOW | `application/use_cases.py`, `evolution.py` |
|
||||
| 6 | `normalize_score()` is dead code — never called | LOW | `domain/scoring.py` |
|
||||
| 7 | `AppSettings` is never imported or used | LOW | `config.py` |
|
||||
| 8 | No LLM error handling in evolution loop | MEDIUM | `evolution.py` |
|
||||
| 9 | Unpinned dependencies (dspy, typer) | LOW | `pyproject.toml` |
|
||||
|
||||
---
|
||||
|
||||
## Test Coverage Gaps
|
||||
|
||||
| Area | Current | Needed |
|
||||
|------|---------|--------|
|
||||
| CLI commands | 0 tests | Unit + integration for each subcommand |
|
||||
| Config validation | 0 tests | Schema validation, missing fields, type errors |
|
||||
| Evolution loop | 3 tests (single iteration each) | Multi-iteration, mixed accept/reject, failure recovery |
|
||||
| Integration pipeline | 1 test (happy path only) | Error paths, mixed results, real adapters |
|
||||
| Adapter coverage | 1 adapter tested | All 4 adapters + error scenarios |
|
||||
| Use case orchestration | 1 indirect test | Direct unit tests for `OptimizePromptUseCase` |
|
||||
|
||||
---
|
||||
|
||||
## Recommended Implementation Order
|
||||
|
||||
### Phase 1 — Production Reliability (P1)
|
||||
1. Fix multi-model routing (#1) — highest impact, smallest scope
|
||||
2. Add error handling & retry (#3) — essential for production runs
|
||||
3. Implement async/parallel execution (#2) — biggest wall-clock improvement
|
||||
|
||||
### Phase 2 — Optimization Quality (P2)
|
||||
4. Input validation (#11) — small scope, high reliability gain
|
||||
5. Logging & observability (#9) — enables debugging long runs
|
||||
6. CLI improvements (#10) — fix Typer bug, add basic commands
|
||||
7. Hold-out validation (#6) — prevents overfitting
|
||||
8. Checkpoint & resume (#4) — essential for long runs
|
||||
9. Custom judge criteria (#7) — enables domain-specific optimization
|
||||
|
||||
### Phase 3 — Advanced Features (P3)
|
||||
10. Population-based evolution (#5)
|
||||
11. Real-world eval harness (#8)
|
||||
12. Remaining P3 features as demand dictates
|
||||
@@ -5,17 +5,18 @@ description = "Prompt evolution without reference data"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"dspy>=2.6,<3.0",
|
||||
"typer>=0.15,<0.20",
|
||||
"pydantic>=2.10",
|
||||
"pydantic-settings>=2.7",
|
||||
"pyyaml>=6.0",
|
||||
"rich>=13.9",
|
||||
"dspy==2.6.27",
|
||||
"typer==0.19.2",
|
||||
"pydantic==2.12.5",
|
||||
"pydantic-settings==2.13.1",
|
||||
"pyyaml==6.0.3",
|
||||
"rich==14.3.3",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.3",
|
||||
"pytest-asyncio>=0.24",
|
||||
"pytest-cov>=6.0",
|
||||
"ruff>=0.9",
|
||||
"mypy>=1.14",
|
||||
@@ -37,11 +38,14 @@ target-version = "py312"
|
||||
python_version = "3.12"
|
||||
strict = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["dspy", "dspy.*"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["prometheus.infrastructure.*", "prometheus.cli.app"]
|
||||
module = ["prometheus.infrastructure.*", "prometheus.cli.app", "prometheus.cli.commands.*"]
|
||||
disable_error_code = ["misc", "import-untyped"]
|
||||
|
||||
|
||||
@@ -22,6 +22,24 @@ class SyntheticBootstrap:
|
||||
self._generator = generator
|
||||
self._rng = random.Random(seed)
|
||||
|
||||
@staticmethod
|
||||
def split_pool(
|
||||
pool: list[SyntheticExample],
|
||||
validation_fraction: float,
|
||||
rng: random.Random | None = None,
|
||||
) -> tuple[list[SyntheticExample], list[SyntheticExample]]:
|
||||
"""Split *pool* into (train, validation) sets.
|
||||
|
||||
Returns (pool, []) when *validation_fraction* is 0.
|
||||
"""
|
||||
if validation_fraction <= 0.0 or len(pool) < 2:
|
||||
return pool, []
|
||||
n_val = max(1, int(len(pool) * validation_fraction))
|
||||
shuffled = list(pool)
|
||||
_rng = rng or random.Random(42)
|
||||
_rng.shuffle(shuffled)
|
||||
return shuffled[:-n_val], shuffled[-n_val:]
|
||||
|
||||
def run(self, task_description: str, n_examples: int) -> list[SyntheticExample]:
|
||||
"""Generate the synthetic pool in a single call.
|
||||
|
||||
|
||||
@@ -4,34 +4,211 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
@dataclass
|
||||
class OptimizationConfig:
|
||||
"""Complete configuration for a PROMETHEUS run."""
|
||||
|
||||
# --- Prompt ---
|
||||
seed_prompt: str
|
||||
task_description: str
|
||||
# Current config schema version.
|
||||
CONFIG_VERSION = 1
|
||||
|
||||
_ERROR_STRATEGY_VALUES = {"skip", "retry", "abort"}
|
||||
_EVAL_METRIC_VALUES = {"exact", "bleu", "rouge_l", "cosine", "llm_judge"}
|
||||
|
||||
|
||||
class JudgeDimension(BaseModel):
|
||||
"""A single evaluation dimension for multi-dimensional scoring."""
|
||||
|
||||
name: str = Field(min_length=1, description="Dimension name (e.g. accuracy, clarity, safety).")
|
||||
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="Weight for this dimension (0.0–1.0).")
|
||||
description: str = Field(default="", description="What this dimension measures.")
|
||||
|
||||
|
||||
class OptimizationConfig(BaseModel):
|
||||
"""Complete configuration for a PROMETHEUS run.
|
||||
|
||||
Validated with Pydantic so missing or wrong-type fields produce clear,
|
||||
actionable error messages at the CLI boundary instead of cryptic failures
|
||||
deep in the pipeline.
|
||||
"""
|
||||
|
||||
model_config = {"extra": "forbid"}
|
||||
|
||||
# --- Schema version (for migration support) ---
|
||||
config_version: int = Field(
|
||||
default=CONFIG_VERSION,
|
||||
description="Config schema version. Set automatically; used for migration.",
|
||||
)
|
||||
|
||||
# --- Prompt (required) ---
|
||||
seed_prompt: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="The initial prompt to optimize.",
|
||||
)
|
||||
task_description: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Description of the task the prompt should accomplish.",
|
||||
)
|
||||
|
||||
# --- Models ---
|
||||
task_model: str = "openai/gpt-4o-mini"
|
||||
judge_model: str = "openai/gpt-4o"
|
||||
proposer_model: str = "openai/gpt-4o"
|
||||
synth_model: str = "openai/gpt-4o"
|
||||
task_model: str = Field(default="openai/gpt-4o-mini", min_length=1)
|
||||
judge_model: str = Field(default="openai/gpt-4o", min_length=1)
|
||||
proposer_model: str = Field(default="openai/gpt-4o", min_length=1)
|
||||
synth_model: str = Field(default="openai/gpt-4o", min_length=1)
|
||||
|
||||
# --- Per-model API overrides (optional, fall back to global api_base/api_key_env) ---
|
||||
task_api_base: str | None = None
|
||||
task_api_key_env: str | None = None
|
||||
judge_api_base: str | None = None
|
||||
judge_api_key_env: str | None = None
|
||||
proposer_api_base: str | None = None
|
||||
proposer_api_key_env: str | None = None
|
||||
synth_api_base: str | None = None
|
||||
synth_api_key_env: str | None = None
|
||||
|
||||
# --- Global API settings (optional) ---
|
||||
api_base: str | None = None
|
||||
api_key_env: str | None = None
|
||||
|
||||
# --- Evolution parameters ---
|
||||
max_iterations: int = 30
|
||||
n_synthetic_inputs: int = 20
|
||||
minibatch_size: int = 5
|
||||
perfect_score: float = 1.0
|
||||
max_iterations: int = Field(default=30, ge=1, description="Maximum evolution iterations.")
|
||||
n_synthetic_inputs: int = Field(default=20, ge=1, description="Number of synthetic inputs to generate.")
|
||||
minibatch_size: int = Field(default=5, ge=1, description="Inputs per evaluation minibatch.")
|
||||
perfect_score: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
|
||||
# --- Population-based evolution ---
|
||||
population_size: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Number of candidates in the evolution population. 1 = single-candidate hill climbing (backward compat).",
|
||||
)
|
||||
crossover_rate: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Probability of applying crossover vs reflective mutation.",
|
||||
)
|
||||
mutation_rate: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Probability of applying a mutation operator after crossover or proposal.",
|
||||
)
|
||||
diversity_penalty: float = Field(
|
||||
default=0.1,
|
||||
ge=0.0,
|
||||
description="Penalty weight for similarity to existing population members.",
|
||||
)
|
||||
|
||||
# --- Reproducibility ---
|
||||
seed: int = 42
|
||||
seed: int = Field(default=42, ge=0)
|
||||
|
||||
# --- Concurrency ---
|
||||
max_concurrency: int = Field(default=5, ge=1, description="Max parallel LLM calls.")
|
||||
|
||||
# --- Error handling ---
|
||||
max_retries: int = Field(default=3, ge=0, description="Max retry attempts for transient errors.")
|
||||
retry_delay_base: float = Field(default=1.0, gt=0, description="Base delay in seconds for retry backoff.")
|
||||
circuit_breaker_threshold: int = Field(default=5, ge=1, description="Consecutive failures before circuit opens.")
|
||||
error_strategy: str = Field(default="retry", description="Error handling strategy: skip | retry | abort.")
|
||||
|
||||
# --- Logging & observability ---
|
||||
debug: bool = Field(default=False, description="Enable DEBUG-level logging.")
|
||||
log_format: str = Field(default="text", description="Log output format: text | json.")
|
||||
log_file: str | None = Field(default=None, description="Optional file path for log output.")
|
||||
|
||||
# --- Checkpoint & resume ---
|
||||
checkpoint_dir: str | None = Field(
|
||||
default=None,
|
||||
description="Directory for checkpoint files. Set to enable checkpointing.",
|
||||
)
|
||||
checkpoint_interval: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
description="Save a checkpoint every N iterations (and on every accepted improvement).",
|
||||
)
|
||||
resume: bool = Field(
|
||||
default=False,
|
||||
description="Resume from the latest checkpoint in checkpoint_dir.",
|
||||
)
|
||||
|
||||
# --- Output ---
|
||||
output_path: str = "output.yaml"
|
||||
output_path: str = Field(default="output.yaml", min_length=1)
|
||||
verbose: bool = False
|
||||
|
||||
# --- Hold-out validation ---
|
||||
validation_split: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
lt=1.0,
|
||||
description="Fraction of synthetic pool reserved for validation (0 = disabled).",
|
||||
)
|
||||
early_stop_patience: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
description="Stop if validation score degrades for this many consecutive iterations.",
|
||||
)
|
||||
|
||||
# --- Judge criteria & multi-dimensional scoring ---
|
||||
judge_criteria: str | None = Field(
|
||||
default=None,
|
||||
description="Custom judge rubric or evaluation criteria override (free text).",
|
||||
)
|
||||
judge_dimensions: list[JudgeDimension] | None = Field(
|
||||
default=None,
|
||||
description="Multi-dimensional scoring dimensions with configurable weights.",
|
||||
)
|
||||
|
||||
# --- Ground-truth evaluation ---
|
||||
eval_dataset_path: str | None = Field(
|
||||
default=None,
|
||||
min_length=1,
|
||||
description="Path to a CSV/JSON dataset with 'input' and 'expected_output' columns.",
|
||||
)
|
||||
eval_metric: str = Field(
|
||||
default="bleu",
|
||||
description="Similarity metric for ground-truth eval: exact | bleu | rouge_l | cosine | llm_judge.",
|
||||
)
|
||||
|
||||
@field_validator("log_format")
|
||||
@classmethod
|
||||
def _validate_log_format(cls, v: str) -> str:
|
||||
allowed = {"text", "json"}
|
||||
if v not in allowed:
|
||||
raise ValueError(f"log_format must be one of {sorted(allowed)}, got '{v}'")
|
||||
return v
|
||||
|
||||
@field_validator("error_strategy")
|
||||
@classmethod
|
||||
def _validate_error_strategy(cls, v: str) -> str:
|
||||
if v not in _ERROR_STRATEGY_VALUES:
|
||||
raise ValueError(
|
||||
f"error_strategy must be one of {sorted(_ERROR_STRATEGY_VALUES)}, got '{v}'"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("eval_metric")
|
||||
@classmethod
|
||||
def _validate_eval_metric(cls, v: str) -> str:
|
||||
if v not in _EVAL_METRIC_VALUES:
|
||||
raise ValueError(
|
||||
f"eval_metric must be one of {sorted(_EVAL_METRIC_VALUES)}, got '{v}'"
|
||||
)
|
||||
return v
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _migrate_config(cls, data: Any) -> Any:
|
||||
"""Apply migration transforms for older config versions."""
|
||||
if isinstance(data, dict):
|
||||
version = data.get("config_version", CONFIG_VERSION)
|
||||
# Future migrations go here, e.g.:
|
||||
# if version < 2:
|
||||
# data = _migrate_v1_to_v2(data)
|
||||
# Always stamp current version.
|
||||
data["config_version"] = CONFIG_VERSION
|
||||
return data
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizationResult:
|
||||
@@ -45,3 +222,7 @@ class OptimizationResult:
|
||||
final_score: float
|
||||
improvement: float
|
||||
history: list[dict[str, Any]] = field(default_factory=list)
|
||||
# Hold-out validation metrics (populated when validation_split > 0)
|
||||
final_validation_score: float | None = None
|
||||
best_validation_score: float | None = None
|
||||
early_stopped: bool = False
|
||||
|
||||
@@ -6,6 +6,9 @@ Combines candidate prompt execution + LLM-as-Judge evaluation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from prometheus.domain.entities import (
|
||||
EvalResult,
|
||||
Prompt,
|
||||
@@ -13,6 +16,9 @@ from prometheus.domain.entities import (
|
||||
Trajectory,
|
||||
)
|
||||
from prometheus.domain.ports import JudgePort, LLMPort
|
||||
from prometheus.domain.scoring import normalize_score
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptEvaluator:
|
||||
@@ -21,13 +27,23 @@ class PromptEvaluator:
|
||||
Pipeline: execute → judge → build trajectories.
|
||||
Replaces GEPA's EvaluatorFn. Instead of comparing to ground truth,
|
||||
uses an LLM-as-Judge.
|
||||
|
||||
Execution and judge calls run in parallel (bounded by *max_concurrency*).
|
||||
Per-call isolation: a failure on one minibatch item produces a
|
||||
zero-score trajectory instead of crashing the whole batch.
|
||||
"""
|
||||
|
||||
def __init__(self, executor: LLMPort, judge: JudgePort):
|
||||
def __init__(
|
||||
self,
|
||||
executor: LLMPort,
|
||||
judge: JudgePort,
|
||||
max_concurrency: int = 5,
|
||||
):
|
||||
self._executor = executor
|
||||
self._judge = judge
|
||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
def evaluate(
|
||||
async def evaluate(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
minibatch: list[SyntheticExample],
|
||||
@@ -36,19 +52,20 @@ class PromptEvaluator:
|
||||
"""Evaluate the prompt on the minibatch.
|
||||
|
||||
Steps:
|
||||
1. Execute the prompt on each input in the minibatch
|
||||
1. Execute the prompt on each input in the minibatch (parallel)
|
||||
2. Judge each (input, output) pair
|
||||
3. Build trajectories with feedback
|
||||
"""
|
||||
# Step 1: Execution
|
||||
outputs: list[str] = []
|
||||
for example in minibatch:
|
||||
raw_output = self._executor.execute(prompt, example.input_text)
|
||||
outputs.append(raw_output)
|
||||
# Step 1: Parallel execution (per-item isolation)
|
||||
output_coros = [
|
||||
self._execute_single(prompt, example)
|
||||
for example in minibatch
|
||||
]
|
||||
outputs = await asyncio.gather(*output_coros)
|
||||
|
||||
# Step 2: Judgement
|
||||
# Step 2: Judgement (judge_adapter handles its own per-call isolation + parallelism)
|
||||
pairs = [(ex.input_text, out) for ex, out in zip(minibatch, outputs)]
|
||||
judge_results = self._judge.judge_batch(task_description, pairs)
|
||||
judge_results = await self._judge.judge_batch(task_description, pairs)
|
||||
|
||||
# Step 3: Build trajectories
|
||||
scores: list[float] = []
|
||||
@@ -56,6 +73,7 @@ class PromptEvaluator:
|
||||
trajectories: list[Trajectory] = []
|
||||
for i, (example, output) in enumerate(zip(minibatch, outputs)):
|
||||
score, feedback = judge_results[i]
|
||||
score = normalize_score(score)
|
||||
scores.append(score)
|
||||
feedbacks.append(feedback)
|
||||
trajectories.append(
|
||||
@@ -73,3 +91,17 @@ class PromptEvaluator:
|
||||
feedbacks=feedbacks,
|
||||
trajectories=trajectories,
|
||||
)
|
||||
|
||||
async def _execute_single(
|
||||
self, prompt: Prompt, example: SyntheticExample
|
||||
) -> str:
|
||||
async with self._semaphore:
|
||||
try:
|
||||
return await self._executor.execute(prompt, example.input_text)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Execution failed for input '%s…': %s",
|
||||
example.input_text[:40],
|
||||
exc,
|
||||
)
|
||||
return f"[execution error: {exc}]"
|
||||
|
||||
@@ -2,33 +2,51 @@
|
||||
Evolution loop — core PROMETHEUS engine.
|
||||
|
||||
Orchestrates the select → evaluate → propose → accept cycle.
|
||||
Equivalent to GEPAEngine.run(), adapted to work without a valset.
|
||||
Supports two modes:
|
||||
- Single-candidate hill climbing (population_size=1, backward compat)
|
||||
- Population-based evolution with crossover & mutation (population_size>1)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.cli.logging_setup import get_logger
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
OptimizationState,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
)
|
||||
from prometheus.domain.ports import ProposerPort
|
||||
from prometheus.domain.ports import (
|
||||
CheckpointPort,
|
||||
CrossoverPort,
|
||||
MutationPort,
|
||||
ProposerPort,
|
||||
)
|
||||
from prometheus.domain.scoring import should_accept
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger("evolution")
|
||||
|
||||
|
||||
class CircuitBreakerOpen(Exception):
|
||||
"""Raised when the circuit breaker trips due to too many consecutive failures."""
|
||||
|
||||
|
||||
class EvolutionLoop:
|
||||
"""Main evolution loop.
|
||||
|
||||
Design:
|
||||
- Keeps only the best candidate (no full population).
|
||||
- Simplifies vs GEPA (no Pareto, no merge).
|
||||
- Population support deferred to v2.
|
||||
- population_size=1: classic single-candidate hill climbing (backward compat).
|
||||
- population_size>1: population-based evolution with crossover, mutation,
|
||||
and diversity maintenance.
|
||||
|
||||
Error handling:
|
||||
- Transient errors are retried by adapters.
|
||||
- Circuit breaker trips after N consecutive iteration failures.
|
||||
- error_strategy controls what happens on non-transient errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -40,6 +58,19 @@ class EvolutionLoop:
|
||||
minibatch_size: int = 5,
|
||||
perfect_score: float = 1.0,
|
||||
verbose: bool = False,
|
||||
circuit_breaker_threshold: int = 5,
|
||||
error_strategy: str = "retry",
|
||||
checkpoint_port: CheckpointPort | None = None,
|
||||
checkpoint_interval: int = 5,
|
||||
# --- Population-based evolution params ---
|
||||
population_size: int = 1,
|
||||
crossover_rate: float = 0.5,
|
||||
mutation_rate: float = 0.3,
|
||||
diversity_penalty: float = 0.1,
|
||||
crossover_port: CrossoverPort | None = None,
|
||||
mutation_port: MutationPort | None = None,
|
||||
# --- Hold-out validation params ---
|
||||
early_stop_patience: int = 5,
|
||||
):
|
||||
self._evaluator = evaluator
|
||||
self._proposer = proposer
|
||||
@@ -48,127 +79,663 @@ class EvolutionLoop:
|
||||
self._minibatch_size = minibatch_size
|
||||
self._perfect_score = perfect_score
|
||||
self._verbose = verbose
|
||||
self._circuit_breaker_threshold = circuit_breaker_threshold
|
||||
self._error_strategy = error_strategy
|
||||
self._checkpoint_port = checkpoint_port
|
||||
self._checkpoint_interval = checkpoint_interval
|
||||
self._population_size = population_size
|
||||
self._crossover_rate = crossover_rate
|
||||
self._mutation_rate = mutation_rate
|
||||
self._diversity_penalty = diversity_penalty
|
||||
self._crossover_port = crossover_port
|
||||
self._mutation_port = mutation_port
|
||||
self._early_stop_patience = early_stop_patience
|
||||
|
||||
def run(
|
||||
async def run(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
initial_state: OptimizationState | None = None,
|
||||
validation_pool: list[SyntheticExample] | None = None,
|
||||
) -> OptimizationState:
|
||||
"""Execute the complete evolution loop."""
|
||||
state = OptimizationState()
|
||||
"""Execute the complete evolution loop.
|
||||
|
||||
# Evaluate the seed
|
||||
initial_batch = self._bootstrap.sample_minibatch(
|
||||
synthetic_pool, self._minibatch_size
|
||||
)
|
||||
initial_eval = self._evaluator.evaluate(
|
||||
seed_prompt, initial_batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * self._minibatch_size # N executions + N judge calls
|
||||
If *initial_state* is provided (from a checkpoint), resume from that
|
||||
point — skipping the seed evaluation and continuing at the saved iteration.
|
||||
|
||||
best_candidate = Candidate(
|
||||
prompt=seed_prompt,
|
||||
best_score=initial_eval.total_score,
|
||||
generation=0,
|
||||
)
|
||||
state.best_candidate = best_candidate
|
||||
state.candidates.append(best_candidate)
|
||||
self._log(f"Initial score: {initial_eval.total_score:.2f}")
|
||||
If *validation_pool* is provided (non-empty), the best candidate is
|
||||
evaluated on the hold-out set after each iteration and early stopping
|
||||
is applied when validation score degrades for ``early_stop_patience``
|
||||
consecutive iterations.
|
||||
"""
|
||||
state = initial_state or OptimizationState()
|
||||
consecutive_failures = 0
|
||||
|
||||
# Hold-out validation tracking
|
||||
has_validation = bool(validation_pool)
|
||||
best_validation_score: float = -1.0
|
||||
validation_patience_counter: int = 0
|
||||
|
||||
# Only evaluate the seed when starting fresh (no checkpoint resume)
|
||||
if initial_state is None:
|
||||
initial_batch = self._bootstrap.sample_minibatch(
|
||||
synthetic_pool, self._minibatch_size
|
||||
)
|
||||
initial_eval = await self._evaluator.evaluate(
|
||||
seed_prompt, initial_batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * self._minibatch_size # N executions + N judge calls
|
||||
|
||||
seed_candidate = Candidate(
|
||||
prompt=seed_prompt,
|
||||
best_score=initial_eval.total_score,
|
||||
generation=0,
|
||||
)
|
||||
state.best_candidate = seed_candidate
|
||||
state.candidates.append(seed_candidate)
|
||||
logger.info(
|
||||
"Initial evaluation complete",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "initial_eval",
|
||||
"score": round(initial_eval.total_score, 4),
|
||||
"minibatch_size": self._minibatch_size,
|
||||
"sample_ids": [ex.id for ex in initial_batch],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Evaluate seed on validation set
|
||||
if has_validation and state.best_candidate is not None:
|
||||
val_eval = await self._evaluator.evaluate(
|
||||
state.best_candidate.prompt, validation_pool, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * len(validation_pool)
|
||||
best_validation_score = val_eval.mean_score
|
||||
logger.info(
|
||||
"Initial validation evaluation",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "validation_eval",
|
||||
"iteration": 0,
|
||||
"validation_score": round(best_validation_score, 4),
|
||||
"validation_pool_size": len(validation_pool),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Population initialization: seed the population with mutations
|
||||
if self._population_size > 1:
|
||||
await self._initialize_population(
|
||||
state, seed_prompt, seed_candidate, task_description
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Resuming from checkpoint",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "resume",
|
||||
"iteration": state.iteration,
|
||||
"total_llm_calls": state.total_llm_calls,
|
||||
},
|
||||
},
|
||||
)
|
||||
# Restore validation tracking from state history
|
||||
if has_validation:
|
||||
for entry in reversed(state.history):
|
||||
if entry.get("event") == "validation_eval":
|
||||
best_validation_score = entry.get("best_validation_score", -1.0)
|
||||
validation_patience_counter = entry.get("validation_patience", 0)
|
||||
break
|
||||
|
||||
# Determine starting iteration
|
||||
start_iteration = state.iteration + 1
|
||||
|
||||
# Main loop
|
||||
for i in range(1, self._max_iterations + 1):
|
||||
for i in range(start_iteration, self._max_iterations + 1):
|
||||
state.iteration = i
|
||||
|
||||
try:
|
||||
# 1. Sample a fresh minibatch
|
||||
batch = self._bootstrap.sample_minibatch(
|
||||
synthetic_pool, self._minibatch_size
|
||||
)
|
||||
|
||||
# 2. Evaluate the current candidate
|
||||
current_eval = self._evaluator.evaluate(
|
||||
best_candidate.prompt, batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * self._minibatch_size
|
||||
|
||||
# 3. Skip if perfect
|
||||
if all(s >= self._perfect_score for s in current_eval.scores):
|
||||
self._log(f"Iter {i}: All scores perfect, skipping.")
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "skip_perfect",
|
||||
"current_score": current_eval.total_score,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 4. Propose a new prompt (reflective mutation)
|
||||
new_prompt = self._proposer.propose(
|
||||
best_candidate.prompt,
|
||||
current_eval.trajectories,
|
||||
task_description,
|
||||
)
|
||||
state.total_llm_calls += 1 # 1 proposition call
|
||||
|
||||
# 5. Evaluate the new prompt on the same minibatch
|
||||
new_eval = self._evaluator.evaluate(
|
||||
new_prompt, batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * self._minibatch_size
|
||||
|
||||
# 6. Accept or reject
|
||||
if should_accept(current_eval, new_eval):
|
||||
best_candidate = Candidate(
|
||||
prompt=new_prompt,
|
||||
best_score=new_eval.total_score,
|
||||
generation=i,
|
||||
parent_id=id(best_candidate),
|
||||
)
|
||||
state.best_candidate = best_candidate
|
||||
state.candidates.append(best_candidate)
|
||||
self._log(
|
||||
f"Iter {i}: ACCEPTED "
|
||||
f"({current_eval.total_score:.2f} -> {new_eval.total_score:.2f})"
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "accepted",
|
||||
"old_score": current_eval.total_score,
|
||||
"new_score": new_eval.total_score,
|
||||
"improvement": new_eval.total_score
|
||||
- current_eval.total_score,
|
||||
}
|
||||
if self._population_size > 1 and len(state.candidates) > 1:
|
||||
await self._run_population_iteration(
|
||||
i, state, synthetic_pool, task_description
|
||||
)
|
||||
else:
|
||||
self._log(
|
||||
f"Iter {i}: REJECTED "
|
||||
f"({new_eval.total_score:.2f} <= {current_eval.total_score:.2f})"
|
||||
await self._run_single_iteration(
|
||||
i, state, synthetic_pool, task_description
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "rejected",
|
||||
"old_score": current_eval.total_score,
|
||||
"new_score": new_eval.total_score,
|
||||
}
|
||||
consecutive_failures = 0
|
||||
|
||||
# Hold-out validation: evaluate best candidate on validation set
|
||||
if has_validation and state.best_candidate is not None:
|
||||
val_eval = await self._evaluator.evaluate(
|
||||
state.best_candidate.prompt, validation_pool, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * len(validation_pool)
|
||||
current_val_score = val_eval.mean_score
|
||||
|
||||
if current_val_score > best_validation_score:
|
||||
best_validation_score = current_val_score
|
||||
validation_patience_counter = 0
|
||||
else:
|
||||
validation_patience_counter += 1
|
||||
|
||||
state.history.append({
|
||||
"iteration": i,
|
||||
"event": "validation_eval",
|
||||
"validation_score": round(current_val_score, 4),
|
||||
"best_validation_score": round(best_validation_score, 4),
|
||||
"validation_patience": validation_patience_counter,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Validation evaluation",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "validation_eval",
|
||||
"iteration": i,
|
||||
"validation_score": round(current_val_score, 4),
|
||||
"best_validation_score": round(best_validation_score, 4),
|
||||
"patience": f"{validation_patience_counter}/{self._early_stop_patience}",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if validation_patience_counter >= self._early_stop_patience:
|
||||
logger.warning(
|
||||
"Early stopping triggered — validation score did not improve for %d iterations",
|
||||
self._early_stop_patience,
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "early_stop",
|
||||
"iteration": i,
|
||||
"best_validation_score": round(best_validation_score, 4),
|
||||
"patience": self._early_stop_patience,
|
||||
},
|
||||
},
|
||||
)
|
||||
state.history.append({
|
||||
"iteration": i,
|
||||
"event": "early_stop",
|
||||
"best_validation_score": round(best_validation_score, 4),
|
||||
"patience": self._early_stop_patience,
|
||||
})
|
||||
state.best_validation_score = best_validation_score
|
||||
state.early_stopped = True
|
||||
if self._checkpoint_port is not None:
|
||||
self._checkpoint_port.save(state)
|
||||
break
|
||||
|
||||
# Checkpoint on accepted improvement (detected via state change)
|
||||
self._maybe_checkpoint(state)
|
||||
|
||||
except Exception as exc:
|
||||
self._log(f"Iter {i}: ERROR — {exc}. Skipping iteration.")
|
||||
consecutive_failures += 1
|
||||
logger.error(
|
||||
"Iteration error",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "iteration_error",
|
||||
"iteration": i,
|
||||
"consecutive_failures": consecutive_failures,
|
||||
"error": str(exc),
|
||||
},
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "error",
|
||||
"error": str(exc),
|
||||
"consecutive_failures": consecutive_failures,
|
||||
}
|
||||
)
|
||||
|
||||
# Check circuit breaker
|
||||
if consecutive_failures >= self._circuit_breaker_threshold:
|
||||
logger.warning(
|
||||
"Circuit breaker tripped",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "circuit_breaker",
|
||||
"iteration": i,
|
||||
"consecutive_failures": consecutive_failures,
|
||||
"error_strategy": self._error_strategy,
|
||||
},
|
||||
},
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "circuit_breaker",
|
||||
"consecutive_failures": consecutive_failures,
|
||||
}
|
||||
)
|
||||
if self._error_strategy == "abort":
|
||||
raise CircuitBreakerOpen(
|
||||
f"Circuit breaker tripped after "
|
||||
f"{consecutive_failures} consecutive failures"
|
||||
) from exc
|
||||
# skip / retry strategies: save checkpoint, then stop the loop gracefully
|
||||
if self._checkpoint_port is not None:
|
||||
self._checkpoint_port.save(state)
|
||||
break
|
||||
|
||||
if self._error_strategy == "abort":
|
||||
raise
|
||||
# skip / retry: continue to next iteration
|
||||
continue
|
||||
|
||||
# Store final validation metadata on state
|
||||
if has_validation:
|
||||
state.best_validation_score = best_validation_score
|
||||
|
||||
return state
|
||||
|
||||
def _log(self, msg: str) -> None:
|
||||
if self._verbose:
|
||||
logger.info("[PROMETHEUS] %s", msg)
|
||||
# ------------------------------------------------------------------
|
||||
# Population initialization
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _initialize_population(
|
||||
self,
|
||||
state: OptimizationState,
|
||||
seed_prompt: Prompt,
|
||||
seed_candidate: Candidate,
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""Fill the population with mutated variants of the seed prompt."""
|
||||
n_needed = self._population_size - 1
|
||||
mutation_types = ["paraphrase", "constrain", "generalize", "specialize"]
|
||||
|
||||
for idx in range(n_needed):
|
||||
mutation_type = mutation_types[idx % len(mutation_types)]
|
||||
|
||||
if self._mutation_port is not None:
|
||||
new_prompt = await self._mutation_port.mutate(
|
||||
seed_prompt, task_description, mutation_type
|
||||
)
|
||||
else:
|
||||
# Fallback: use proposer for reflective mutation
|
||||
new_prompt = await self._proposer.propose(
|
||||
seed_prompt, [], task_description
|
||||
)
|
||||
state.total_llm_calls += 1
|
||||
new_candidate = Candidate(
|
||||
prompt=new_prompt,
|
||||
best_score=seed_candidate.best_score, # estimate until evaluated
|
||||
generation=0,
|
||||
parent_id=id(seed_candidate),
|
||||
)
|
||||
state.candidates.append(new_candidate)
|
||||
|
||||
logger.info(
|
||||
"Population initialized",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "population_init",
|
||||
"population_size": len(state.candidates),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Single-candidate iteration (original hill-climbing)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _run_single_iteration(
|
||||
self,
|
||||
i: int,
|
||||
state: OptimizationState,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""Execute a single-candidate iteration. Mutates *state* in-place."""
|
||||
best_candidate = state.best_candidate # type: ignore[assignment]
|
||||
|
||||
# 1. Sample a fresh minibatch
|
||||
batch = self._bootstrap.sample_minibatch(
|
||||
synthetic_pool, self._minibatch_size
|
||||
)
|
||||
sample_ids = [ex.id for ex in batch]
|
||||
|
||||
# 2. Evaluate the current candidate
|
||||
current_eval = await self._evaluator.evaluate(
|
||||
best_candidate.prompt, batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * self._minibatch_size
|
||||
|
||||
logger.debug(
|
||||
"Iteration minibatch evaluated",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "minibatch_eval",
|
||||
"iteration": i,
|
||||
"sample_ids": sample_ids,
|
||||
"scores": [round(s, 4) for s in current_eval.scores],
|
||||
"total_score": round(current_eval.total_score, 4),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# 3. Skip if perfect
|
||||
if all(s >= self._perfect_score for s in current_eval.scores):
|
||||
logger.info(
|
||||
"Iteration skipped — all scores perfect",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "skip_perfect",
|
||||
"iteration": i,
|
||||
"total_score": round(current_eval.total_score, 4),
|
||||
},
|
||||
},
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "skip_perfect",
|
||||
"current_score": current_eval.total_score,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# 4. Propose a new prompt (reflective mutation) — sequential
|
||||
state.total_llm_calls += 1
|
||||
new_prompt = await self._proposer.propose(
|
||||
best_candidate.prompt,
|
||||
current_eval.trajectories,
|
||||
task_description,
|
||||
)
|
||||
|
||||
prompt_diff = self._compute_prompt_diff(
|
||||
best_candidate.prompt.text, new_prompt.text
|
||||
)
|
||||
logger.debug(
|
||||
"Proposed new prompt",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "proposer_output",
|
||||
"iteration": i,
|
||||
"prompt_diff": prompt_diff,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# 5. Evaluate the new prompt on the same minibatch
|
||||
new_eval = await self._evaluator.evaluate(
|
||||
new_prompt, batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * self._minibatch_size
|
||||
|
||||
# 6. Accept or reject
|
||||
if should_accept(current_eval, new_eval):
|
||||
new_candidate = Candidate(
|
||||
prompt=new_prompt,
|
||||
best_score=new_eval.total_score,
|
||||
generation=i,
|
||||
parent_id=id(best_candidate),
|
||||
)
|
||||
state.best_candidate = new_candidate
|
||||
state.candidates.append(new_candidate)
|
||||
logger.info(
|
||||
"Iteration accepted",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "accepted",
|
||||
"iteration": i,
|
||||
"old_score": round(current_eval.total_score, 4),
|
||||
"new_score": round(new_eval.total_score, 4),
|
||||
"improvement": round(
|
||||
new_eval.total_score - current_eval.total_score, 4
|
||||
),
|
||||
"sample_ids": sample_ids,
|
||||
"new_scores": [round(s, 4) for s in new_eval.scores],
|
||||
"prompt_diff": prompt_diff,
|
||||
},
|
||||
},
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "accepted",
|
||||
"old_score": current_eval.total_score,
|
||||
"new_score": new_eval.total_score,
|
||||
"improvement": new_eval.total_score
|
||||
- current_eval.total_score,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Iteration rejected",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "rejected",
|
||||
"iteration": i,
|
||||
"old_score": round(current_eval.total_score, 4),
|
||||
"new_score": round(new_eval.total_score, 4),
|
||||
"sample_ids": sample_ids,
|
||||
"new_scores": [round(s, 4) for s in new_eval.scores],
|
||||
"prompt_diff": prompt_diff,
|
||||
},
|
||||
},
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "rejected",
|
||||
"old_score": current_eval.total_score,
|
||||
"new_score": new_eval.total_score,
|
||||
}
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Population-based iteration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _run_population_iteration(
|
||||
self,
|
||||
i: int,
|
||||
state: OptimizationState,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""Execute a population-based iteration. Mutates *state* in-place."""
|
||||
population = state.candidates
|
||||
|
||||
# 1. Sample a fresh minibatch
|
||||
batch = self._bootstrap.sample_minibatch(
|
||||
synthetic_pool, self._minibatch_size
|
||||
)
|
||||
sample_ids = [ex.id for ex in batch]
|
||||
|
||||
# 2. Select two parents via tournament selection
|
||||
parent_a = self._tournament_select(population)
|
||||
parent_b = self._tournament_select(population)
|
||||
|
||||
# 3. Generate child: crossover or reflective mutation
|
||||
use_crossover = (
|
||||
random.random() < self._crossover_rate
|
||||
and self._crossover_port is not None
|
||||
)
|
||||
if use_crossover:
|
||||
state.total_llm_calls += 1
|
||||
child_prompt = await self._crossover_port.crossover(
|
||||
parent_a.prompt, parent_b.prompt, task_description
|
||||
)
|
||||
origin = "crossover"
|
||||
else:
|
||||
# Reflective mutation: evaluate a parent, propose improvement
|
||||
state.total_llm_calls += 2 * self._minibatch_size
|
||||
parent_eval = await self._evaluator.evaluate(
|
||||
parent_a.prompt, batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 1
|
||||
child_prompt = await self._proposer.propose(
|
||||
parent_a.prompt,
|
||||
parent_eval.trajectories,
|
||||
task_description,
|
||||
)
|
||||
origin = "reflective"
|
||||
|
||||
# 4. Optional mutation
|
||||
if random.random() < self._mutation_rate and self._mutation_port is not None:
|
||||
mutation_type = random.choice(
|
||||
["paraphrase", "constrain", "generalize", "specialize"]
|
||||
)
|
||||
state.total_llm_calls += 1
|
||||
child_prompt = await self._mutation_port.mutate(
|
||||
child_prompt, task_description, mutation_type
|
||||
)
|
||||
origin += f"+mutation({mutation_type})"
|
||||
|
||||
# 5. Evaluate the child
|
||||
state.total_llm_calls += 2 * self._minibatch_size
|
||||
child_eval = await self._evaluator.evaluate(
|
||||
child_prompt, batch, task_description
|
||||
)
|
||||
|
||||
# 6. Compute fitness with diversity penalty
|
||||
child_score = child_eval.total_score
|
||||
diversity_sim = self._compute_diversity_score(
|
||||
child_prompt, population
|
||||
)
|
||||
child_fitness = child_score - self._diversity_penalty * (1.0 - diversity_sim)
|
||||
|
||||
# 7. Find worst candidate and replace if child is better
|
||||
worst_idx = min(
|
||||
range(len(population)),
|
||||
key=lambda idx: population[idx].best_score,
|
||||
)
|
||||
worst_fitness = (
|
||||
population[worst_idx].best_score
|
||||
- self._diversity_penalty * (1.0 - self._compute_diversity_score(
|
||||
population[worst_idx].prompt, population
|
||||
))
|
||||
)
|
||||
|
||||
accepted = child_fitness > worst_fitness
|
||||
|
||||
if accepted:
|
||||
new_candidate = Candidate(
|
||||
prompt=child_prompt,
|
||||
best_score=child_score,
|
||||
generation=i,
|
||||
parent_id=id(parent_a),
|
||||
)
|
||||
population[worst_idx] = new_candidate
|
||||
|
||||
# Update best if this child is the new best
|
||||
if child_score > (state.best_candidate.best_score if state.best_candidate else 0):
|
||||
state.best_candidate = new_candidate
|
||||
|
||||
logger.info(
|
||||
"Population iteration accepted",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "pop_accepted",
|
||||
"iteration": i,
|
||||
"origin": origin,
|
||||
"child_score": round(child_score, 4),
|
||||
"child_fitness": round(child_fitness, 4),
|
||||
"diversity_sim": round(diversity_sim, 4),
|
||||
"replaced_idx": worst_idx,
|
||||
"sample_ids": sample_ids,
|
||||
},
|
||||
},
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "pop_accepted",
|
||||
"origin": origin,
|
||||
"child_score": child_score,
|
||||
"child_fitness": child_fitness,
|
||||
"diversity_sim": diversity_sim,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Population iteration rejected",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "pop_rejected",
|
||||
"iteration": i,
|
||||
"origin": origin,
|
||||
"child_score": round(child_score, 4),
|
||||
"child_fitness": round(child_fitness, 4),
|
||||
"worst_fitness": round(worst_fitness, 4),
|
||||
"sample_ids": sample_ids,
|
||||
},
|
||||
},
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "pop_rejected",
|
||||
"origin": origin,
|
||||
"child_score": child_score,
|
||||
"child_fitness": child_fitness,
|
||||
"worst_fitness": worst_fitness,
|
||||
}
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Selection and diversity helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _tournament_select(
|
||||
self,
|
||||
population: list[Candidate],
|
||||
tournament_size: int = 3,
|
||||
) -> Candidate:
|
||||
"""Tournament selection: pick the best from a random subset."""
|
||||
k = min(tournament_size, len(population))
|
||||
contestants = random.sample(population, k)
|
||||
return max(contestants, key=lambda c: c.best_score)
|
||||
|
||||
@staticmethod
|
||||
def _compute_diversity_score(
|
||||
prompt: Prompt,
|
||||
population: list[Candidate],
|
||||
) -> float:
|
||||
"""Compute the average Jaccard similarity between *prompt* and all
|
||||
population members. Returns 1.0 when population has only one member
|
||||
(no diversity penalty)."""
|
||||
if len(population) <= 1:
|
||||
return 1.0
|
||||
|
||||
prompt_words = set(prompt.text.lower().split())
|
||||
if not prompt_words:
|
||||
return 0.0
|
||||
|
||||
similarities: list[float] = []
|
||||
for candidate in population:
|
||||
other_words = set(candidate.prompt.text.lower().split())
|
||||
if not other_words:
|
||||
continue
|
||||
intersection = prompt_words & other_words
|
||||
union = prompt_words | other_words
|
||||
sim = len(intersection) / len(union) if union else 0.0
|
||||
similarities.append(sim)
|
||||
|
||||
# Average similarity (lower = more diverse)
|
||||
return sum(similarities) / len(similarities) if similarities else 0.0
|
||||
|
||||
@staticmethod
|
||||
def _compute_prompt_diff(old: str, new: str) -> dict[str, int]:
|
||||
"""Compute a simple diff summary between two prompts."""
|
||||
old_lines = set(old.splitlines())
|
||||
new_lines = set(new.splitlines())
|
||||
return {
|
||||
"lines_added": len(new_lines - old_lines),
|
||||
"lines_removed": len(old_lines - new_lines),
|
||||
"chars_delta": len(new) - len(old),
|
||||
}
|
||||
|
||||
def _maybe_checkpoint(self, state: OptimizationState) -> None:
|
||||
"""Save a checkpoint if the interval is met or on accepted improvements."""
|
||||
if self._checkpoint_port is None:
|
||||
return
|
||||
if state.iteration % self._checkpoint_interval == 0:
|
||||
self._checkpoint_port.save(state)
|
||||
|
||||
116
src/prometheus/application/ground_truth_evaluator.py
Normal file
116
src/prometheus/application/ground_truth_evaluator.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Ground-truth evaluator — execution + similarity comparison.
|
||||
|
||||
Produces a quality signal *with* ground truth by comparing model outputs
|
||||
against expected outputs using a configurable similarity metric.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from prometheus.domain.entities import (
|
||||
EvalResult,
|
||||
GroundTruthExample,
|
||||
Prompt,
|
||||
Trajectory,
|
||||
)
|
||||
from prometheus.domain.ports import LLMPort, SimilarityPort
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GroundTruthEvaluator:
|
||||
"""Evaluates a prompt against a ground-truth dataset.
|
||||
|
||||
Pipeline: execute → compare with similarity metric → build trajectories.
|
||||
Unlike PromptEvaluator (which uses LLM-as-Judge), this compares outputs
|
||||
directly against known-good expected outputs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: LLMPort,
|
||||
similarity: SimilarityPort,
|
||||
max_concurrency: int = 5,
|
||||
):
|
||||
self._executor = executor
|
||||
self._similarity = similarity
|
||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
dataset: list[GroundTruthExample],
|
||||
) -> EvalResult:
|
||||
"""Evaluate the prompt on the ground-truth dataset.
|
||||
|
||||
Steps:
|
||||
1. Execute the prompt on each input (parallel, bounded)
|
||||
2. Compare each output against expected using similarity metric
|
||||
3. Build trajectories with feedback
|
||||
"""
|
||||
# Step 1: Parallel execution (per-item isolation)
|
||||
output_coros = [
|
||||
self._execute_single(prompt, example) for example in dataset
|
||||
]
|
||||
outputs = await asyncio.gather(*output_coros)
|
||||
|
||||
# Step 2: Compute similarity scores
|
||||
scores: list[float] = []
|
||||
feedbacks: list[str] = []
|
||||
trajectories: list[Trajectory] = []
|
||||
|
||||
for example, output in zip(dataset, outputs):
|
||||
score = self._similarity.compute(output, example.expected_output)
|
||||
score = max(0.0, min(1.0, score)) # normalize to [0, 1]
|
||||
scores.append(score)
|
||||
feedback = self._build_feedback(output, example.expected_output, score)
|
||||
feedbacks.append(feedback)
|
||||
trajectories.append(
|
||||
Trajectory(
|
||||
input_text=example.input_text,
|
||||
output_text=output,
|
||||
score=score,
|
||||
feedback=feedback,
|
||||
prompt_used=prompt.text,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Ground-truth evaluation complete: %d items, mean_score=%.4f",
|
||||
len(dataset),
|
||||
sum(scores) / len(scores) if scores else 0.0,
|
||||
)
|
||||
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=feedbacks,
|
||||
trajectories=trajectories,
|
||||
)
|
||||
|
||||
async def _execute_single(
|
||||
self, prompt: Prompt, example: GroundTruthExample
|
||||
) -> str:
|
||||
async with self._semaphore:
|
||||
try:
|
||||
return await self._executor.execute(prompt, example.input_text)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Execution failed for input '%s…': %s",
|
||||
example.input_text[:40],
|
||||
exc,
|
||||
)
|
||||
return f"[execution error: {exc}]"
|
||||
|
||||
@staticmethod
|
||||
def _build_feedback(output: str, expected: str, score: float) -> str:
|
||||
"""Build human-readable feedback for a ground-truth comparison."""
|
||||
if score >= 0.99:
|
||||
return "Exact match."
|
||||
elif score >= 0.7:
|
||||
return f"Close match (score={score:.2f}). Expected: {expected[:100]}"
|
||||
elif score >= 0.3:
|
||||
return f"Partial match (score={score:.2f}). Expected: {expected[:100]}"
|
||||
else:
|
||||
return f"Poor match (score={score:.2f}). Expected: {expected[:100]}"
|
||||
@@ -10,8 +10,16 @@ from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.dto import OptimizationConfig, OptimizationResult
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.cli.logging_setup import get_logger
|
||||
from prometheus.domain.entities import Prompt
|
||||
from prometheus.domain.ports import ProposerPort
|
||||
from prometheus.domain.ports import (
|
||||
CheckpointPort,
|
||||
CrossoverPort,
|
||||
MutationPort,
|
||||
ProposerPort,
|
||||
)
|
||||
|
||||
logger = get_logger("use_cases")
|
||||
|
||||
|
||||
class OptimizePromptUseCase:
|
||||
@@ -25,24 +33,60 @@ class OptimizePromptUseCase:
|
||||
evaluator: PromptEvaluator,
|
||||
proposer: ProposerPort,
|
||||
bootstrap: SyntheticBootstrap,
|
||||
checkpoint_port: CheckpointPort | None = None,
|
||||
crossover_port: CrossoverPort | None = None,
|
||||
mutation_port: MutationPort | None = None,
|
||||
):
|
||||
self._evaluator = evaluator
|
||||
self._proposer = proposer
|
||||
self._bootstrap = bootstrap
|
||||
self._checkpoint_port = checkpoint_port
|
||||
self._crossover_port = crossover_port
|
||||
self._mutation_port = mutation_port
|
||||
|
||||
def execute(self, config: OptimizationConfig) -> OptimizationResult:
|
||||
async def execute(self, config: OptimizationConfig) -> OptimizationResult:
|
||||
"""Full pipeline:
|
||||
1. Bootstrap → generate synthetic inputs
|
||||
2. Evolution → optimization loop
|
||||
2. Evolution → optimization loop (with optional checkpoint resume)
|
||||
3. Return result
|
||||
"""
|
||||
# Phase 0: Bootstrap
|
||||
synthetic_pool = self._bootstrap.run(
|
||||
task_description=config.task_description,
|
||||
n_examples=config.n_synthetic_inputs,
|
||||
)
|
||||
# Phase 0: Bootstrap (skip synthetic generation on resume if pool was saved)
|
||||
initial_state = None
|
||||
if config.resume and self._checkpoint_port is not None:
|
||||
initial_state = self._checkpoint_port.load()
|
||||
if initial_state is not None and initial_state.synthetic_pool:
|
||||
synthetic_pool = initial_state.synthetic_pool
|
||||
logger.info(
|
||||
"Resumed checkpoint includes %d synthetic inputs — skipping bootstrap",
|
||||
extra={"structured": {"event": "resume_skip_bootstrap", "pool_size": len(synthetic_pool)}},
|
||||
)
|
||||
else:
|
||||
synthetic_pool = self._bootstrap.run(
|
||||
task_description=config.task_description,
|
||||
n_examples=config.n_synthetic_inputs,
|
||||
)
|
||||
else:
|
||||
synthetic_pool = self._bootstrap.run(
|
||||
task_description=config.task_description,
|
||||
n_examples=config.n_synthetic_inputs,
|
||||
)
|
||||
|
||||
# Phase 1: Evolution
|
||||
# Split into train / validation if configured
|
||||
validation_pool: list = []
|
||||
if config.validation_split > 0:
|
||||
synthetic_pool, validation_pool = SyntheticBootstrap.split_pool(
|
||||
synthetic_pool, config.validation_split,
|
||||
)
|
||||
logger.info(
|
||||
"Split synthetic pool: %d train, %d validation (%.0f%% hold-out)",
|
||||
len(synthetic_pool), len(validation_pool),
|
||||
config.validation_split * 100,
|
||||
extra={"structured": {
|
||||
"event": "pool_split",
|
||||
"train_size": len(synthetic_pool),
|
||||
"val_size": len(validation_pool),
|
||||
}},
|
||||
)
|
||||
loop = EvolutionLoop(
|
||||
evaluator=self._evaluator,
|
||||
proposer=self._proposer,
|
||||
@@ -51,9 +95,24 @@ class OptimizePromptUseCase:
|
||||
minibatch_size=config.minibatch_size,
|
||||
perfect_score=config.perfect_score,
|
||||
verbose=config.verbose,
|
||||
circuit_breaker_threshold=config.circuit_breaker_threshold,
|
||||
error_strategy=config.error_strategy,
|
||||
checkpoint_port=self._checkpoint_port,
|
||||
checkpoint_interval=config.checkpoint_interval,
|
||||
population_size=config.population_size,
|
||||
crossover_rate=config.crossover_rate,
|
||||
mutation_rate=config.mutation_rate,
|
||||
diversity_penalty=config.diversity_penalty,
|
||||
crossover_port=self._crossover_port,
|
||||
mutation_port=self._mutation_port,
|
||||
early_stop_patience=config.early_stop_patience,
|
||||
)
|
||||
seed_prompt = Prompt(text=config.seed_prompt)
|
||||
state = loop.run(seed_prompt, synthetic_pool, config.task_description)
|
||||
state = await loop.run(
|
||||
seed_prompt, synthetic_pool, config.task_description,
|
||||
initial_state=initial_state,
|
||||
validation_pool=validation_pool or None,
|
||||
)
|
||||
|
||||
# Phase 2: Result
|
||||
initial_score = (
|
||||
@@ -69,9 +128,12 @@ class OptimizePromptUseCase:
|
||||
),
|
||||
initial_prompt=config.seed_prompt,
|
||||
iterations_used=state.iteration,
|
||||
total_llm_calls=state.total_llm_calls + 1, # +1 for bootstrap
|
||||
total_llm_calls=state.total_llm_calls + 1, # +1 for bootstrap synthesis call
|
||||
initial_score=initial_score,
|
||||
final_score=final_score,
|
||||
improvement=final_score - initial_score,
|
||||
history=state.history,
|
||||
final_validation_score=state.best_validation_score,
|
||||
best_validation_score=state.best_validation_score,
|
||||
early_stopped=state.early_stopped,
|
||||
)
|
||||
|
||||
@@ -1,29 +1,13 @@
|
||||
"""
|
||||
CLI — user entry point.
|
||||
|
||||
Typer interface with -i (input) and -o (output) options.
|
||||
Registers all subcommands and delegates to cli/commands/.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
|
||||
import dspy
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.dto import OptimizationConfig, OptimizationResult
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.use_cases import OptimizePromptUseCase
|
||||
from prometheus.infrastructure.file_io import YamlPersistence
|
||||
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
||||
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
||||
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
||||
from prometheus.cli.commands import init, list_runs, optimize, version
|
||||
|
||||
app = typer.Typer(
|
||||
name="prometheus",
|
||||
@@ -31,137 +15,12 @@ app = typer.Typer(
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@app.command()
|
||||
def optimize(
|
||||
input: str = typer.Option(
|
||||
...,
|
||||
"-i",
|
||||
"--input",
|
||||
help="Path to input YAML config file.",
|
||||
exists=True,
|
||||
readable=True,
|
||||
),
|
||||
output: str = typer.Option(
|
||||
"output.yaml",
|
||||
"-o",
|
||||
"--output",
|
||||
help="Path to output YAML result file.",
|
||||
),
|
||||
verbose: bool = typer.Option(
|
||||
False,
|
||||
"-v",
|
||||
"--verbose",
|
||||
help="Print detailed progress.",
|
||||
),
|
||||
) -> None:
|
||||
"""Optimize a prompt without any reference data.
|
||||
|
||||
Usage:
|
||||
prometheus optimize -i config.yaml -o result.yaml
|
||||
"""
|
||||
# Configure verbose logging
|
||||
if verbose:
|
||||
logging.basicConfig(level=logging.INFO, format="[PROMETHEUS] %(message)s")
|
||||
|
||||
console.print(
|
||||
Panel.fit(
|
||||
"PROMETHEUS — Prompt Evolution Engine",
|
||||
subtitle="No reference data required",
|
||||
)
|
||||
)
|
||||
|
||||
# 1. Load config
|
||||
persistence = YamlPersistence()
|
||||
raw_config = persistence.read_config(input)
|
||||
config = OptimizationConfig(
|
||||
seed_prompt=raw_config["seed_prompt"],
|
||||
task_description=raw_config["task_description"],
|
||||
task_model=raw_config.get("task_model", "openai/gpt-4o-mini"),
|
||||
judge_model=raw_config.get("judge_model", "openai/gpt-4o"),
|
||||
proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"),
|
||||
synth_model=raw_config.get("synth_model", "openai/gpt-4o"),
|
||||
max_iterations=raw_config.get("max_iterations", 30),
|
||||
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
|
||||
minibatch_size=raw_config.get("minibatch_size", 5),
|
||||
seed=raw_config.get("seed", 42),
|
||||
output_path=output,
|
||||
verbose=verbose,
|
||||
)
|
||||
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
||||
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
||||
|
||||
# 2. Configure DSPy with optional api_base/api_key from config
|
||||
lm_kwargs: dict = {}
|
||||
api_base = raw_config.get("api_base")
|
||||
api_key_env = raw_config.get("api_key_env")
|
||||
if api_base:
|
||||
lm_kwargs["api_base"] = api_base
|
||||
if api_key_env:
|
||||
lm_kwargs["api_key"] = os.environ.get(api_key_env, "")
|
||||
task_lm = dspy.LM(config.task_model, **lm_kwargs)
|
||||
dspy.configure(lm=task_lm)
|
||||
|
||||
# 3. Build adapters (Dependency Injection)
|
||||
synth_adapter = DSPySyntheticAdapter()
|
||||
llm_adapter = DSPyLLMAdapter(model=config.task_model)
|
||||
judge_adapter = DSPyJudgeAdapter()
|
||||
proposer_adapter = DSPyProposerAdapter()
|
||||
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
||||
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter)
|
||||
use_case = OptimizePromptUseCase(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer_adapter,
|
||||
bootstrap=bootstrap,
|
||||
)
|
||||
|
||||
# 4. Execute
|
||||
with console.status("[bold green]Evolving prompt..."):
|
||||
result = use_case.execute(config)
|
||||
|
||||
# 5. Display results
|
||||
_display_result(result)
|
||||
|
||||
# 6. Save
|
||||
_save_result(persistence, output, result)
|
||||
console.print(f"\n[green]Results saved to {output}[/green]")
|
||||
|
||||
|
||||
def _display_result(result: OptimizationResult) -> None:
|
||||
"""Display a Rich summary in the terminal."""
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold green]Optimized Prompt[/bold green]\n\n{result.optimized_prompt}",
|
||||
title="Result",
|
||||
)
|
||||
)
|
||||
table = Table(title="Metrics")
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Value", style="bold")
|
||||
table.add_row("Initial Score", f"{result.initial_score:.2f}")
|
||||
table.add_row("Final Score", f"{result.final_score:.2f}")
|
||||
table.add_row("Improvement", f"{result.improvement:+.2f}")
|
||||
table.add_row("Iterations", str(result.iterations_used))
|
||||
table.add_row("LLM Calls", str(result.total_llm_calls))
|
||||
console.print(table)
|
||||
|
||||
|
||||
def _save_result(
|
||||
persistence: YamlPersistence,
|
||||
path: str,
|
||||
result: OptimizationResult,
|
||||
) -> None:
|
||||
"""Save the result as YAML."""
|
||||
persistence.write_result(path, asdict(result))
|
||||
|
||||
|
||||
@app.command(hidden=True)
|
||||
def _help() -> None:
|
||||
"""Internal placeholder to force multi-command Typer behavior."""
|
||||
pass
|
||||
# Register all subcommands — having multiple commands fixes the
|
||||
# Typer 0.24+ bug where a single-command app absorbs the subcommand.
|
||||
optimize.register(app)
|
||||
version.register(app)
|
||||
init.register(app)
|
||||
list_runs.register(app)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
1
src/prometheus/cli/commands/__init__.py
Normal file
1
src/prometheus/cli/commands/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""CLI command modules."""
|
||||
97
src/prometheus/cli/commands/init.py
Normal file
97
src/prometheus/cli/commands/init.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""prometheus init — scaffold a config YAML interactively."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
_TEMPLATE = """\
|
||||
# PROMETHEUS configuration
|
||||
# Generated by `prometheus init`
|
||||
|
||||
# --- Required ---
|
||||
seed_prompt: {seed_prompt}
|
||||
task_description: {task_description}
|
||||
|
||||
# --- Models ---
|
||||
task_model: {task_model}
|
||||
judge_model: {judge_model}
|
||||
proposer_model: {proposer_model}
|
||||
synth_model: {synth_model}
|
||||
|
||||
# --- Global API settings (optional) ---
|
||||
# api_base: https://api.openai.com/v1
|
||||
# api_key_env: OPENAI_API_KEY
|
||||
|
||||
# --- Evolution parameters ---
|
||||
max_iterations: 30
|
||||
n_synthetic_inputs: 20
|
||||
minibatch_size: 5
|
||||
perfect_score: 1.0
|
||||
seed: 42
|
||||
|
||||
# --- Concurrency ---
|
||||
max_concurrency: 5
|
||||
|
||||
# --- Error handling ---
|
||||
max_retries: 3
|
||||
retry_delay_base: 1.0
|
||||
circuit_breaker_threshold: 5
|
||||
error_strategy: retry
|
||||
"""
|
||||
|
||||
|
||||
def register(app: typer.Typer) -> None:
|
||||
"""Register the init command on the Typer app."""
|
||||
|
||||
@app.command()
|
||||
def init(
|
||||
output: str = typer.Option(
|
||||
"config.yaml",
|
||||
"-o",
|
||||
"--output",
|
||||
help="Path for the generated config file.",
|
||||
),
|
||||
) -> None:
|
||||
"""Interactively scaffold a PROMETHEUS config YAML.
|
||||
|
||||
Prompts for required fields and writes a ready-to-edit config file.
|
||||
"""
|
||||
target = Path(output)
|
||||
|
||||
if target.exists() and not typer.confirm(
|
||||
f"{output} already exists. Overwrite?", default=False
|
||||
):
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
seed_prompt: str = typer.prompt("Seed prompt")
|
||||
task_description: str = typer.prompt("Task description")
|
||||
task_model: str = typer.prompt("Task model", default="openai/gpt-4o-mini")
|
||||
judge_model: str = typer.prompt("Judge model", default="openai/gpt-4o")
|
||||
proposer_model: str = typer.prompt("Proposer model", default="openai/gpt-4o")
|
||||
synth_model: str = typer.prompt("Synth model", default="openai/gpt-4o")
|
||||
|
||||
content = _TEMPLATE.format(
|
||||
seed_prompt=_yaml_string(seed_prompt),
|
||||
task_description=_yaml_string(task_description),
|
||||
task_model=task_model,
|
||||
judge_model=judge_model,
|
||||
proposer_model=proposer_model,
|
||||
synth_model=synth_model,
|
||||
)
|
||||
|
||||
target.write_text(content, encoding="utf-8")
|
||||
console.print(f"[green]Config written to {output}[/green]")
|
||||
console.print("[dim]Edit it as needed, then run: prometheus optimize -i config.yaml[/dim]")
|
||||
|
||||
|
||||
def _yaml_string(value: str) -> str:
|
||||
"""Quote a string for YAML if it contains special characters."""
|
||||
if any(ch in value for ch in (":", "#", "'", '"', "\n", "{", "}", "[", "]", ",")):
|
||||
escaped = value.replace("'", "''")
|
||||
return f"'{escaped}'"
|
||||
return value
|
||||
101
src/prometheus/cli/commands/list_runs.py
Normal file
101
src/prometheus/cli/commands/list_runs.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""prometheus list — list past optimization runs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import glob as globmod
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
console = Console()
|
||||
|
||||
_DEFAULT_PATTERNS = ("output.yaml", "results/*.yaml", "*.result.yaml")
|
||||
|
||||
|
||||
def register(app: typer.Typer) -> None:
|
||||
"""Register the list command on the Typer app."""
|
||||
|
||||
@app.command("list")
|
||||
def list_runs(
|
||||
directory: str = typer.Option(
|
||||
".",
|
||||
"-d",
|
||||
"--directory",
|
||||
help="Directory to scan for result YAML files.",
|
||||
),
|
||||
) -> None:
|
||||
"""List past optimization runs found in result YAML files.
|
||||
|
||||
Scans the given directory for YAML files that look like PROMETHEUS
|
||||
output (they contain 'optimized_prompt' and 'final_score' keys) and
|
||||
displays a summary table.
|
||||
"""
|
||||
base = Path(directory)
|
||||
if not base.is_dir():
|
||||
console.print(f"[red]Directory not found: {directory}[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
runs: list[dict] = []
|
||||
|
||||
for pattern in _DEFAULT_PATTERNS:
|
||||
for path_str in globmod.glob(str(base / pattern), recursive=False):
|
||||
_try_read_run(path_str, runs)
|
||||
|
||||
# Also try nested directories one level deep
|
||||
for path_str in globmod.glob(str(base / "**/*.yaml"), recursive=True):
|
||||
_try_read_run(path_str, runs)
|
||||
|
||||
if not runs:
|
||||
console.print("[dim]No optimization runs found.[/dim]")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
# Deduplicate by path
|
||||
seen: set[str] = set()
|
||||
unique_runs: list[dict] = []
|
||||
for run in runs:
|
||||
if run["path"] not in seen:
|
||||
seen.add(run["path"])
|
||||
unique_runs.append(run)
|
||||
|
||||
table = Table(title="PROMETHEUS Runs")
|
||||
table.add_column("File", style="cyan")
|
||||
table.add_column("Initial", justify="right")
|
||||
table.add_column("Final", justify="right", style="green")
|
||||
table.add_column("Delta", justify="right")
|
||||
table.add_column("Iters", justify="right")
|
||||
table.add_column("Prompt (first 60 chars)", style="dim")
|
||||
|
||||
for run in sorted(unique_runs, key=lambda r: r["path"]):
|
||||
table.add_row(
|
||||
run["path"],
|
||||
f"{run['initial_score']:.2f}",
|
||||
f"{run['final_score']:.2f}",
|
||||
f"{run['improvement']:+.2f}",
|
||||
str(run["iterations"]),
|
||||
run["prompt_preview"],
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def _try_read_run(path_str: str, runs: list[dict]) -> None:
|
||||
"""Try to parse a YAML file as a PROMETHEUS result and append metadata."""
|
||||
try:
|
||||
with open(path_str, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
if "optimized_prompt" not in data or "final_score" not in data:
|
||||
return
|
||||
runs.append({
|
||||
"path": path_str,
|
||||
"initial_score": float(data.get("initial_score", 0.0)),
|
||||
"final_score": float(data.get("final_score", 0.0)),
|
||||
"improvement": float(data.get("improvement", 0.0)),
|
||||
"iterations": int(data.get("iterations_used", 0)),
|
||||
"prompt_preview": str(data.get("optimized_prompt", ""))[:60],
|
||||
})
|
||||
except (OSError, yaml.YAMLError, ValueError, TypeError):
|
||||
pass
|
||||
449
src/prometheus/cli/commands/optimize.py
Normal file
449
src/prometheus/cli/commands/optimize.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""prometheus optimize — run prompt optimization."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
|
||||
import dspy
|
||||
import typer
|
||||
from pydantic import ValidationError
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.dto import JudgeDimension, OptimizationConfig, OptimizationResult
|
||||
from prometheus.domain.entities import EvalResult, Prompt
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.ground_truth_evaluator import GroundTruthEvaluator
|
||||
from prometheus.application.use_cases import OptimizePromptUseCase
|
||||
from prometheus.cli.logging_setup import configure_logging
|
||||
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
|
||||
from prometheus.infrastructure.checkpoint import JsonCheckpointPersistence
|
||||
from prometheus.infrastructure.file_io import YamlPersistence
|
||||
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
||||
from prometheus.infrastructure.crossover_adapter import DSPyCrossoverAdapter
|
||||
from prometheus.infrastructure.mutation_adapter import DSPyMutationAdapter
|
||||
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
||||
from prometheus.infrastructure.similarity import create_similarity_adapter
|
||||
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def register(app: typer.Typer) -> None:
|
||||
"""Register the optimize command on the Typer app."""
|
||||
|
||||
@app.command()
|
||||
def optimize(
|
||||
input: str = typer.Option(
|
||||
...,
|
||||
"-i",
|
||||
"--input",
|
||||
help="Path to input YAML config file.",
|
||||
exists=True,
|
||||
readable=True,
|
||||
),
|
||||
output: str = typer.Option(
|
||||
"output.yaml",
|
||||
"-o",
|
||||
"--output",
|
||||
help="Path to output YAML result file.",
|
||||
),
|
||||
verbose: bool = typer.Option(
|
||||
False,
|
||||
"-v",
|
||||
"--verbose",
|
||||
help="Print detailed progress (INFO level).",
|
||||
),
|
||||
debug: bool = typer.Option(
|
||||
False,
|
||||
"--debug",
|
||||
help="Enable DEBUG-level logging (overrides -v).",
|
||||
),
|
||||
log_format: str = typer.Option(
|
||||
"text",
|
||||
"--log-format",
|
||||
help="Log output format: text | json.",
|
||||
),
|
||||
log_file: str | None = typer.Option(
|
||||
None,
|
||||
"--log-file",
|
||||
help="Optional file path to write logs to.",
|
||||
),
|
||||
max_retries: int = typer.Option(
|
||||
3,
|
||||
"--max-retries",
|
||||
help="Max retry attempts for transient LLM errors (429, timeout, 5xx).",
|
||||
),
|
||||
error_strategy: str = typer.Option(
|
||||
"retry",
|
||||
"--error-strategy",
|
||||
help="How to handle errors: skip | retry | abort.",
|
||||
),
|
||||
max_concurrency: int = typer.Option(
|
||||
5,
|
||||
"--max-concurrency",
|
||||
help="Max parallel LLM calls for minibatch execution and judging.",
|
||||
),
|
||||
eval_dataset: str | None = typer.Option(
|
||||
None,
|
||||
"--eval-dataset",
|
||||
help="Path to a CSV/JSON dataset with 'input' and 'expected_output' columns.",
|
||||
),
|
||||
eval_metric: str = typer.Option(
|
||||
"bleu",
|
||||
"--eval-metric",
|
||||
help="Similarity metric for ground-truth eval: exact | bleu | rouge_l | cosine | llm_judge.",
|
||||
),
|
||||
checkpoint_dir: str | None = typer.Option(
|
||||
None,
|
||||
"--checkpoint-dir",
|
||||
help="Directory for checkpoint files. Enables periodic checkpointing.",
|
||||
),
|
||||
checkpoint_interval: int = typer.Option(
|
||||
5,
|
||||
"--checkpoint-interval",
|
||||
help="Save a checkpoint every N iterations.",
|
||||
),
|
||||
resume: bool = typer.Option(
|
||||
False,
|
||||
"--resume",
|
||||
help="Resume from the latest checkpoint in --checkpoint-dir.",
|
||||
),
|
||||
population_size: int = typer.Option(
|
||||
1,
|
||||
"--population-size",
|
||||
help="Number of candidates in the evolution population. 1 = single-candidate hill climbing.",
|
||||
),
|
||||
crossover_rate: float = typer.Option(
|
||||
0.5,
|
||||
"--crossover-rate",
|
||||
help="Probability of applying crossover vs reflective mutation (0.0–1.0). Only used when --population-size > 1.",
|
||||
),
|
||||
mutation_rate: float = typer.Option(
|
||||
0.3,
|
||||
"--mutation-rate",
|
||||
help="Probability of applying a mutation operator after crossover/proposal (0.0–1.0). Only used when --population-size > 1.",
|
||||
),
|
||||
validation_split: float = typer.Option(
|
||||
0.3,
|
||||
"--validation-split",
|
||||
help="Fraction of synthetic pool reserved for hold-out validation (0.0–0.9). 0 disables validation.",
|
||||
),
|
||||
early_stop_patience: int = typer.Option(
|
||||
5,
|
||||
"--early-stop-patience",
|
||||
help="Stop if validation score does not improve for this many consecutive iterations.",
|
||||
),
|
||||
judge_criteria: str | None = typer.Option(
|
||||
None,
|
||||
"--judge-criteria",
|
||||
help="Custom judge rubric or evaluation criteria override (free text).",
|
||||
),
|
||||
) -> None:
|
||||
"""Optimize a prompt without any reference data.
|
||||
|
||||
Usage:
|
||||
prometheus optimize -i config.yaml -o result.yaml
|
||||
prometheus optimize -i config.yaml --eval-dataset data.csv --eval-metric bleu
|
||||
prometheus optimize -i config.yaml --checkpoint-dir .prometheus/checkpoints --resume
|
||||
"""
|
||||
asyncio.run(
|
||||
_async_optimize(
|
||||
input, output, verbose, debug, log_format, log_file,
|
||||
max_retries, error_strategy, max_concurrency,
|
||||
eval_dataset, eval_metric,
|
||||
checkpoint_dir, checkpoint_interval, resume,
|
||||
population_size, crossover_rate, mutation_rate,
|
||||
validation_split, early_stop_patience,
|
||||
judge_criteria,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _async_optimize(
|
||||
input: str,
|
||||
output: str,
|
||||
verbose: bool,
|
||||
debug: bool,
|
||||
log_format: str,
|
||||
log_file: str | None,
|
||||
max_retries: int,
|
||||
error_strategy: str,
|
||||
max_concurrency: int,
|
||||
eval_dataset: str | None,
|
||||
eval_metric: str,
|
||||
checkpoint_dir: str | None,
|
||||
checkpoint_interval: int,
|
||||
resume: bool,
|
||||
population_size: int = 1,
|
||||
crossover_rate: float = 0.5,
|
||||
mutation_rate: float = 0.3,
|
||||
validation_split: float = 0.3,
|
||||
early_stop_patience: int = 5,
|
||||
judge_criteria: str | None = None,
|
||||
) -> None:
|
||||
# Configure structured logging
|
||||
if debug:
|
||||
log_level = logging.DEBUG
|
||||
elif verbose:
|
||||
log_level = logging.INFO
|
||||
else:
|
||||
log_level = logging.WARNING
|
||||
configure_logging(level=log_level, log_format=log_format, log_file=log_file)
|
||||
|
||||
console.print(
|
||||
Panel.fit(
|
||||
"PROMETHEUS — Prompt Evolution Engine",
|
||||
subtitle="No reference data required",
|
||||
)
|
||||
)
|
||||
|
||||
# 1. Load & validate config
|
||||
persistence = YamlPersistence()
|
||||
raw_config = persistence.read_config(input)
|
||||
|
||||
# CLI flags override config file values
|
||||
raw_config.setdefault("max_retries", max_retries)
|
||||
raw_config.setdefault("error_strategy", error_strategy)
|
||||
raw_config.setdefault("max_concurrency", max_concurrency)
|
||||
raw_config["output_path"] = output
|
||||
raw_config["verbose"] = verbose
|
||||
raw_config["debug"] = debug
|
||||
raw_config["log_format"] = log_format
|
||||
raw_config["log_file"] = log_file
|
||||
if eval_dataset:
|
||||
raw_config["eval_dataset_path"] = eval_dataset
|
||||
raw_config.setdefault("eval_metric", eval_metric)
|
||||
if checkpoint_dir:
|
||||
raw_config["checkpoint_dir"] = checkpoint_dir
|
||||
raw_config.setdefault("checkpoint_interval", checkpoint_interval)
|
||||
if resume:
|
||||
raw_config["resume"] = True
|
||||
raw_config.setdefault("population_size", population_size)
|
||||
raw_config.setdefault("crossover_rate", crossover_rate)
|
||||
raw_config.setdefault("mutation_rate", mutation_rate)
|
||||
raw_config.setdefault("validation_split", validation_split)
|
||||
raw_config.setdefault("early_stop_patience", early_stop_patience)
|
||||
if judge_criteria:
|
||||
raw_config["judge_criteria"] = judge_criteria
|
||||
|
||||
try:
|
||||
config = OptimizationConfig.model_validate(raw_config)
|
||||
except ValidationError as exc:
|
||||
console.print("[bold red]Configuration error:[/bold red]\n")
|
||||
for err in exc.errors():
|
||||
loc = " → ".join(str(l) for l in err["loc"])
|
||||
console.print(f" [red]• {loc}: {err['msg']}[/red]")
|
||||
raise typer.Exit(code=1) from exc
|
||||
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
||||
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
||||
|
||||
# 2. Create per-model DSPy LM instances
|
||||
def _model_lm_kwargs(
|
||||
model_api_base: str | None,
|
||||
model_api_key_env: str | None,
|
||||
) -> dict:
|
||||
"""Build kwargs for dspy.LM, using per-model overrides with global fallback."""
|
||||
kwargs: dict = {}
|
||||
api_base = model_api_base or config.api_base
|
||||
api_key_env = model_api_key_env or config.api_key_env
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
if api_key_env:
|
||||
kwargs["api_key"] = os.environ.get(api_key_env, "")
|
||||
return kwargs
|
||||
|
||||
task_lm = dspy.LM(
|
||||
config.task_model,
|
||||
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env),
|
||||
)
|
||||
judge_lm = dspy.LM(
|
||||
config.judge_model,
|
||||
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env),
|
||||
)
|
||||
proposer_lm = dspy.LM(
|
||||
config.proposer_model,
|
||||
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env),
|
||||
)
|
||||
synth_lm = dspy.LM(
|
||||
config.synth_model,
|
||||
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env),
|
||||
)
|
||||
|
||||
# 3. Build adapters (Dependency Injection — each gets its own LM + retry config)
|
||||
synth_adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||
llm_adapter = DSPyLLMAdapter(
|
||||
lm=task_lm,
|
||||
max_retries=config.max_retries,
|
||||
retry_delay_base=config.retry_delay_base,
|
||||
)
|
||||
judge_adapter = DSPyJudgeAdapter(
|
||||
lm=judge_lm,
|
||||
max_retries=config.max_retries,
|
||||
retry_delay_base=config.retry_delay_base,
|
||||
max_concurrency=config.max_concurrency,
|
||||
judge_criteria=config.judge_criteria,
|
||||
judge_dimensions=config.judge_dimensions,
|
||||
)
|
||||
proposer_adapter = DSPyProposerAdapter(
|
||||
lm=proposer_lm,
|
||||
max_retries=config.max_retries,
|
||||
retry_delay_base=config.retry_delay_base,
|
||||
)
|
||||
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
||||
evaluator = PromptEvaluator(
|
||||
executor=llm_adapter,
|
||||
judge=judge_adapter,
|
||||
max_concurrency=config.max_concurrency,
|
||||
)
|
||||
# Build checkpoint port if checkpoint_dir is configured
|
||||
checkpoint_port = None
|
||||
if config.checkpoint_dir:
|
||||
checkpoint_port = JsonCheckpointPersistence(checkpoint_dir=config.checkpoint_dir)
|
||||
|
||||
# Build crossover/mutation adapters for population-based evolution
|
||||
crossover_adapter = None
|
||||
mutation_adapter = None
|
||||
if config.population_size > 1:
|
||||
# Reuse proposer LM for crossover and mutation (same model, same role)
|
||||
crossover_adapter = DSPyCrossoverAdapter(
|
||||
lm=proposer_lm,
|
||||
max_retries=config.max_retries,
|
||||
retry_delay_base=config.retry_delay_base,
|
||||
)
|
||||
mutation_adapter = DSPyMutationAdapter(
|
||||
lm=proposer_lm,
|
||||
max_retries=config.max_retries,
|
||||
retry_delay_base=config.retry_delay_base,
|
||||
)
|
||||
|
||||
use_case = OptimizePromptUseCase(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer_adapter,
|
||||
bootstrap=bootstrap,
|
||||
checkpoint_port=checkpoint_port,
|
||||
crossover_port=crossover_adapter,
|
||||
mutation_port=mutation_adapter,
|
||||
)
|
||||
|
||||
# 4. Execute
|
||||
with console.status("[bold green]Evolving prompt..."):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
# 4b. Compute actual LLM call count from adapter counters
|
||||
actual_llm_calls = (
|
||||
llm_adapter.call_count
|
||||
+ judge_adapter.call_count
|
||||
+ proposer_adapter.call_count
|
||||
+ synth_adapter.call_count
|
||||
+ (crossover_adapter.call_count if crossover_adapter else 0)
|
||||
+ (mutation_adapter.call_count if mutation_adapter else 0)
|
||||
)
|
||||
result = OptimizationResult(
|
||||
optimized_prompt=result.optimized_prompt,
|
||||
initial_prompt=result.initial_prompt,
|
||||
iterations_used=result.iterations_used,
|
||||
total_llm_calls=actual_llm_calls,
|
||||
initial_score=result.initial_score,
|
||||
final_score=result.final_score,
|
||||
improvement=result.improvement,
|
||||
history=result.history,
|
||||
final_validation_score=result.final_validation_score,
|
||||
best_validation_score=result.best_validation_score,
|
||||
early_stopped=result.early_stopped,
|
||||
)
|
||||
|
||||
# 5. Display results
|
||||
_display_result(result)
|
||||
|
||||
# 6. Optional ground-truth evaluation on the optimized prompt
|
||||
if config.eval_dataset_path:
|
||||
dataset = FileDatasetLoader().load(config.eval_dataset_path)
|
||||
if config.eval_metric == "llm_judge":
|
||||
# llm_judge reuses the existing PromptEvaluator with the LLM judge
|
||||
from prometheus.domain.entities import SyntheticExample
|
||||
synth_dataset = [
|
||||
SyntheticExample(input_text=ex.input_text, id=ex.id) for ex in dataset
|
||||
]
|
||||
gt_eval = PromptEvaluator(
|
||||
executor=llm_adapter,
|
||||
judge=judge_adapter,
|
||||
max_concurrency=config.max_concurrency,
|
||||
)
|
||||
with console.status("[bold green]Running ground-truth evaluation (llm_judge)..."):
|
||||
gt_result = await gt_eval.evaluate(
|
||||
prompt=Prompt(text=result.optimized_prompt),
|
||||
minibatch=synth_dataset,
|
||||
task_description=config.task_description,
|
||||
)
|
||||
else:
|
||||
gt_evaluator = GroundTruthEvaluator(
|
||||
executor=llm_adapter,
|
||||
similarity=create_similarity_adapter(config.eval_metric),
|
||||
max_concurrency=config.max_concurrency,
|
||||
)
|
||||
with console.status("[bold green]Running ground-truth evaluation..."):
|
||||
gt_result = await gt_evaluator.evaluate(
|
||||
prompt=Prompt(text=result.optimized_prompt),
|
||||
dataset=dataset,
|
||||
)
|
||||
_display_ground_truth(gt_result, config.eval_metric, len(dataset))
|
||||
|
||||
# 7. Save
|
||||
_save_result(persistence, output, result)
|
||||
console.print(f"\n[green]Results saved to {output}[/green]")
|
||||
|
||||
|
||||
def _display_result(result: OptimizationResult) -> None:
|
||||
"""Display a Rich summary in the terminal."""
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold green]Optimized Prompt[/bold green]\n\n{result.optimized_prompt}",
|
||||
title="Result",
|
||||
)
|
||||
)
|
||||
table = Table(title="Metrics")
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Value", style="bold")
|
||||
table.add_row("Initial Score", f"{result.initial_score:.2f}")
|
||||
table.add_row("Final Score", f"{result.final_score:.2f}")
|
||||
table.add_row("Improvement", f"{result.improvement:+.2f}")
|
||||
if result.best_validation_score is not None:
|
||||
table.add_row("Best Validation Score", f"{result.best_validation_score:.4f}")
|
||||
if result.early_stopped:
|
||||
table.add_row("Early Stopped", "[yellow]Yes[/yellow]")
|
||||
table.add_row("Iterations", str(result.iterations_used))
|
||||
table.add_row("LLM Calls", str(result.total_llm_calls))
|
||||
console.print(table)
|
||||
|
||||
|
||||
def _save_result(
|
||||
persistence: YamlPersistence,
|
||||
path: str,
|
||||
result: OptimizationResult,
|
||||
) -> None:
|
||||
"""Save the result as YAML."""
|
||||
persistence.write_result(path, asdict(result))
|
||||
|
||||
|
||||
def _display_ground_truth(
|
||||
result: EvalResult, metric: str, dataset_size: int
|
||||
) -> None:
|
||||
"""Display ground-truth evaluation results."""
|
||||
console.print()
|
||||
table = Table(title=f"Ground-Truth Evaluation (metric: {metric})")
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Value", style="bold")
|
||||
table.add_row("Dataset Size", str(dataset_size))
|
||||
table.add_row("Mean Score", f"{result.mean_score:.4f}")
|
||||
table.add_row("Total Score", f"{result.total_score:.4f}")
|
||||
exact_matches = sum(1 for s in result.scores if s >= 0.99)
|
||||
table.add_row("Exact Matches", f"{exact_matches}/{dataset_size}")
|
||||
table.add_row("Accuracy", f"{exact_matches / dataset_size:.2%}")
|
||||
console.print(table)
|
||||
18
src/prometheus/cli/commands/version.py
Normal file
18
src/prometheus/cli/commands/version.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""prometheus version — print the current version."""
|
||||
from __future__ import annotations
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
|
||||
from prometheus import __version__
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def register(app: typer.Typer) -> None:
|
||||
"""Register the version command on the Typer app."""
|
||||
|
||||
@app.command()
|
||||
def version() -> None:
|
||||
"""Print the PROMETHEUS version."""
|
||||
console.print(f"PROMETHEUS {__version__}")
|
||||
96
src/prometheus/cli/logging_setup.py
Normal file
96
src/prometheus/cli/logging_setup.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Structured logging configuration for PROMETHEUS.
|
||||
|
||||
Supports text (human-readable) and JSON (machine-parseable) output,
|
||||
configurable log levels, and optional file output.
|
||||
|
||||
Fixes Bug #4: verbose mode now reliably produces output by configuring
|
||||
handlers explicitly instead of relying on ``logging.basicConfig``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from typing import TextIO
|
||||
|
||||
|
||||
_PROMETHEUS_LOGGER = "prometheus"
|
||||
|
||||
|
||||
class _JsonFormatter(logging.Formatter):
|
||||
"""Emit one JSON object per log line."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
message = record.getMessage()
|
||||
payload: dict = {
|
||||
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": message,
|
||||
}
|
||||
# Merge any extra structured fields the caller attached.
|
||||
if hasattr(record, "structured"):
|
||||
payload["structured"] = record.structured # type: ignore[attr-defined]
|
||||
if record.exc_info and record.exc_info[1] is not None:
|
||||
payload["exception"] = self.formatException(record.exc_info)
|
||||
return json.dumps(payload, default=str)
|
||||
|
||||
|
||||
class _TextFormatter(logging.Formatter):
|
||||
"""Human-readable format with structured extras appended."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
base = super().format(record)
|
||||
if hasattr(record, "structured") and record.structured:
|
||||
extras = " ".join(f"{k}={v}" for k, v in record.structured.items())
|
||||
base = f"{base} {extras}"
|
||||
return base
|
||||
|
||||
|
||||
def configure_logging(
|
||||
*,
|
||||
level: int = logging.WARNING,
|
||||
log_format: str = "text",
|
||||
log_file: str | None = None,
|
||||
) -> None:
|
||||
"""Configure the prometheus root logger.
|
||||
|
||||
Args:
|
||||
level: Logging level (e.g. logging.DEBUG, logging.INFO).
|
||||
log_format: ``"text"`` for human-readable or ``"json"`` for
|
||||
machine-parseable output.
|
||||
log_file: Optional path to also write logs to a file.
|
||||
"""
|
||||
prom_logger = logging.getLogger(_PROMETHEUS_LOGGER)
|
||||
prom_logger.setLevel(level)
|
||||
# Remove any stale handlers so re-configuration is idempotent.
|
||||
prom_logger.handlers.clear()
|
||||
|
||||
if log_format == "json":
|
||||
fmt = _JsonFormatter()
|
||||
else:
|
||||
fmt = _TextFormatter(
|
||||
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
# Console handler (stderr so it doesn't mix with Rich stdout)
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
console_handler.setFormatter(fmt)
|
||||
prom_logger.addHandler(console_handler)
|
||||
|
||||
# Optional file handler
|
||||
if log_file:
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(fmt)
|
||||
prom_logger.addHandler(file_handler)
|
||||
|
||||
# Prevent propagation to root logger to avoid duplicate output
|
||||
prom_logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Return a child logger under the prometheus namespace."""
|
||||
return logging.getLogger(f"{_PROMETHEUS_LOGGER}.{name}")
|
||||
@@ -1,12 +1,2 @@
|
||||
"""Application settings."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppSettings:
|
||||
"""Non-sensitive settings, hardcoded for the MVP."""
|
||||
|
||||
app_name: str = "prometheus"
|
||||
version: str = "0.1.0"
|
||||
|
||||
@@ -31,6 +31,15 @@ class SyntheticExample:
|
||||
id: int = 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GroundTruthExample:
|
||||
"""A ground-truth evaluation example with a known-good expected output."""
|
||||
|
||||
input_text: str
|
||||
expected_output: str
|
||||
id: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trajectory:
|
||||
"""Execution trace of a prompt on an input.
|
||||
@@ -85,3 +94,6 @@ class OptimizationState:
|
||||
synthetic_pool: list[SyntheticExample] = field(default_factory=list)
|
||||
history: list[dict[str, Any]] = field(default_factory=list)
|
||||
total_llm_calls: int = 0
|
||||
# Hold-out validation
|
||||
best_validation_score: float | None = None
|
||||
early_stopped: bool = False
|
||||
|
||||
@@ -8,7 +8,14 @@ from abc import ABC, abstractmethod
|
||||
|
||||
from typing import Any
|
||||
|
||||
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
GroundTruthExample,
|
||||
OptimizationState,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
Trajectory,
|
||||
)
|
||||
|
||||
|
||||
class LLMPort(ABC):
|
||||
@@ -18,7 +25,7 @@ class LLMPort(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, prompt: Prompt, input_text: str) -> str:
|
||||
async def execute(self, prompt: Prompt, input_text: str) -> str:
|
||||
"""Execute the prompt on the input, return the raw response."""
|
||||
...
|
||||
|
||||
@@ -31,7 +38,7 @@ class JudgePort(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def judge_batch(
|
||||
async def judge_batch(
|
||||
self,
|
||||
task_description: str,
|
||||
pairs: list[tuple[str, str]],
|
||||
@@ -50,7 +57,7 @@ class ProposerPort(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def propose(
|
||||
async def propose(
|
||||
self,
|
||||
current_prompt: Prompt,
|
||||
trajectories: list[Trajectory],
|
||||
@@ -73,6 +80,34 @@ class SyntheticGeneratorPort(ABC):
|
||||
...
|
||||
|
||||
|
||||
class CrossoverPort(ABC):
|
||||
"""Port for crossover — combining instructions from two parent candidates."""
|
||||
|
||||
@abstractmethod
|
||||
async def crossover(
|
||||
self,
|
||||
parent_a: Prompt,
|
||||
parent_b: Prompt,
|
||||
task_description: str,
|
||||
) -> Prompt:
|
||||
"""Combine instructions from two parents into a child prompt."""
|
||||
...
|
||||
|
||||
|
||||
class MutationPort(ABC):
|
||||
"""Port for mutating a prompt — paraphrase, constrain, generalize, specialize."""
|
||||
|
||||
@abstractmethod
|
||||
async def mutate(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
task_description: str,
|
||||
mutation_type: str = "paraphrase",
|
||||
) -> Prompt:
|
||||
"""Apply a mutation to the prompt."""
|
||||
...
|
||||
|
||||
|
||||
class PersistencePort(ABC):
|
||||
"""Port for reading/writing files."""
|
||||
|
||||
@@ -83,3 +118,49 @@ class PersistencePort(ABC):
|
||||
@abstractmethod
|
||||
def write_result(self, path: str, data: dict[str, Any]) -> None:
|
||||
...
|
||||
|
||||
|
||||
class SimilarityPort(ABC):
|
||||
"""Port for computing similarity between a prediction and expected output.
|
||||
|
||||
Infrastructure provides concrete metrics (exact match, BLEU, ROUGE, cosine).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, prediction: str, expected: str) -> float:
|
||||
"""Compute similarity score in [0, 1]. 1.0 = perfect match."""
|
||||
...
|
||||
|
||||
|
||||
class DatasetLoaderPort(ABC):
|
||||
"""Port for loading ground-truth evaluation datasets."""
|
||||
|
||||
@abstractmethod
|
||||
def load(self, path: str) -> list[GroundTruthExample]:
|
||||
"""Load a dataset from a CSV or JSON file.
|
||||
|
||||
Each row must have 'input' and 'expected_output' fields.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class CheckpointPort(ABC):
|
||||
"""Port for saving and loading optimization checkpoints.
|
||||
|
||||
Enables resuming long-running optimizations after interruption.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, state: OptimizationState) -> None:
|
||||
"""Persist the current optimization state to disk."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> OptimizationState | None:
|
||||
"""Load the latest checkpoint. Returns None if no checkpoint exists."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def latest_exists(self) -> bool:
|
||||
"""Check if a checkpoint file is available for resuming."""
|
||||
...
|
||||
|
||||
@@ -19,3 +19,26 @@ def should_accept(
|
||||
def normalize_score(raw: float, min_val: float = 0.0, max_val: float = 1.0) -> float:
|
||||
"""Clamp a score within [min_val, max_val]."""
|
||||
return max(min_val, min(max_val, raw))
|
||||
|
||||
|
||||
def weighted_aggregate(
|
||||
scores: dict[str, float],
|
||||
weights: dict[str, float],
|
||||
) -> float:
|
||||
"""Compute a weighted average of per-dimension scores.
|
||||
|
||||
Args:
|
||||
scores: Mapping of dimension name → score (0.0–1.0).
|
||||
weights: Mapping of dimension name → weight (0.0–1.0).
|
||||
|
||||
Returns:
|
||||
Weighted average in [0.0, 1.0]. Returns 0.0 if inputs are empty.
|
||||
"""
|
||||
if not scores or not weights:
|
||||
return 0.0
|
||||
total_weight = sum(weights.get(name, 0.0) for name in scores)
|
||||
if total_weight == 0.0:
|
||||
return sum(scores.values()) / len(scores)
|
||||
return sum(
|
||||
scores.get(name, 0.0) * weights.get(name, 0.0) for name in scores
|
||||
) / total_weight
|
||||
|
||||
149
src/prometheus/infrastructure/checkpoint.py
Normal file
149
src/prometheus/infrastructure/checkpoint.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
JSON checkpoint persistence — save/load optimization state to disk.
|
||||
|
||||
Implements the CheckpointPort with JSON for human-readable, versionable snapshots.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
from prometheus.cli.logging_setup import get_logger
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
OptimizationState,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
)
|
||||
from prometheus.domain.ports import CheckpointPort
|
||||
|
||||
logger = get_logger("checkpoint")
|
||||
|
||||
_CHECKPOINT_FILE = "latest.json"
|
||||
|
||||
|
||||
class JsonCheckpointPersistence(CheckpointPort):
|
||||
"""Saves optimization state as JSON to a configurable directory."""
|
||||
|
||||
def __init__(self, checkpoint_dir: str | Path = ".prometheus/checkpoints") -> None:
|
||||
self._dir = Path(checkpoint_dir)
|
||||
|
||||
def save(self, state: OptimizationState) -> None:
|
||||
"""Persist the current optimization state to disk."""
|
||||
self._dir.mkdir(parents=True, exist_ok=True)
|
||||
path = self._dir / _CHECKPOINT_FILE
|
||||
data = _serialize_state(state)
|
||||
path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||
logger.info(
|
||||
"Checkpoint saved",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "checkpoint_saved",
|
||||
"path": str(path),
|
||||
"iteration": state.iteration,
|
||||
"total_llm_calls": state.total_llm_calls,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def load(self) -> OptimizationState | None:
|
||||
"""Load the latest checkpoint. Returns None if no checkpoint exists."""
|
||||
path = self._dir / _CHECKPOINT_FILE
|
||||
if not path.exists():
|
||||
return None
|
||||
raw = json.loads(path.read_text(encoding="utf-8"))
|
||||
state = _deserialize_state(raw)
|
||||
logger.info(
|
||||
"Checkpoint loaded",
|
||||
extra={
|
||||
"structured": {
|
||||
"event": "checkpoint_loaded",
|
||||
"path": str(path),
|
||||
"iteration": state.iteration,
|
||||
"total_llm_calls": state.total_llm_calls,
|
||||
},
|
||||
},
|
||||
)
|
||||
return state
|
||||
|
||||
def latest_exists(self) -> bool:
|
||||
"""Check if a checkpoint file is available for resuming."""
|
||||
return (self._dir / _CHECKPOINT_FILE).exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization helpers — keep the JSON format stable and self-describing.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
def _serialize_state(state: OptimizationState) -> dict:
|
||||
"""Convert OptimizationState to a JSON-safe dict."""
|
||||
return {
|
||||
"schema_version": _SCHEMA_VERSION,
|
||||
"iteration": state.iteration,
|
||||
"best_candidate": _serialize_candidate(state.best_candidate),
|
||||
"candidates": [_serialize_candidate(c) for c in state.candidates],
|
||||
"synthetic_pool": [
|
||||
{"input_text": ex.input_text, "category": ex.category, "id": ex.id}
|
||||
for ex in state.synthetic_pool
|
||||
],
|
||||
"history": state.history,
|
||||
"total_llm_calls": state.total_llm_calls,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_candidate(candidate: Candidate | None) -> dict | None:
|
||||
if candidate is None:
|
||||
return None
|
||||
return {
|
||||
"prompt_text": candidate.prompt.text,
|
||||
"prompt_metadata": candidate.prompt.metadata,
|
||||
"best_score": candidate.best_score,
|
||||
"generation": candidate.generation,
|
||||
"parent_id": candidate.parent_id,
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_state(data: dict) -> OptimizationState:
|
||||
"""Reconstruct OptimizationState from a checkpoint dict."""
|
||||
version = data.get("schema_version", 1)
|
||||
# Future migration hooks go here: if version < 2: ...
|
||||
|
||||
best_raw = data.get("best_candidate")
|
||||
best_candidate = _deserialize_candidate(best_raw)
|
||||
|
||||
candidates = [_deserialize_candidate(c) for c in data.get("candidates", [])]
|
||||
|
||||
synthetic_pool = [
|
||||
SyntheticExample(
|
||||
input_text=ex["input_text"],
|
||||
category=ex.get("category", "default"),
|
||||
id=ex.get("id", 0),
|
||||
)
|
||||
for ex in data.get("synthetic_pool", [])
|
||||
]
|
||||
|
||||
state = OptimizationState(
|
||||
iteration=data.get("iteration", 0),
|
||||
best_candidate=best_candidate,
|
||||
candidates=candidates,
|
||||
synthetic_pool=synthetic_pool,
|
||||
history=data.get("history", []),
|
||||
total_llm_calls=data.get("total_llm_calls", 0),
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
def _deserialize_candidate(raw: dict | None) -> Candidate | None:
|
||||
if raw is None:
|
||||
return None
|
||||
return Candidate(
|
||||
prompt=Prompt(text=raw["prompt_text"], metadata=raw.get("prompt_metadata", {})),
|
||||
best_score=raw.get("best_score", 0.0),
|
||||
generation=raw.get("generation", 0),
|
||||
parent_id=raw.get("parent_id"),
|
||||
)
|
||||
63
src/prometheus/infrastructure/crossover_adapter.py
Normal file
63
src/prometheus/infrastructure/crossover_adapter.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Adapter: Instruction Crossover via DSPy.
|
||||
|
||||
Implements CrossoverPort — combines two parent prompts into a child.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import dspy
|
||||
|
||||
from prometheus.domain.entities import Prompt
|
||||
from prometheus.domain.ports import CrossoverPort
|
||||
from prometheus.infrastructure.dspy_modules import InstructionCrossover
|
||||
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||
|
||||
|
||||
class DSPyCrossoverAdapter(CrossoverPort):
|
||||
"""Uses DSPy to combine two parent instructions into a child."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lm: dspy.LM,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
) -> None:
|
||||
self._lm = lm
|
||||
self._crossover = InstructionCrossover()
|
||||
self._max_retries = max_retries
|
||||
self._retry_delay_base = retry_delay_base
|
||||
self.call_count: int = 0
|
||||
|
||||
async def crossover(
|
||||
self,
|
||||
parent_a: Prompt,
|
||||
parent_b: Prompt,
|
||||
task_description: str,
|
||||
) -> Prompt:
|
||||
async def _call() -> Prompt:
|
||||
return await asyncio.to_thread(
|
||||
self._sync_crossover, parent_a, parent_b, task_description,
|
||||
)
|
||||
|
||||
return await async_retry_with_backoff(
|
||||
_call,
|
||||
max_retries=self._max_retries,
|
||||
retry_delay_base=self._retry_delay_base,
|
||||
)
|
||||
|
||||
def _sync_crossover(
|
||||
self,
|
||||
parent_a: Prompt,
|
||||
parent_b: Prompt,
|
||||
task_description: str,
|
||||
) -> Prompt:
|
||||
with dspy.context(lm=self._lm):
|
||||
pred = self._crossover(
|
||||
parent_a=parent_a.text,
|
||||
parent_b=parent_b.text,
|
||||
task_description=task_description,
|
||||
)
|
||||
self.call_count += 1
|
||||
return Prompt(text=pred.child_instruction)
|
||||
75
src/prometheus/infrastructure/dataset_loader.py
Normal file
75
src/prometheus/infrastructure/dataset_loader.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Dataset loader — loads ground-truth CSV/JSON datasets."""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from prometheus.domain.entities import GroundTruthExample
|
||||
from prometheus.domain.ports import DatasetLoaderPort
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileDatasetLoader(DatasetLoaderPort):
|
||||
"""Loads evaluation datasets from CSV or JSON files.
|
||||
|
||||
CSV files must have 'input' and 'expected_output' columns.
|
||||
JSON files must be an array of objects with 'input' and 'expected_output' keys.
|
||||
"""
|
||||
|
||||
def load(self, path: str) -> list[GroundTruthExample]:
|
||||
suffix = Path(path).suffix.lower()
|
||||
if suffix == ".csv":
|
||||
return self._load_csv(path)
|
||||
elif suffix in (".json", ".jsonl"):
|
||||
return self._load_json(path)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported dataset format '{suffix}'. Use .csv, .json, or .jsonl."
|
||||
)
|
||||
|
||||
def _load_csv(self, path: str) -> list[GroundTruthExample]:
|
||||
examples: list[GroundTruthExample] = []
|
||||
with open(path, newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for i, row in enumerate(reader):
|
||||
input_text = row.get("input", "").strip()
|
||||
expected = row.get("expected_output", "").strip()
|
||||
if not input_text:
|
||||
logger.warning("Skipping CSV row %d: empty 'input' field", i + 1)
|
||||
continue
|
||||
examples.append(
|
||||
GroundTruthExample(
|
||||
input_text=input_text,
|
||||
expected_output=expected,
|
||||
id=i,
|
||||
)
|
||||
)
|
||||
logger.info("Loaded %d examples from CSV: %s", len(examples), path)
|
||||
return examples
|
||||
|
||||
def _load_json(self, path: str) -> list[GroundTruthExample]:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("JSON dataset must be an array of objects.")
|
||||
examples: list[GroundTruthExample] = []
|
||||
for i, item in enumerate(data):
|
||||
input_text = item.get("input", "").strip() if isinstance(item, dict) else ""
|
||||
expected = (
|
||||
item.get("expected_output", "").strip() if isinstance(item, dict) else ""
|
||||
)
|
||||
if not input_text:
|
||||
logger.warning("Skipping JSON item %d: empty 'input' field", i)
|
||||
continue
|
||||
examples.append(
|
||||
GroundTruthExample(
|
||||
input_text=input_text,
|
||||
expected_output=expected,
|
||||
id=i,
|
||||
)
|
||||
)
|
||||
logger.info("Loaded %d examples from JSON: %s", len(examples), path)
|
||||
return examples
|
||||
@@ -11,8 +11,10 @@ import re
|
||||
import dspy
|
||||
|
||||
from prometheus.infrastructure.dspy_signatures import (
|
||||
CrossoverInstructions,
|
||||
GenerateSyntheticInputs,
|
||||
JudgeOutput,
|
||||
MutateInstruction,
|
||||
ProposeInstruction,
|
||||
)
|
||||
|
||||
@@ -53,19 +55,30 @@ class OutputJudge(dspy.Module):
|
||||
self.judge = dspy.ChainOfThought(JudgeOutput)
|
||||
|
||||
def forward(
|
||||
self, task_description: str, input_text: str, output_text: str
|
||||
self,
|
||||
task_description: str,
|
||||
input_text: str,
|
||||
output_text: str,
|
||||
judge_criteria: str = "",
|
||||
dimension_names: str = "",
|
||||
) -> dspy.Prediction:
|
||||
result = self.judge(
|
||||
task_description=task_description,
|
||||
input_text=input_text,
|
||||
output_text=output_text,
|
||||
judge_criteria=judge_criteria,
|
||||
dimension_names=dimension_names,
|
||||
)
|
||||
try:
|
||||
score = float(result.score)
|
||||
except (ValueError, TypeError):
|
||||
score = 0.5 # neutral fallback
|
||||
score = max(0.0, min(1.0, score))
|
||||
return dspy.Prediction(score=score, feedback=result.feedback)
|
||||
return dspy.Prediction(
|
||||
score=score,
|
||||
feedback=result.feedback,
|
||||
dimension_scores=getattr(result, "dimension_scores", "{}"),
|
||||
)
|
||||
|
||||
|
||||
class InstructionProposer(dspy.Module):
|
||||
@@ -90,3 +103,45 @@ class InstructionProposer(dspy.Module):
|
||||
failure_examples=failure_examples,
|
||||
)
|
||||
return dspy.Prediction(new_instruction=result.new_instruction)
|
||||
|
||||
|
||||
class InstructionCrossover(dspy.Module):
|
||||
"""Crossover: combines two parent instructions into a child."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.crossover = dspy.ChainOfThought(CrossoverInstructions)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
parent_a: str,
|
||||
parent_b: str,
|
||||
task_description: str,
|
||||
) -> dspy.Prediction:
|
||||
result = self.crossover(
|
||||
parent_a=parent_a,
|
||||
parent_b=parent_b,
|
||||
task_description=task_description,
|
||||
)
|
||||
return dspy.Prediction(child_instruction=result.child_instruction)
|
||||
|
||||
|
||||
class InstructionMutator(dspy.Module):
|
||||
"""Mutator: applies a typed mutation to an instruction."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mutate = dspy.ChainOfThought(MutateInstruction)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
current_instruction: str,
|
||||
task_description: str,
|
||||
mutation_type: str,
|
||||
) -> dspy.Prediction:
|
||||
result = self.mutate(
|
||||
current_instruction=current_instruction,
|
||||
task_description=task_description,
|
||||
mutation_type=mutation_type,
|
||||
)
|
||||
return dspy.Prediction(mutated_instruction=result.mutated_instruction)
|
||||
|
||||
@@ -44,6 +44,12 @@ class JudgeOutput(dspy.Signature):
|
||||
output_text: str = dspy.InputField(
|
||||
desc="The assistant's response to evaluate."
|
||||
)
|
||||
judge_criteria: str = dspy.InputField(
|
||||
desc="Custom evaluation rubric or criteria. Empty string = use default judging criteria."
|
||||
)
|
||||
dimension_names: str = dspy.InputField(
|
||||
desc="Comma-separated dimension names for multi-dimensional scoring. Empty string = single overall score."
|
||||
)
|
||||
score: float = dspy.OutputField(
|
||||
desc="Quality score from 0.0 (wrong) to 1.0 (perfect)."
|
||||
)
|
||||
@@ -53,6 +59,12 @@ class JudgeOutput(dspy.Signature):
|
||||
"with the output and how to improve it. Be critical."
|
||||
),
|
||||
)
|
||||
dimension_scores: str = dspy.OutputField(
|
||||
desc=(
|
||||
"JSON object mapping dimension names to scores (0.0-1.0). "
|
||||
'Empty object {} if no dimensions specified.'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ProposeInstruction(dspy.Signature):
|
||||
@@ -77,3 +89,52 @@ class ProposeInstruction(dspy.Signature):
|
||||
new_instruction: str = dspy.OutputField(
|
||||
desc="An improved version of the instruction."
|
||||
)
|
||||
|
||||
|
||||
class CrossoverInstructions(dspy.Signature):
|
||||
"""Combine two instruction prompts into a single improved instruction.
|
||||
|
||||
Take the strongest elements from each parent — structure, phrasing,
|
||||
constraints, examples — and merge them into a coherent child instruction
|
||||
that is strictly better than either parent alone.
|
||||
"""
|
||||
|
||||
parent_a: str = dspy.InputField(
|
||||
desc="First parent instruction."
|
||||
)
|
||||
parent_b: str = dspy.InputField(
|
||||
desc="Second parent instruction."
|
||||
)
|
||||
task_description: str = dspy.InputField(
|
||||
desc="Description of the task."
|
||||
)
|
||||
child_instruction: str = dspy.OutputField(
|
||||
desc=(
|
||||
"A combined instruction that takes the best elements from "
|
||||
"both parents into a single, coherent instruction."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MutateInstruction(dspy.Signature):
|
||||
"""Apply a specific mutation to an instruction prompt.
|
||||
|
||||
The mutation_type determines the transformation:
|
||||
- paraphrase: restate the instruction in different words
|
||||
- constrain: add specificity, constraints, or guard-rails
|
||||
- generalize: broaden the instruction to cover more cases
|
||||
- specialize: narrow the instruction for better focus on the task
|
||||
"""
|
||||
|
||||
current_instruction: str = dspy.InputField(
|
||||
desc="The instruction to mutate."
|
||||
)
|
||||
task_description: str = dspy.InputField(
|
||||
desc="Description of the task."
|
||||
)
|
||||
mutation_type: str = dspy.InputField(
|
||||
desc="Type of mutation: paraphrase, constrain, generalize, or specialize."
|
||||
)
|
||||
mutated_instruction: str = dspy.OutputField(
|
||||
desc="The mutated instruction, preserving core intent but altered per the mutation type."
|
||||
)
|
||||
|
||||
@@ -5,30 +5,141 @@ Implements the JudgePort via the DSPy OutputJudge module.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import dspy
|
||||
|
||||
from prometheus.application.dto import JudgeDimension
|
||||
from prometheus.domain.ports import JudgePort
|
||||
from prometheus.domain.scoring import weighted_aggregate
|
||||
from prometheus.infrastructure.dspy_modules import OutputJudge
|
||||
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DSPyJudgeAdapter(JudgePort):
|
||||
"""Evaluates a batch of (input, output) pairs by calling the Judge for each.
|
||||
|
||||
Sequential for MVP. Future: parallelize via dspy.Parallel.
|
||||
Per-call isolation: a failure on one item returns a zero-score sentinel
|
||||
instead of crashing the whole batch.
|
||||
|
||||
Judge calls run in parallel (bounded by *max_concurrency*).
|
||||
|
||||
When *judge_criteria* or *judge_dimensions* are provided, the judge applies
|
||||
custom rubrics and/or multi-dimensional scoring with weighted aggregation.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
lm: dspy.LM,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
max_concurrency: int = 5,
|
||||
judge_criteria: str | None = None,
|
||||
judge_dimensions: list[JudgeDimension] | None = None,
|
||||
) -> None:
|
||||
self._lm = lm
|
||||
self._judge = OutputJudge()
|
||||
self._max_retries = max_retries
|
||||
self._retry_delay_base = retry_delay_base
|
||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||
self._judge_criteria = judge_criteria or ""
|
||||
self._judge_dimensions = judge_dimensions or []
|
||||
self._dimension_names = (
|
||||
",".join(d.name for d in self._judge_dimensions)
|
||||
if self._judge_dimensions
|
||||
else ""
|
||||
)
|
||||
self._weights: dict[str, float] = (
|
||||
{d.name: d.weight for d in self._judge_dimensions}
|
||||
if self._judge_dimensions
|
||||
else {}
|
||||
)
|
||||
self.call_count: int = 0
|
||||
|
||||
def judge_batch(
|
||||
async def judge_batch(
|
||||
self,
|
||||
task_description: str,
|
||||
pairs: list[tuple[str, str]],
|
||||
) -> list[tuple[float, str]]:
|
||||
results: list[tuple[float, str]] = []
|
||||
for input_text, output_text in pairs:
|
||||
pred = self._judge(
|
||||
tasks = [
|
||||
self._judge_single_safe(task_description, input_text, output_text)
|
||||
for input_text, output_text in pairs
|
||||
]
|
||||
return list(await asyncio.gather(*tasks))
|
||||
|
||||
async def _judge_single_safe(
|
||||
self,
|
||||
task_description: str,
|
||||
input_text: str,
|
||||
output_text: str,
|
||||
) -> tuple[float, str]:
|
||||
async with self._semaphore:
|
||||
try:
|
||||
return await self._judge_single(task_description, input_text, output_text)
|
||||
except Exception as exc:
|
||||
logger.warning("Judge call failed for input '%s…': %s", input_text[:40], exc)
|
||||
return (0.0, f"[judge error: {exc}]")
|
||||
|
||||
async def _judge_single(
|
||||
self,
|
||||
task_description: str,
|
||||
input_text: str,
|
||||
output_text: str,
|
||||
) -> tuple[float, str]:
|
||||
async def _call() -> tuple[float, str]:
|
||||
pred = await asyncio.to_thread(
|
||||
self._sync_judge, task_description, input_text, output_text,
|
||||
)
|
||||
return self._aggregate_result(pred)
|
||||
|
||||
return await async_retry_with_backoff(
|
||||
_call,
|
||||
max_retries=self._max_retries,
|
||||
retry_delay_base=self._retry_delay_base,
|
||||
)
|
||||
|
||||
def _sync_judge(self, task_description: str, input_text: str, output_text: str):
|
||||
with dspy.context(lm=self._lm):
|
||||
result = self._judge(
|
||||
task_description=task_description,
|
||||
input_text=input_text,
|
||||
output_text=output_text,
|
||||
judge_criteria=self._judge_criteria,
|
||||
dimension_names=self._dimension_names,
|
||||
)
|
||||
results.append((pred.score, pred.feedback))
|
||||
return results
|
||||
self.call_count += 1
|
||||
return result
|
||||
|
||||
def _aggregate_result(self, pred: Any) -> tuple[float, str]:
|
||||
"""Compute weighted aggregate score from dimension scores if available."""
|
||||
if not self._judge_dimensions:
|
||||
return (pred.score, pred.feedback)
|
||||
|
||||
# Parse per-dimension scores from LLM output
|
||||
dim_scores: dict[str, float] = {}
|
||||
try:
|
||||
raw = json.loads(pred.dimension_scores)
|
||||
if isinstance(raw, dict):
|
||||
for name in self._weights:
|
||||
val = raw.get(name)
|
||||
if val is not None:
|
||||
dim_scores[name] = max(0.0, min(1.0, float(val)))
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
logger.debug("Failed to parse dimension_scores, falling back to overall score")
|
||||
|
||||
if not dim_scores:
|
||||
return (pred.score, pred.feedback)
|
||||
|
||||
aggregate = weighted_aggregate(dim_scores, self._weights)
|
||||
# Enrich feedback with per-dimension breakdown
|
||||
dim_breakdown = ", ".join(
|
||||
f"{name}={dim_scores.get(name, 0.0):.2f}"
|
||||
for name in self._weights
|
||||
)
|
||||
feedback = f"{pred.feedback} [{dim_breakdown}]"
|
||||
return (aggregate, feedback)
|
||||
|
||||
@@ -5,10 +5,13 @@ Implements the LLMPort via DSPy.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import dspy
|
||||
|
||||
from prometheus.domain.entities import Prompt
|
||||
from prometheus.domain.ports import LLMPort
|
||||
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||
|
||||
|
||||
class DSPyLLMAdapter(LLMPort):
|
||||
@@ -21,12 +24,34 @@ class DSPyLLMAdapter(LLMPort):
|
||||
input_text: str = dspy.InputField(desc="The input to process.")
|
||||
output: str = dspy.OutputField(desc="The response following the instruction.")
|
||||
|
||||
def __init__(self, model: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
lm: dspy.LM,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
) -> None:
|
||||
self._lm = lm
|
||||
self._predictor = dspy.Predict(self._ExecuteSignature)
|
||||
self._max_retries = max_retries
|
||||
self._retry_delay_base = retry_delay_base
|
||||
self.call_count: int = 0
|
||||
|
||||
def execute(self, prompt: Prompt, input_text: str) -> str:
|
||||
result = self._predictor(
|
||||
instruction=prompt.text,
|
||||
input_text=input_text,
|
||||
async def execute(self, prompt: Prompt, input_text: str) -> str:
|
||||
async def _call() -> str:
|
||||
# DSPy is synchronous — run in a thread to avoid blocking the event loop.
|
||||
return await asyncio.to_thread(self._sync_execute, prompt, input_text)
|
||||
|
||||
return await async_retry_with_backoff(
|
||||
_call,
|
||||
max_retries=self._max_retries,
|
||||
retry_delay_base=self._retry_delay_base,
|
||||
)
|
||||
|
||||
def _sync_execute(self, prompt: Prompt, input_text: str) -> str:
|
||||
with dspy.context(lm=self._lm):
|
||||
result = self._predictor(
|
||||
instruction=prompt.text,
|
||||
input_text=input_text,
|
||||
)
|
||||
self.call_count += 1
|
||||
return str(result.output)
|
||||
|
||||
70
src/prometheus/infrastructure/mutation_adapter.py
Normal file
70
src/prometheus/infrastructure/mutation_adapter.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Adapter: Instruction Mutation via DSPy.
|
||||
|
||||
Implements MutationPort — applies typed mutations (paraphrase, constrain,
|
||||
generalize, specialize) to a prompt.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
import dspy
|
||||
|
||||
from prometheus.domain.entities import Prompt
|
||||
from prometheus.domain.ports import MutationPort
|
||||
from prometheus.infrastructure.dspy_modules import InstructionMutator
|
||||
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||
|
||||
_MUTATION_TYPES = ("paraphrase", "constrain", "generalize", "specialize")
|
||||
|
||||
|
||||
class DSPyMutationAdapter(MutationPort):
|
||||
"""Uses DSPy to apply typed mutations to an instruction."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lm: dspy.LM,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
) -> None:
|
||||
self._lm = lm
|
||||
self._mutator = InstructionMutator()
|
||||
self._max_retries = max_retries
|
||||
self._retry_delay_base = retry_delay_base
|
||||
self.call_count: int = 0
|
||||
|
||||
async def mutate(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
task_description: str,
|
||||
mutation_type: str = "paraphrase",
|
||||
) -> Prompt:
|
||||
if mutation_type not in _MUTATION_TYPES:
|
||||
mutation_type = random.choice(_MUTATION_TYPES)
|
||||
|
||||
async def _call() -> Prompt:
|
||||
return await asyncio.to_thread(
|
||||
self._sync_mutate, prompt, task_description, mutation_type,
|
||||
)
|
||||
|
||||
return await async_retry_with_backoff(
|
||||
_call,
|
||||
max_retries=self._max_retries,
|
||||
retry_delay_base=self._retry_delay_base,
|
||||
)
|
||||
|
||||
def _sync_mutate(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
task_description: str,
|
||||
mutation_type: str,
|
||||
) -> Prompt:
|
||||
with dspy.context(lm=self._lm):
|
||||
pred = self._mutator(
|
||||
current_instruction=prompt.text,
|
||||
task_description=task_description,
|
||||
mutation_type=mutation_type,
|
||||
)
|
||||
self.call_count += 1
|
||||
return Prompt(text=pred.mutated_instruction)
|
||||
@@ -6,29 +6,58 @@ Converts trajectories into readable format for the LLM proposer.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import dspy
|
||||
|
||||
from prometheus.domain.entities import Prompt, Trajectory
|
||||
from prometheus.domain.ports import ProposerPort
|
||||
from prometheus.infrastructure.dspy_modules import InstructionProposer
|
||||
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||
|
||||
|
||||
class DSPyProposerAdapter(ProposerPort):
|
||||
"""Uses evaluation trajectories to build a failure report and propose a new prompt."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
lm: dspy.LM,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
) -> None:
|
||||
self._lm = lm
|
||||
self._proposer = InstructionProposer()
|
||||
self._max_retries = max_retries
|
||||
self._retry_delay_base = retry_delay_base
|
||||
self.call_count: int = 0
|
||||
|
||||
def propose(
|
||||
async def propose(
|
||||
self,
|
||||
current_prompt: Prompt,
|
||||
trajectories: list[Trajectory],
|
||||
task_description: str,
|
||||
) -> Prompt:
|
||||
failure_examples = self._format_failures(trajectories)
|
||||
pred = self._proposer(
|
||||
current_instruction=current_prompt.text,
|
||||
task_description=task_description,
|
||||
failure_examples=failure_examples,
|
||||
|
||||
async def _call() -> Prompt:
|
||||
return await asyncio.to_thread(
|
||||
self._sync_propose, current_prompt, task_description, failure_examples,
|
||||
)
|
||||
|
||||
return await async_retry_with_backoff(
|
||||
_call,
|
||||
max_retries=self._max_retries,
|
||||
retry_delay_base=self._retry_delay_base,
|
||||
)
|
||||
|
||||
def _sync_propose(self, current_prompt: Prompt, task_description: str, failure_examples: str) -> Prompt:
|
||||
with dspy.context(lm=self._lm):
|
||||
pred = self._proposer(
|
||||
current_instruction=current_prompt.text,
|
||||
task_description=task_description,
|
||||
failure_examples=failure_examples,
|
||||
)
|
||||
self.call_count += 1
|
||||
return Prompt(text=pred.new_instruction)
|
||||
|
||||
@staticmethod
|
||||
|
||||
102
src/prometheus/infrastructure/retry.py
Normal file
102
src/prometheus/infrastructure/retry.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Retry with exponential backoff for transient LLM errors."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Status codes / keywords that indicate a transient error worth retrying.
|
||||
_TRANSIENT_PATTERNS = (
|
||||
"429",
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"connection error",
|
||||
"connection refused",
|
||||
"overloaded",
|
||||
)
|
||||
|
||||
|
||||
def is_transient_error(exc: Exception) -> bool:
|
||||
"""Return True if the exception looks like a transient LLM/API error."""
|
||||
msg = str(exc).lower()
|
||||
if any(p in msg for p in _TRANSIENT_PATTERNS):
|
||||
return True
|
||||
if isinstance(exc, (ConnectionError, TimeoutError, OSError)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TransientError(RuntimeError):
|
||||
"""Raised when all retry attempts are exhausted for a transient error."""
|
||||
|
||||
|
||||
def retry_with_backoff(
|
||||
fn: Callable[..., T],
|
||||
*args: Any,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""Call *fn* with exponential-backoff retry on transient errors.
|
||||
|
||||
Delay per attempt: ``retry_delay_base * 2 ** attempt`` seconds.
|
||||
"""
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if not is_transient_error(exc) or attempt == max_retries:
|
||||
raise
|
||||
delay = retry_delay_base * (2 ** attempt)
|
||||
logger.warning(
|
||||
"Transient error (attempt %d/%d): %s — retrying in %.1fs",
|
||||
attempt + 1,
|
||||
max_retries + 1,
|
||||
exc,
|
||||
delay,
|
||||
)
|
||||
time.sleep(delay)
|
||||
# Should not reach here, but satisfy type-checker.
|
||||
raise TransientError(str(last_exc)) from last_exc
|
||||
|
||||
|
||||
async def async_retry_with_backoff(
|
||||
fn: Callable[..., Coroutine[Any, Any, T]],
|
||||
*args: Any,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""Async version of retry_with_backoff — uses asyncio.sleep instead of time.sleep."""
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await fn(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if not is_transient_error(exc) or attempt == max_retries:
|
||||
raise
|
||||
delay = retry_delay_base * (2 ** attempt)
|
||||
logger.warning(
|
||||
"Transient error (attempt %d/%d): %s — retrying in %.1fs",
|
||||
attempt + 1,
|
||||
max_retries + 1,
|
||||
exc,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
raise TransientError(str(last_exc)) from last_exc
|
||||
153
src/prometheus/infrastructure/similarity.py
Normal file
153
src/prometheus/infrastructure/similarity.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Similarity adapters — concrete metrics for comparing prediction vs expected."""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
from prometheus.domain.ports import SimilarityPort
|
||||
|
||||
|
||||
class ExactMatchSimilarity(SimilarityPort):
|
||||
"""Case-insensitive exact string match. Returns 1.0 or 0.0."""
|
||||
|
||||
def compute(self, prediction: str, expected: str) -> float:
|
||||
return 1.0 if prediction.strip().lower() == expected.strip().lower() else 0.0
|
||||
|
||||
|
||||
class BleuSimilarity(SimilarityPort):
|
||||
"""BLEU-style n-gram precision (up to 4-grams).
|
||||
|
||||
Simplified implementation using sentence-level BLEU with brevity penalty.
|
||||
Returns a score in [0, 1].
|
||||
"""
|
||||
|
||||
def __init__(self, max_n: int = 4):
|
||||
self._max_n = max_n
|
||||
|
||||
def compute(self, prediction: str, expected: str) -> float:
|
||||
pred_tokens = _tokenize(prediction)
|
||||
ref_tokens = _tokenize(expected)
|
||||
if not pred_tokens or not ref_tokens:
|
||||
return 0.0 if not ref_tokens else 0.0
|
||||
|
||||
# Modified precision for each n-gram
|
||||
precisions: list[float] = []
|
||||
for n in range(1, self._max_n + 1):
|
||||
pred_ngrams = _ngrams(pred_tokens, n)
|
||||
ref_ngrams = _ngrams(ref_tokens, n)
|
||||
if not pred_ngrams:
|
||||
break
|
||||
clipped = sum(min(pred_ngrams[ng], ref_ngrams.get(ng, 0)) for ng in pred_ngrams)
|
||||
total = sum(pred_ngrams.values())
|
||||
precisions.append(clipped / total if total > 0 else 0.0)
|
||||
|
||||
if not precisions:
|
||||
return 0.0
|
||||
|
||||
# Geometric mean of precisions
|
||||
log_avg = sum(math.log(p) for p in precisions if p > 0)
|
||||
n_nonzero = sum(1 for p in precisions if p > 0)
|
||||
if n_nonzero == 0:
|
||||
return 0.0
|
||||
geo_mean = math.exp(log_avg / n_nonzero)
|
||||
|
||||
# Brevity penalty
|
||||
bp = 1.0
|
||||
if len(pred_tokens) < len(ref_tokens):
|
||||
bp = math.exp(1 - len(ref_tokens) / len(pred_tokens))
|
||||
|
||||
return min(bp * geo_mean, 1.0)
|
||||
|
||||
|
||||
class RougeLSimilarity(SimilarityPort):
|
||||
"""ROUGE-L using Longest Common Subsequence.
|
||||
|
||||
Returns F1 score combining precision and recall in [0, 1].
|
||||
"""
|
||||
|
||||
def compute(self, prediction: str, expected: str) -> float:
|
||||
pred_tokens = _tokenize(prediction)
|
||||
ref_tokens = _tokenize(expected)
|
||||
if not pred_tokens or not ref_tokens:
|
||||
return 0.0
|
||||
|
||||
lcs_len = _lcs_length(pred_tokens, ref_tokens)
|
||||
precision = lcs_len / len(pred_tokens)
|
||||
recall = lcs_len / len(ref_tokens)
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
class CosineSimilarity(SimilarityPort):
|
||||
"""TF-IDF cosine similarity between bag-of-words vectors.
|
||||
|
||||
Lightweight semantic similarity without external embedding models.
|
||||
"""
|
||||
|
||||
def compute(self, prediction: str, expected: str) -> float:
|
||||
pred_tokens = _tokenize(prediction)
|
||||
ref_tokens = _tokenize(expected)
|
||||
if not pred_tokens or not ref_tokens:
|
||||
return 0.0
|
||||
|
||||
pred_counts = Counter(pred_tokens)
|
||||
ref_counts = Counter(ref_tokens)
|
||||
all_tokens = set(pred_counts) | set(ref_counts)
|
||||
|
||||
dot = sum(pred_counts.get(t, 0) * ref_counts.get(t, 0) for t in all_tokens)
|
||||
norm_pred = math.sqrt(sum(v * v for v in pred_counts.values()))
|
||||
norm_ref = math.sqrt(sum(v * v for v in ref_counts.values()))
|
||||
|
||||
if norm_pred == 0 or norm_ref == 0:
|
||||
return 0.0
|
||||
return dot / (norm_pred * norm_ref)
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
|
||||
def _tokenize(text: str) -> list[str]:
|
||||
"""Simple whitespace + punctuation tokenizer."""
|
||||
return re.findall(r"\w+", text.lower())
|
||||
|
||||
|
||||
def _ngrams(tokens: list[str], n: int) -> Counter:
|
||||
"""Count n-grams in a token list."""
|
||||
return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))
|
||||
|
||||
|
||||
def _lcs_length(a: list[str], b: list[str]) -> int:
|
||||
"""Compute length of the Longest Common Subsequence."""
|
||||
m, n = len(a), len(b)
|
||||
prev = [0] * (n + 1)
|
||||
for i in range(1, m + 1):
|
||||
curr = [0] * (n + 1)
|
||||
for j in range(1, n + 1):
|
||||
if a[i - 1] == b[j - 1]:
|
||||
curr[j] = prev[j - 1] + 1
|
||||
else:
|
||||
curr[j] = max(prev[j], curr[j - 1])
|
||||
prev = curr
|
||||
return prev[n]
|
||||
|
||||
|
||||
def create_similarity_adapter(metric: str) -> SimilarityPort:
|
||||
"""Factory: create a SimilarityPort by metric name.
|
||||
|
||||
Supported metrics: exact, bleu, rouge_l, cosine.
|
||||
"""
|
||||
adapters = {
|
||||
"exact": ExactMatchSimilarity,
|
||||
"bleu": BleuSimilarity,
|
||||
"rouge_l": RougeLSimilarity,
|
||||
"cosine": CosineSimilarity,
|
||||
}
|
||||
cls = adapters.get(metric)
|
||||
if cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown eval metric '{metric}'. Choose from: {sorted(adapters)}"
|
||||
)
|
||||
return cls()
|
||||
@@ -5,6 +5,8 @@ Implements the SyntheticGeneratorPort via DSPy.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import dspy
|
||||
|
||||
from prometheus.domain.entities import SyntheticExample
|
||||
from prometheus.domain.ports import SyntheticGeneratorPort
|
||||
from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator
|
||||
@@ -13,18 +15,22 @@ from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator
|
||||
class DSPySyntheticAdapter(SyntheticGeneratorPort):
|
||||
"""Generates synthetic inputs in a single batch call via DSPy."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, lm: dspy.LM) -> None:
|
||||
self._lm = lm
|
||||
self._generator = SyntheticInputGenerator()
|
||||
self.call_count: int = 0
|
||||
|
||||
def generate_inputs(
|
||||
self,
|
||||
task_description: str,
|
||||
n_examples: int,
|
||||
) -> list[SyntheticExample]:
|
||||
pred = self._generator(
|
||||
task_description=task_description,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
with dspy.context(lm=self._lm):
|
||||
pred = self._generator(
|
||||
task_description=task_description,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
self.call_count += 1
|
||||
return [
|
||||
SyntheticExample(
|
||||
input_text=text,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Shared test fixtures."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -66,17 +66,17 @@ def mock_eval_result() -> EvalResult:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_port() -> MagicMock:
|
||||
def mock_llm_port() -> AsyncMock:
|
||||
"""Mock LLMPort that returns canned responses."""
|
||||
port = MagicMock()
|
||||
port = AsyncMock()
|
||||
port.execute.return_value = "This is a mock response."
|
||||
return port
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_judge_port() -> MagicMock:
|
||||
def mock_judge_port() -> AsyncMock:
|
||||
"""Mock JudgePort that returns moderate scores."""
|
||||
port = MagicMock()
|
||||
port = AsyncMock()
|
||||
port.judge_batch.return_value = [
|
||||
(0.5, "Moderate quality, needs improvement."),
|
||||
] * 5
|
||||
@@ -84,10 +84,34 @@ def mock_judge_port() -> MagicMock:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_proposer_port() -> MagicMock:
|
||||
def mock_proposer_port() -> AsyncMock:
|
||||
"""Mock ProposerPort that returns a slightly modified prompt."""
|
||||
port = MagicMock()
|
||||
port = AsyncMock()
|
||||
port.propose.return_value = Prompt(
|
||||
text="You are a very helpful assistant. Answer the question precisely."
|
||||
)
|
||||
return port
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_crossover_port() -> AsyncMock:
|
||||
"""Mock CrossoverPort that combines two parent prompts."""
|
||||
port = AsyncMock()
|
||||
|
||||
async def _crossover(parent_a: Prompt, parent_b: Prompt, task_description: str) -> Prompt:
|
||||
return Prompt(text=f"{parent_a.text} Also, {parent_b.text.lower()}")
|
||||
|
||||
port.crossover = AsyncMock(side_effect=_crossover)
|
||||
return port
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mutation_port() -> AsyncMock:
|
||||
"""Mock MutationPort that paraphrases a prompt."""
|
||||
port = AsyncMock()
|
||||
|
||||
async def _mutate(prompt: Prompt, task_description: str, mutation_type: str = "paraphrase") -> Prompt:
|
||||
return Prompt(text=f"[{mutation_type}] {prompt.text}")
|
||||
|
||||
port.mutate = AsyncMock(side_effect=_mutate)
|
||||
return port
|
||||
|
||||
@@ -16,14 +16,14 @@ def mock_lm() -> dspy.LM:
|
||||
{"output": "Mock output response"},
|
||||
]
|
||||
)
|
||||
dspy.configure(lm=lm)
|
||||
return lm
|
||||
|
||||
|
||||
class TestDSPyLLMAdapter:
|
||||
def test_execute_returns_response(self, mock_lm: dspy.LM) -> None:
|
||||
adapter = DSPyLLMAdapter(model="openai/gpt-4o-mini")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_response(self, mock_lm: dspy.LM) -> None:
|
||||
adapter = DSPyLLMAdapter(lm=mock_lm)
|
||||
prompt = Prompt(text="Answer the question.")
|
||||
result = adapter.execute(prompt, "What is 2+2?")
|
||||
result = await adapter.execute(prompt, "What is 2+2?")
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
300
tests/integration/test_evolution_integration.py
Normal file
300
tests/integration/test_evolution_integration.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Integration tests for multi-iteration evolution with mixed accept/reject."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory
|
||||
from prometheus.domain.ports import JudgePort, LLMPort, ProposerPort
|
||||
|
||||
|
||||
def _make_eval(scores: list[float]) -> EvalResult:
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=["feedback"] * len(scores),
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", s, "feedback", "prompt")
|
||||
for i, s in enumerate(scores)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestMultiIterationEvolution:
|
||||
"""Tests for the evolution loop across multiple iterations."""
|
||||
|
||||
@pytest.fixture
|
||||
def seed_prompt(self) -> Prompt:
|
||||
return Prompt(text="You are a helpful assistant.")
|
||||
|
||||
@pytest.fixture
|
||||
def task_description(self) -> str:
|
||||
return "Answer factual questions."
|
||||
|
||||
@pytest.fixture
|
||||
def synthetic_pool(self) -> list[SyntheticExample]:
|
||||
return [SyntheticExample(input_text=f"input {i}", id=i) for i in range(20)]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_accept_reject(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
task_description: str,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
) -> None:
|
||||
"""Iteration 1: accept, iteration 2: reject, iteration 3: accept."""
|
||||
mock_llm = MagicMock(spec=LLMPort)
|
||||
mock_judge = MagicMock(spec=JudgePort)
|
||||
mock_proposer = MagicMock(spec=ProposerPort)
|
||||
|
||||
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:3]
|
||||
|
||||
# Build eval sequence: initial, then per-iteration (current, new)
|
||||
evals = [
|
||||
_make_eval([0.3, 0.3, 0.3]), # initial seed eval
|
||||
# Iter 1: accept (old=0.4, new=0.8)
|
||||
_make_eval([0.4, 0.4, 0.4]),
|
||||
_make_eval([0.8, 0.8, 0.8]),
|
||||
# Iter 2: reject (old=0.7, new=0.2)
|
||||
_make_eval([0.7, 0.7, 0.7]),
|
||||
_make_eval([0.2, 0.2, 0.2]),
|
||||
# Iter 3: accept (old=0.5, new=0.9)
|
||||
_make_eval([0.5, 0.5, 0.5]),
|
||||
_make_eval([0.9, 0.9, 0.9]),
|
||||
]
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
|
||||
mock_proposer.propose.side_effect = [
|
||||
Prompt(text="Better prompt v1"),
|
||||
Prompt(text="Worse prompt v2"),
|
||||
Prompt(text="Best prompt v3"),
|
||||
]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=3,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert state.iteration == 3
|
||||
assert state.best_candidate is not None
|
||||
assert state.best_candidate.best_score == pytest.approx(2.7) # 0.9*3
|
||||
assert len(state.history) == 3
|
||||
assert state.history[0]["event"] == "accepted"
|
||||
assert state.history[1]["event"] == "rejected"
|
||||
assert state.history[2]["event"] == "accepted"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_rejected_keeps_seed(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
task_description: str,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
) -> None:
|
||||
"""When all proposals are rejected, the seed prompt stays as best."""
|
||||
mock_llm = MagicMock(spec=LLMPort)
|
||||
mock_judge = MagicMock(spec=JudgePort)
|
||||
mock_proposer = MagicMock(spec=ProposerPort)
|
||||
|
||||
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:3]
|
||||
|
||||
evals = [
|
||||
_make_eval([0.5, 0.5, 0.5]), # initial
|
||||
]
|
||||
for _ in range(3):
|
||||
evals.append(_make_eval([0.5, 0.5, 0.5])) # current
|
||||
evals.append(_make_eval([0.1, 0.1, 0.1])) # worse proposal
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
|
||||
mock_proposer.propose.side_effect = [
|
||||
Prompt(text="bad v1"),
|
||||
Prompt(text="bad v2"),
|
||||
Prompt(text="bad v3"),
|
||||
]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=3,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert state.best_candidate.prompt.text == seed_prompt.text
|
||||
assert state.best_candidate.best_score == pytest.approx(1.5) # 0.5*3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_accepted_chain(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
task_description: str,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
) -> None:
|
||||
"""All iterations accept, forming an improvement chain."""
|
||||
mock_llm = MagicMock(spec=LLMPort)
|
||||
mock_judge = MagicMock(spec=JudgePort)
|
||||
mock_proposer = MagicMock(spec=ProposerPort)
|
||||
|
||||
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:2]
|
||||
|
||||
evals = [
|
||||
_make_eval([0.2, 0.2]), # initial
|
||||
]
|
||||
for i in range(1, 5):
|
||||
score = 0.2 + i * 0.15
|
||||
evals.append(_make_eval([score, score])) # current
|
||||
evals.append(_make_eval([score + 0.1, score + 0.1])) # new (accepted)
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
|
||||
mock_proposer.propose.side_effect = [
|
||||
Prompt(text=f"Improved v{i}") for i in range(4)
|
||||
]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=4,
|
||||
minibatch_size=2,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert len(state.candidates) == 5 # seed + 4 accepted
|
||||
assert all(h["event"] == "accepted" for h in state.history)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_recovery_continues_loop(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
task_description: str,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
) -> None:
|
||||
"""When an iteration errors, the loop continues."""
|
||||
mock_llm = MagicMock(spec=LLMPort)
|
||||
mock_judge = MagicMock(spec=JudgePort)
|
||||
mock_proposer = MagicMock(spec=ProposerPort)
|
||||
|
||||
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:2]
|
||||
|
||||
# Eval sequence for 3 iterations:
|
||||
# - iter 1: evaluate current → propose → evaluate new (accepted)
|
||||
# - iter 2: evaluate current → propose (ERROR, no new eval)
|
||||
# - iter 3: evaluate current → propose → evaluate new (accepted)
|
||||
evals = [
|
||||
_make_eval([0.3, 0.3]), # initial
|
||||
_make_eval([0.5, 0.5]), # iter 1 current
|
||||
_make_eval([0.9, 0.9]), # iter 1 new (accepted)
|
||||
_make_eval([0.5, 0.5]), # iter 2 current (proposer errors after this)
|
||||
_make_eval([0.5, 0.5]), # iter 3 current
|
||||
_make_eval([0.8, 0.8]), # iter 3 new (accepted)
|
||||
]
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
|
||||
# Proposer raises on iter 2
|
||||
mock_proposer.propose.side_effect = [
|
||||
Prompt(text="good v1"),
|
||||
RuntimeError("LLM timeout"),
|
||||
Prompt(text="good v3"),
|
||||
]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=2,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert state.iteration == 3
|
||||
assert state.history[1]["event"] == "error"
|
||||
assert "LLM timeout" in state.history[1]["error"]
|
||||
assert state.history[0]["event"] == "accepted"
|
||||
assert state.history[2]["event"] == "accepted"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perfect_score_skips_proposer(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
task_description: str,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
) -> None:
|
||||
"""When all scores are perfect, no proposition is made."""
|
||||
mock_llm = MagicMock(spec=LLMPort)
|
||||
mock_judge = MagicMock(spec=JudgePort)
|
||||
mock_proposer = MagicMock(spec=ProposerPort)
|
||||
|
||||
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:2]
|
||||
|
||||
perfect_eval = _make_eval([1.0, 1.0])
|
||||
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=5,
|
||||
minibatch_size=2,
|
||||
perfect_score=1.0,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
mock_proposer.propose.assert_not_called()
|
||||
assert all(h["event"] == "skip_perfect" for h in state.history)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_counting(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
task_description: str,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
) -> None:
|
||||
"""Verify LLM call counting: 2*N per eval (execute + judge) + 1 per propose."""
|
||||
mock_llm = MagicMock(spec=LLMPort)
|
||||
mock_judge = MagicMock(spec=JudgePort)
|
||||
mock_proposer = MagicMock(spec=ProposerPort)
|
||||
|
||||
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:3]
|
||||
|
||||
evals = [_make_eval([0.3, 0.3, 0.3])] # initial
|
||||
for _ in range(2):
|
||||
evals.append(_make_eval([0.4, 0.4, 0.4]))
|
||||
evals.append(_make_eval([0.6, 0.6, 0.6]))
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
|
||||
mock_proposer.propose.side_effect = [
|
||||
Prompt(text="v1"),
|
||||
Prompt(text="v2"),
|
||||
]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=2,
|
||||
minibatch_size=3,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
# Initial: 2*3=6, Iter1: 2*3 + 1 + 2*3 = 13, Iter2: same = 13
|
||||
# Total: 6 + 13 + 13 = 32
|
||||
assert state.total_llm_calls == 32
|
||||
@@ -1,7 +1,9 @@
|
||||
"""End-to-end pipeline test with mocked LLM calls."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
@@ -23,9 +25,10 @@ def _make_eval(scores: list[float]) -> EvalResult:
|
||||
|
||||
|
||||
class TestFullPipeline:
|
||||
def test_pipeline_produces_result(self) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_produces_result(self) -> None:
|
||||
"""Full pipeline with mocked ports produces an OptimizationResult."""
|
||||
mock_llm = MagicMock(spec=LLMPort)
|
||||
mock_llm = AsyncMock(spec=LLMPort)
|
||||
mock_llm.execute.return_value = "mock response"
|
||||
|
||||
mock_judge = MagicMock(spec=JudgePort)
|
||||
@@ -38,11 +41,11 @@ class TestFullPipeline:
|
||||
eval_sequence.append(_make_eval([0.6, 0.6, 0.6, 0.6, 0.6])) # new eval (accepted)
|
||||
mock_judge.judge_batch.return_value = [(0.5, "ok")] * 5
|
||||
|
||||
mock_proposer = MagicMock(spec=ProposerPort)
|
||||
mock_proposer = AsyncMock(spec=ProposerPort)
|
||||
mock_proposer.propose.return_value = Prompt(text="Improved prompt")
|
||||
|
||||
evaluator = PromptEvaluator(mock_llm, mock_judge)
|
||||
evaluator.evaluate = MagicMock(side_effect=eval_sequence)
|
||||
evaluator.evaluate = AsyncMock(side_effect=eval_sequence)
|
||||
|
||||
mock_gen = MagicMock()
|
||||
mock_gen.generate_inputs.return_value = [
|
||||
@@ -65,7 +68,7 @@ class TestFullPipeline:
|
||||
seed=42,
|
||||
)
|
||||
|
||||
result = use_case.execute(config)
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert result.initial_prompt == "Answer questions."
|
||||
assert result.optimized_prompt == "Improved prompt"
|
||||
|
||||
199
tests/integration/test_ground_truth_eval.py
Normal file
199
tests/integration/test_ground_truth_eval.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Integration test — ground-truth evaluation end-to-end with real similarity metrics."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from prometheus.application.ground_truth_evaluator import GroundTruthEvaluator
|
||||
from prometheus.domain.entities import GroundTruthExample, Prompt
|
||||
from prometheus.domain.ports import LLMPort
|
||||
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
|
||||
from prometheus.infrastructure.similarity import (
|
||||
BleuSimilarity,
|
||||
CosineSimilarity,
|
||||
ExactMatchSimilarity,
|
||||
RougeLSimilarity,
|
||||
create_similarity_adapter,
|
||||
)
|
||||
|
||||
|
||||
def _make_dataset(items: list[tuple[str, str]]) -> list[GroundTruthExample]:
|
||||
return [
|
||||
GroundTruthExample(input_text=inp, expected_output=exp, id=i)
|
||||
for i, (inp, exp) in enumerate(items)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qa_dataset():
|
||||
return _make_dataset([
|
||||
("What is the capital of France?", "Paris"),
|
||||
("What is 2+2?", "4"),
|
||||
("What color is the sky?", "blue"),
|
||||
])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt():
|
||||
return Prompt(text="Answer the following question concisely.")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_executor():
|
||||
"""Returns responses that partially match the ground truth."""
|
||||
port = AsyncMock(spec=LLMPort)
|
||||
port.execute.side_effect = [
|
||||
"Paris is the capital of France.",
|
||||
"The answer is 4.",
|
||||
"The sky is blue.",
|
||||
]
|
||||
return port
|
||||
|
||||
|
||||
class TestGroundTruthIntegrationWithExactMatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match_on_qa(self, mock_executor, qa_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor,
|
||||
similarity=ExactMatchSimilarity(),
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||
# None of the outputs are exact matches with expected outputs
|
||||
assert all(s == 0.0 for s in result.scores)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match_with_exact_outputs(self, qa_dataset, prompt):
|
||||
exact_executor = AsyncMock(spec=LLMPort)
|
||||
exact_executor.execute.side_effect = ["Paris", "4", "blue"]
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=exact_executor,
|
||||
similarity=ExactMatchSimilarity(),
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||
assert all(s == 1.0 for s in result.scores)
|
||||
|
||||
|
||||
class TestGroundTruthIntegrationWithBleu:
|
||||
@pytest.mark.asyncio
|
||||
async def test_bleu_scores_partial_match(self, mock_executor, qa_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor,
|
||||
similarity=BleuSimilarity(),
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||
assert all(0.0 < s < 1.0 for s in result.scores)
|
||||
assert result.mean_score > 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bleu_perfect_match(self, qa_dataset, prompt):
|
||||
perfect_executor = AsyncMock(spec=LLMPort)
|
||||
perfect_executor.execute.side_effect = ["Paris", "4", "blue"]
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=perfect_executor,
|
||||
similarity=BleuSimilarity(),
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||
assert all(s > 0.0 for s in result.scores)
|
||||
|
||||
|
||||
class TestGroundTruthIntegrationWithRouge:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rouge_l_scores(self, mock_executor, qa_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor,
|
||||
similarity=RougeLSimilarity(),
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||
assert all(s > 0.0 for s in result.scores)
|
||||
|
||||
|
||||
class TestGroundTruthIntegrationWithCosine:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cosine_scores(self, mock_executor, qa_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor,
|
||||
similarity=CosineSimilarity(),
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||
assert all(s > 0.0 for s in result.scores)
|
||||
|
||||
|
||||
class TestDatasetLoaderIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_csv_and_evaluate(self, tmp_path, prompt):
|
||||
csv_file = tmp_path / "eval.csv"
|
||||
csv_file.write_text("input,expected_output\nWhat is 2+2?,4\nWhat color is grass?,green\n")
|
||||
|
||||
loader = FileDatasetLoader()
|
||||
dataset = loader.load(str(csv_file))
|
||||
assert len(dataset) == 2
|
||||
|
||||
executor = AsyncMock(spec=LLMPort)
|
||||
executor.execute.side_effect = ["4", "green"]
|
||||
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=executor,
|
||||
similarity=ExactMatchSimilarity(),
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, dataset)
|
||||
assert all(s == 1.0 for s in result.scores)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_json_and_evaluate(self, tmp_path, prompt):
|
||||
json_file = tmp_path / "eval.json"
|
||||
data = [
|
||||
{"input": "What is 2+2?", "expected_output": "4"},
|
||||
{"input": "What color is grass?", "expected_output": "green"},
|
||||
]
|
||||
json_file.write_text(json.dumps(data))
|
||||
|
||||
loader = FileDatasetLoader()
|
||||
dataset = loader.load(str(json_file))
|
||||
assert len(dataset) == 2
|
||||
|
||||
executor = AsyncMock(spec=LLMPort)
|
||||
executor.execute.side_effect = ["4", "not green"]
|
||||
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=executor,
|
||||
similarity=create_similarity_adapter("bleu"),
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, dataset)
|
||||
# First item should score well, second poorly
|
||||
assert result.scores[0] > result.scores[1]
|
||||
|
||||
|
||||
class TestMetricComparison:
|
||||
"""Compare different metrics on the same outputs to ensure they behave differently."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_give_different_scores(self, qa_dataset, prompt):
|
||||
results = {}
|
||||
for metric_name, metric_cls in [
|
||||
("exact", ExactMatchSimilarity),
|
||||
("bleu", BleuSimilarity),
|
||||
("rouge_l", RougeLSimilarity),
|
||||
("cosine", CosineSimilarity),
|
||||
]:
|
||||
executor = AsyncMock(spec=LLMPort)
|
||||
executor.execute.side_effect = [
|
||||
"Paris is the capital of France.",
|
||||
"The answer is 4.",
|
||||
"The sky is blue.",
|
||||
]
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=executor,
|
||||
similarity=metric_cls(),
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, qa_dataset)
|
||||
results[metric_name] = result.mean_score
|
||||
|
||||
# Exact match should be 0 (no exact matches)
|
||||
assert results["exact"] == 0.0
|
||||
# All other metrics should give partial credit
|
||||
assert results["bleu"] > 0.0
|
||||
assert results["rouge_l"] > 0.0
|
||||
assert results["cosine"] > 0.0
|
||||
213
tests/unit/test_adapter_config.py
Normal file
213
tests/unit/test_adapter_config.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Unit tests for multi-model adapter configuration.
|
||||
|
||||
Verifies that each adapter uses its own dspy.LM instance and
|
||||
that per-model api_base/api_key_env overrides are wired correctly.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import dspy
|
||||
import pytest
|
||||
|
||||
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory
|
||||
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
||||
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
||||
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_lm() -> dspy.LM:
|
||||
"""Dummy LM for task execution."""
|
||||
return dspy.utils.DummyLM([{"output": "task model output"}])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def judge_lm() -> dspy.LM:
|
||||
"""Dummy LM for judging (ChainOfThought requires reasoning field)."""
|
||||
return dspy.utils.DummyLM(
|
||||
[
|
||||
{"reasoning": "Evaluating output.", "score": "0.8", "feedback": "Good response.", "dimension_scores": "{}"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def proposer_lm() -> dspy.LM:
|
||||
"""Dummy LM for proposing (ChainOfThought requires reasoning field)."""
|
||||
return dspy.utils.DummyLM(
|
||||
[
|
||||
{"reasoning": "Analyzing failures.", "new_instruction": "Improved prompt: be more specific."},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def synth_lm() -> dspy.LM:
|
||||
"""Dummy LM for synthetic generation (ChainOfThought requires reasoning field)."""
|
||||
return dspy.utils.DummyLM(
|
||||
[
|
||||
{"reasoning": "Generating examples.", "examples": json.dumps(["input 1", "input 2", "input 3"])},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestDSPyLLMAdapterOwnLM:
|
||||
"""Bug #2 fix: DSPyLLMAdapter must use the LM it receives, not the global one."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_provided_lm_not_global(self) -> None:
|
||||
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
|
||||
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
|
||||
dspy.configure(lm=global_lm)
|
||||
|
||||
adapter = DSPyLLMAdapter(lm=local_lm)
|
||||
result = await adapter.execute(Prompt(text="test"), "input")
|
||||
|
||||
assert result == "local response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_affect_global_lm(self) -> None:
|
||||
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
|
||||
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
|
||||
dspy.configure(lm=global_lm)
|
||||
|
||||
adapter = DSPyLLMAdapter(lm=local_lm)
|
||||
await adapter.execute(Prompt(text="test"), "input")
|
||||
|
||||
# Global LM should still be the same
|
||||
assert dspy.settings.lm is global_lm
|
||||
|
||||
|
||||
class TestDSPyJudgeAdapterOwnLM:
|
||||
"""DSPyJudgeAdapter must use its own LM instance."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_provided_lm(self, judge_lm: dspy.LM) -> None:
|
||||
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||
results = await adapter.judge_batch(
|
||||
task_description="Test task",
|
||||
pairs=[("input 1", "output 1")],
|
||||
)
|
||||
assert len(results) == 1
|
||||
score, feedback = results[0]
|
||||
assert score == 0.8
|
||||
assert feedback == "Good response."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_use_global_lm(self) -> None:
|
||||
judge_lm = dspy.utils.DummyLM(
|
||||
[{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response", "dimension_scores": "{}"}]
|
||||
)
|
||||
global_lm = dspy.utils.DummyLM([{"reasoning": "no", "score": "0.1", "feedback": "Wrong LM!", "dimension_scores": "{}"}])
|
||||
dspy.configure(lm=global_lm)
|
||||
|
||||
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||
results = await adapter.judge_batch("task", [("in", "out")])
|
||||
assert results[0][0] == 0.9
|
||||
|
||||
|
||||
class TestDSPyProposerAdapterOwnLM:
|
||||
"""DSPyProposerAdapter must use its own LM instance."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_provided_lm(self, proposer_lm: dspy.LM) -> None:
|
||||
adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||
trajectories = [
|
||||
Trajectory(
|
||||
input_text="test input",
|
||||
output_text="test output",
|
||||
score=0.3,
|
||||
feedback="bad",
|
||||
prompt_used="old prompt",
|
||||
)
|
||||
]
|
||||
result = await adapter.propose(
|
||||
current_prompt=Prompt(text="old prompt"),
|
||||
trajectories=trajectories,
|
||||
task_description="Test task",
|
||||
)
|
||||
assert "Improved prompt" in result.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_use_global_lm(self) -> None:
|
||||
proposer_lm = dspy.utils.DummyLM(
|
||||
[{"reasoning": "ok", "new_instruction": "proposer-specific"}]
|
||||
)
|
||||
global_lm = dspy.utils.DummyLM(
|
||||
[{"reasoning": "no", "new_instruction": "wrong-global"}]
|
||||
)
|
||||
dspy.configure(lm=global_lm)
|
||||
|
||||
adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||
result = await adapter.propose(
|
||||
current_prompt=Prompt(text="test"),
|
||||
trajectories=[],
|
||||
task_description="task",
|
||||
)
|
||||
assert result.text == "proposer-specific"
|
||||
|
||||
|
||||
class TestDSPySyntheticAdapterOwnLM:
|
||||
"""DSPySyntheticAdapter must use its own LM instance."""
|
||||
|
||||
def test_uses_provided_lm(self, synth_lm: dspy.LM) -> None:
|
||||
adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||
results = adapter.generate_inputs("Test task", 3)
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(ex, SyntheticExample) for ex in results)
|
||||
|
||||
def test_does_not_use_global_lm(self) -> None:
|
||||
synth_lm = dspy.utils.DummyLM(
|
||||
[{"reasoning": "ok", "examples": json.dumps(["synth-specific"])}]
|
||||
)
|
||||
global_lm = dspy.utils.DummyLM(
|
||||
[{"reasoning": "no", "examples": json.dumps(["wrong-global"])}]
|
||||
)
|
||||
dspy.configure(lm=global_lm)
|
||||
|
||||
adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||
results = adapter.generate_inputs("task", 1)
|
||||
assert results[0].input_text == "synth-specific"
|
||||
|
||||
|
||||
class TestPerModelOverrides:
|
||||
"""Verify that per-model api_base/api_key_env are passed through to dspy.LM."""
|
||||
|
||||
@patch("prometheus.cli.commands.optimize.dspy.LM")
|
||||
def test_per_model_api_base_override(self, mock_lm_cls: MagicMock) -> None:
|
||||
"""Per-model api_base should be used instead of global."""
|
||||
mock_lm_cls.return_value = MagicMock()
|
||||
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="test",
|
||||
task_description="test",
|
||||
task_model="openai/gpt-4o-mini",
|
||||
judge_model="openai/gpt-4o",
|
||||
proposer_model="openai/gpt-4o",
|
||||
synth_model="openai/gpt-4o",
|
||||
judge_api_base="https://judge.example.com/v1",
|
||||
judge_api_key_env="JUDGE_API_KEY",
|
||||
)
|
||||
|
||||
# Verify config carries the overrides
|
||||
assert config.judge_api_base == "https://judge.example.com/v1"
|
||||
assert config.judge_api_key_env == "JUDGE_API_KEY"
|
||||
assert config.task_api_base is None
|
||||
|
||||
def test_config_defaults_to_none(self) -> None:
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
|
||||
config = OptimizationConfig(seed_prompt="test", task_description="test")
|
||||
assert config.task_api_base is None
|
||||
assert config.task_api_key_env is None
|
||||
assert config.judge_api_base is None
|
||||
assert config.judge_api_key_env is None
|
||||
assert config.proposer_api_base is None
|
||||
assert config.proposer_api_key_env is None
|
||||
assert config.synth_api_base is None
|
||||
assert config.synth_api_key_env is None
|
||||
294
tests/unit/test_adapters.py
Normal file
294
tests/unit/test_adapters.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Unit tests for infrastructure adapters — LLM, Judge, Proposer, Synthetic.
|
||||
|
||||
Uses mocked DSPy modules to isolate adapter logic from LLM calls.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import dspy
|
||||
import pytest
|
||||
|
||||
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory
|
||||
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
||||
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
||||
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
||||
|
||||
|
||||
# --- LLM Adapter ---
|
||||
|
||||
|
||||
class TestDSPyLLMAdapter:
|
||||
"""Tests for DSPyLLMAdapter.execute()."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lm(self) -> MagicMock:
|
||||
return MagicMock(spec=dspy.LM)
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, mock_lm: MagicMock) -> DSPyLLMAdapter:
|
||||
return DSPyLLMAdapter(lm=mock_lm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_output_string(
|
||||
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
mock_predictor = MagicMock()
|
||||
mock_predictor.return_value = MagicMock(output="Hello response")
|
||||
adapter._predictor = mock_predictor
|
||||
|
||||
prompt = Prompt(text="Say hello.")
|
||||
result = await adapter.execute(prompt, "Hi there")
|
||||
|
||||
assert result == "Hello response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_passes_prompt_text_and_input(
|
||||
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
mock_predictor = MagicMock()
|
||||
mock_predictor.return_value = MagicMock(output="response")
|
||||
adapter._predictor = mock_predictor
|
||||
|
||||
prompt = Prompt(text="Translate this.")
|
||||
await adapter.execute(prompt, "Hello world")
|
||||
|
||||
mock_predictor.assert_called_once_with(
|
||||
instruction="Translate this.",
|
||||
input_text="Hello world",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_uses_dspy_context(
|
||||
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
mock_predictor = MagicMock()
|
||||
mock_predictor.return_value = MagicMock(output="ok")
|
||||
adapter._predictor = mock_predictor
|
||||
|
||||
with patch("prometheus.infrastructure.llm_adapter.dspy.context") as mock_ctx:
|
||||
await adapter.execute(Prompt(text="test"), "input")
|
||||
mock_ctx.assert_called_once_with(lm=mock_lm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_converts_output_to_str(
|
||||
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
mock_predictor = MagicMock()
|
||||
mock_predictor.return_value = MagicMock(output=42)
|
||||
adapter._predictor = mock_predictor
|
||||
|
||||
result = await adapter.execute(Prompt(text="test"), "input")
|
||||
assert isinstance(result, str)
|
||||
assert result == "42"
|
||||
|
||||
|
||||
# --- Judge Adapter ---
|
||||
|
||||
|
||||
class TestDSPyJudgeAdapter:
|
||||
"""Tests for DSPyJudgeAdapter.judge_batch()."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lm(self) -> MagicMock:
|
||||
return MagicMock(spec=dspy.LM)
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, mock_lm: MagicMock) -> DSPyJudgeAdapter:
|
||||
return DSPyJudgeAdapter(lm=mock_lm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_batch_returns_scores_and_feedback(
|
||||
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._judge = MagicMock()
|
||||
adapter._judge.side_effect = [
|
||||
MagicMock(score=0.9, feedback="Excellent."),
|
||||
MagicMock(score=0.4, feedback="Incomplete."),
|
||||
]
|
||||
|
||||
pairs = [("What is 2+2?", "4"), ("Capital of France?", "London")]
|
||||
result = await adapter.judge_batch("math and geography", pairs)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] == (0.9, "Excellent.")
|
||||
assert result[1] == (0.4, "Incomplete.")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_batch_empty_pairs(
|
||||
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
result = await adapter.judge_batch("task", [])
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_batch_uses_dspy_context(
|
||||
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._judge = MagicMock()
|
||||
adapter._judge.return_value = MagicMock(score=0.5, feedback="ok")
|
||||
|
||||
with patch("prometheus.infrastructure.judge_adapter.dspy.context") as mock_ctx:
|
||||
await adapter.judge_batch("task", [("in", "out")])
|
||||
mock_ctx.assert_called_once_with(lm=mock_lm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_batch_returns_all_results(
|
||||
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
"""Judge calls run in parallel but all results are returned."""
|
||||
adapter._judge = MagicMock()
|
||||
adapter._judge.side_effect = [
|
||||
MagicMock(score=0.5, feedback="ok"),
|
||||
MagicMock(score=0.7, feedback="better"),
|
||||
MagicMock(score=0.3, feedback="worse"),
|
||||
]
|
||||
|
||||
pairs = [("first", "out1"), ("second", "out2"), ("third", "out3")]
|
||||
results = await adapter.judge_batch("task", pairs)
|
||||
|
||||
assert len(results) == 3
|
||||
scores = [r[0] for r in results]
|
||||
assert 0.5 in scores
|
||||
assert 0.7 in scores
|
||||
assert 0.3 in scores
|
||||
|
||||
|
||||
# --- Proposer Adapter ---
|
||||
|
||||
|
||||
class TestDSPyProposerAdapter:
|
||||
"""Tests for DSPyProposerAdapter.propose()."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lm(self) -> MagicMock:
|
||||
return MagicMock(spec=dspy.LM)
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, mock_lm: MagicMock) -> DSPyProposerAdapter:
|
||||
return DSPyProposerAdapter(lm=mock_lm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propose_returns_new_prompt(
|
||||
self, adapter: DSPyProposerAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._proposer = MagicMock()
|
||||
adapter._proposer.return_value = MagicMock(
|
||||
new_instruction="Be concise and accurate."
|
||||
)
|
||||
|
||||
current = Prompt(text="Answer questions.")
|
||||
trajectories = [
|
||||
Trajectory("in", "out", 0.3, "too verbose", "Answer questions.")
|
||||
]
|
||||
result = await adapter.propose(current, trajectories, "Q&A task")
|
||||
|
||||
assert isinstance(result, Prompt)
|
||||
assert result.text == "Be concise and accurate."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propose_uses_dspy_context(
|
||||
self, adapter: DSPyProposerAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._proposer = MagicMock()
|
||||
adapter._proposer.return_value = MagicMock(new_instruction="improved")
|
||||
|
||||
with patch("prometheus.infrastructure.proposer_adapter.dspy.context") as mock_ctx:
|
||||
await adapter.propose(Prompt(text="t"), [], "task")
|
||||
mock_ctx.assert_called_once_with(lm=mock_lm)
|
||||
|
||||
def test_format_failures_single_trajectory(self) -> None:
|
||||
trajectories = [
|
||||
Trajectory("What is AI?", "A type of robot.", 0.3, "Incomplete definition.", "prompt")
|
||||
]
|
||||
result = DSPyProposerAdapter._format_failures(trajectories)
|
||||
|
||||
assert "What is AI?" in result
|
||||
assert "A type of robot." in result
|
||||
assert "0.30" in result
|
||||
assert "Incomplete definition." in result
|
||||
assert "# Example 1" in result
|
||||
|
||||
def test_format_failures_multiple_trajectories(self) -> None:
|
||||
trajectories = [
|
||||
Trajectory("input1", "output1", 0.4, "bad", "prompt"),
|
||||
Trajectory("input2", "output2", 0.2, "worse", "prompt"),
|
||||
]
|
||||
result = DSPyProposerAdapter._format_failures(trajectories)
|
||||
|
||||
assert "# Example 1" in result
|
||||
assert "# Example 2" in result
|
||||
assert "---" in result
|
||||
assert "input1" in result
|
||||
assert "input2" in result
|
||||
|
||||
def test_format_failures_empty_list(self) -> None:
|
||||
result = DSPyProposerAdapter._format_failures([])
|
||||
assert result == ""
|
||||
|
||||
|
||||
# --- Synthetic Adapter ---
|
||||
|
||||
|
||||
class TestDSPySyntheticAdapter:
|
||||
"""Tests for DSPySyntheticAdapter.generate_inputs()."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lm(self) -> MagicMock:
|
||||
return MagicMock(spec=dspy.LM)
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, mock_lm: MagicMock) -> DSPySyntheticAdapter:
|
||||
return DSPySyntheticAdapter(lm=mock_lm)
|
||||
|
||||
def test_generate_inputs_returns_examples(
|
||||
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._generator = MagicMock()
|
||||
adapter._generator.return_value = MagicMock(
|
||||
examples=["What is AI?", "Explain ML.", "What is NLP?"]
|
||||
)
|
||||
|
||||
result = adapter.generate_inputs("AI task", 3)
|
||||
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(ex, SyntheticExample) for ex in result)
|
||||
assert result[0].input_text == "What is AI?"
|
||||
assert result[0].id == 0
|
||||
assert result[1].id == 1
|
||||
|
||||
def test_generate_inputs_truncates_to_n(
|
||||
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._generator = MagicMock()
|
||||
adapter._generator.return_value = MagicMock(
|
||||
examples=["q1", "q2", "q3", "q4", "q5"]
|
||||
)
|
||||
|
||||
result = adapter.generate_inputs("task", 3)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
def test_generate_inputs_passes_correct_args(
|
||||
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._generator = MagicMock()
|
||||
adapter._generator.return_value = MagicMock(examples=["q1"])
|
||||
|
||||
adapter.generate_inputs("my task", 5)
|
||||
|
||||
adapter._generator.assert_called_once_with(
|
||||
task_description="my task",
|
||||
n_examples=5,
|
||||
)
|
||||
|
||||
def test_generate_inputs_empty_list(
|
||||
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._generator = MagicMock()
|
||||
adapter._generator.return_value = MagicMock(examples=[])
|
||||
|
||||
result = adapter.generate_inputs("task", 0)
|
||||
|
||||
assert result == []
|
||||
333
tests/unit/test_checkpoint.py
Normal file
333
tests/unit/test_checkpoint.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Unit tests for checkpoint & resume functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
EvalResult,
|
||||
OptimizationState,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
Trajectory,
|
||||
)
|
||||
from prometheus.infrastructure.checkpoint import JsonCheckpointPersistence
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JsonCheckpointPersistence — save/load round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJsonCheckpointPersistence:
|
||||
def test_roundtrip_full_state(self, tmp_path: Path) -> None:
|
||||
"""Saving and loading preserves all fields."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
|
||||
|
||||
state = OptimizationState(
|
||||
iteration=7,
|
||||
best_candidate=Candidate(
|
||||
prompt=Prompt(text="best prompt", metadata={"source": "test"}),
|
||||
best_score=0.92,
|
||||
generation=5,
|
||||
),
|
||||
candidates=[
|
||||
Candidate(prompt=Prompt(text="p1"), best_score=0.5, generation=0),
|
||||
Candidate(prompt=Prompt(text="p2"), best_score=0.92, generation=5),
|
||||
],
|
||||
synthetic_pool=[
|
||||
SyntheticExample(input_text="q1", category="cat_a", id=0),
|
||||
SyntheticExample(input_text="q2", category="cat_b", id=1),
|
||||
],
|
||||
history=[{"iteration": 1, "event": "accepted", "old_score": 0.5, "new_score": 0.7}],
|
||||
total_llm_calls=42,
|
||||
)
|
||||
|
||||
ckpt.save(state)
|
||||
assert ckpt.latest_exists()
|
||||
|
||||
loaded = ckpt.load()
|
||||
assert loaded is not None
|
||||
assert loaded.iteration == 7
|
||||
assert loaded.total_llm_calls == 42
|
||||
assert loaded.best_candidate is not None
|
||||
assert loaded.best_candidate.prompt.text == "best prompt"
|
||||
assert loaded.best_candidate.prompt.metadata == {"source": "test"}
|
||||
assert loaded.best_candidate.best_score == 0.92
|
||||
assert len(loaded.candidates) == 2
|
||||
assert len(loaded.synthetic_pool) == 2
|
||||
assert loaded.synthetic_pool[0].input_text == "q1"
|
||||
assert loaded.synthetic_pool[1].category == "cat_b"
|
||||
assert loaded.history[0]["event"] == "accepted"
|
||||
|
||||
def test_load_returns_none_when_no_checkpoint(self, tmp_path: Path) -> None:
|
||||
"""Loading from empty dir returns None."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "nope")
|
||||
assert ckpt.load() is None
|
||||
assert not ckpt.latest_exists()
|
||||
|
||||
def test_creates_directory_on_save(self, tmp_path: Path) -> None:
|
||||
"""Save creates the directory tree if it doesn't exist."""
|
||||
deep_dir = tmp_path / "a" / "b" / "c"
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=deep_dir)
|
||||
state = OptimizationState(iteration=1)
|
||||
ckpt.save(state)
|
||||
assert (deep_dir / "latest.json").exists()
|
||||
|
||||
def test_overwrites_previous_checkpoint(self, tmp_path: Path) -> None:
|
||||
"""Second save overwrites the first."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
|
||||
|
||||
ckpt.save(OptimizationState(iteration=1, total_llm_calls=10))
|
||||
ckpt.save(OptimizationState(iteration=5, total_llm_calls=50))
|
||||
|
||||
loaded = ckpt.load()
|
||||
assert loaded is not None
|
||||
assert loaded.iteration == 5
|
||||
assert loaded.total_llm_calls == 50
|
||||
|
||||
def test_json_is_human_readable(self, tmp_path: Path) -> None:
|
||||
"""Checkpoint file is valid, pretty-printed JSON."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
|
||||
state = OptimizationState(
|
||||
iteration=3,
|
||||
best_candidate=Candidate(prompt=Prompt(text="hello"), best_score=0.8),
|
||||
)
|
||||
ckpt.save(state)
|
||||
|
||||
raw = json.loads((tmp_path / "latest.json").read_text())
|
||||
assert raw["schema_version"] == 1
|
||||
assert raw["iteration"] == 3
|
||||
assert raw["best_candidate"]["prompt_text"] == "hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EvolutionLoop — checkpoint integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEvolutionCheckpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_saved_on_interval(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""Checkpoint is saved every checkpoint_interval iterations."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
# All iterations accepted so checkpoint triggers
|
||||
good_eval = EvalResult(
|
||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"out{i}", s, "ok", "p")
|
||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||
],
|
||||
)
|
||||
better_eval = EvalResult(
|
||||
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
# initial_eval + 5 iterations (each needs old_eval + new_eval)
|
||||
evaluator.evaluate = AsyncMock(
|
||||
side_effect=[good_eval] # initial
|
||||
+ [good_eval, better_eval] * 5 # 5 iterations
|
||||
)
|
||||
|
||||
proposer = AsyncMock()
|
||||
proposer.propose.return_value = Prompt(text="improved prompt")
|
||||
|
||||
checkpoint_port = MagicMock()
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=5,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=checkpoint_port,
|
||||
checkpoint_interval=2,
|
||||
)
|
||||
|
||||
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
# Checkpoint at iterations 2, 4 (every 2nd)
|
||||
save_calls = checkpoint_port.save.call_count
|
||||
assert save_calls >= 2 # at least at iters 2 and 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_checkpoint_without_port(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""No checkpointing happens when checkpoint_port is None (default)."""
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
perfect_eval = EvalResult(
|
||||
scores=[1.0] * 5,
|
||||
feedbacks=["perfect"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", 1.0, "perfect", "p")
|
||||
for i in range(5)
|
||||
],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=AsyncMock(),
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=None,
|
||||
)
|
||||
# Should run without error — no checkpoint port, no crash
|
||||
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_skips_seed_evaluation(
|
||||
self,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""When initial_state is provided, seed eval is skipped and loop starts from saved iteration."""
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
proposer = AsyncMock()
|
||||
proposer.propose.return_value = Prompt(text="new prompt")
|
||||
|
||||
# Only return evaluations for resumed iterations (1 iter: old_eval + new_eval)
|
||||
old_eval = EvalResult(
|
||||
scores=[0.5] * 5,
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", 0.5, "ok", "p") for i in range(5)
|
||||
],
|
||||
)
|
||||
new_eval = EvalResult(
|
||||
scores=[0.8] * 5,
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
|
||||
|
||||
# Create a state simulating checkpoint at iteration 4
|
||||
initial_state = OptimizationState(
|
||||
iteration=4,
|
||||
best_candidate=Candidate(
|
||||
prompt=Prompt(text="checkpoint prompt"), best_score=2.5, generation=4
|
||||
),
|
||||
candidates=[Candidate(prompt=Prompt(text="checkpoint prompt"), best_score=2.5)],
|
||||
total_llm_calls=40,
|
||||
)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=5, # only iteration 5 remains
|
||||
minibatch_size=5,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt=Prompt(text="seed"),
|
||||
synthetic_pool=synthetic_pool,
|
||||
task_description=task_description,
|
||||
initial_state=initial_state,
|
||||
)
|
||||
|
||||
# Should have run only 1 iteration (iter 5)
|
||||
assert state.iteration == 5
|
||||
# total_llm_calls should include the 40 from checkpoint + new calls
|
||||
assert state.total_llm_calls > 40
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_save_and_resume_roundtrip(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""End-to-end: run a few iterations, checkpoint, resume, finish."""
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
old_eval = EvalResult(
|
||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", s, "ok", "p")
|
||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||
],
|
||||
)
|
||||
new_eval = EvalResult(
|
||||
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(
|
||||
side_effect=[old_eval, old_eval, new_eval, old_eval, new_eval]
|
||||
)
|
||||
proposer = AsyncMock()
|
||||
proposer.propose.return_value = Prompt(text="improved prompt")
|
||||
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=2,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=ckpt,
|
||||
checkpoint_interval=1,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
assert state.iteration == 2
|
||||
assert ckpt.latest_exists()
|
||||
|
||||
# Capture the checkpoint state *before* resume (state is mutated in-place)
|
||||
loaded = ckpt.load()
|
||||
assert loaded is not None
|
||||
saved_llm_calls = loaded.total_llm_calls
|
||||
saved_iteration = loaded.iteration
|
||||
|
||||
# Set up evaluator for resumed run (just 1 more iteration)
|
||||
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
|
||||
proposer.propose.return_value = Prompt(text="even better prompt")
|
||||
|
||||
loop2 = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=ckpt,
|
||||
checkpoint_interval=1,
|
||||
)
|
||||
resumed = await loop2.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
initial_state=loaded,
|
||||
)
|
||||
assert resumed.iteration == 3
|
||||
assert resumed.total_llm_calls > saved_llm_calls
|
||||
assert resumed.iteration > saved_iteration
|
||||
278
tests/unit/test_cli.py
Normal file
278
tests/unit/test_cli.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Tests for the CLI interface — prometheus optimize, version, etc.
|
||||
|
||||
Uses Typer's CliRunner for isolated command testing.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from prometheus.application.dto import OptimizationResult
|
||||
from prometheus.cli.app import app
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
class TestCLIOptimize:
|
||||
"""Tests for the `prometheus optimize` command."""
|
||||
|
||||
def _write_config(self, tmp_path: Path, **overrides: object) -> Path:
|
||||
"""Write a minimal valid config YAML and return its path."""
|
||||
data = {
|
||||
"seed_prompt": "You are a helpful assistant.",
|
||||
"task_description": "Answer factual questions accurately.",
|
||||
}
|
||||
data.update(overrides)
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
return config_file
|
||||
|
||||
def test_optimize_with_valid_config(self, tmp_path: Path) -> None:
|
||||
config_file = self._write_config(tmp_path)
|
||||
output_file = tmp_path / "output.yaml"
|
||||
|
||||
mock_result = OptimizationResult(
|
||||
optimized_prompt="Improved prompt",
|
||||
initial_prompt="You are a helpful assistant.",
|
||||
iterations_used=5,
|
||||
total_llm_calls=50,
|
||||
initial_score=0.3,
|
||||
final_score=0.9,
|
||||
improvement=0.6,
|
||||
history=[],
|
||||
)
|
||||
|
||||
mock_uc = AsyncMock()
|
||||
mock_uc.execute.return_value = mock_result
|
||||
|
||||
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||
mock_llm_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||
mock_judge_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||
mock_prop_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"optimize",
|
||||
"-i",
|
||||
str(config_file),
|
||||
"-o",
|
||||
str(output_file),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Optimized Prompt" in result.output
|
||||
|
||||
def test_optimize_missing_input_file(self) -> None:
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["optimize", "-i", "/nonexistent/config.yaml"],
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
|
||||
def test_optimize_with_verbose_flag(self, tmp_path: Path) -> None:
|
||||
config_file = self._write_config(tmp_path)
|
||||
output_file = tmp_path / "output.yaml"
|
||||
|
||||
mock_result = OptimizationResult(
|
||||
optimized_prompt="Improved",
|
||||
initial_prompt="test",
|
||||
iterations_used=1,
|
||||
total_llm_calls=10,
|
||||
initial_score=0.3,
|
||||
final_score=0.8,
|
||||
improvement=0.5,
|
||||
history=[],
|
||||
)
|
||||
|
||||
mock_uc = AsyncMock()
|
||||
mock_uc.execute.return_value = mock_result
|
||||
|
||||
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||
mock_llm_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||
mock_judge_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||
mock_prop_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"optimize",
|
||||
"-i",
|
||||
str(config_file),
|
||||
"-o",
|
||||
str(output_file),
|
||||
"-v",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_optimize_displays_metrics(self, tmp_path: Path) -> None:
|
||||
config_file = self._write_config(tmp_path)
|
||||
output_file = tmp_path / "output.yaml"
|
||||
|
||||
mock_result = OptimizationResult(
|
||||
optimized_prompt="Better prompt",
|
||||
initial_prompt="test",
|
||||
iterations_used=3,
|
||||
total_llm_calls=30,
|
||||
initial_score=0.40,
|
||||
final_score=0.85,
|
||||
improvement=0.45,
|
||||
history=[],
|
||||
)
|
||||
|
||||
mock_uc = AsyncMock()
|
||||
mock_uc.execute.return_value = mock_result
|
||||
|
||||
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||
mock_llm_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||
mock_judge_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||
mock_prop_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"optimize",
|
||||
"-i",
|
||||
str(config_file),
|
||||
"-o",
|
||||
str(output_file),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "0.40" in result.output
|
||||
assert "0.85" in result.output
|
||||
assert "+0.45" in result.output
|
||||
|
||||
def test_optimize_with_max_concurrency_flag(self, tmp_path: Path) -> None:
|
||||
config_file = self._write_config(tmp_path)
|
||||
output_file = tmp_path / "output.yaml"
|
||||
|
||||
mock_result = OptimizationResult(
|
||||
optimized_prompt="Better prompt",
|
||||
initial_prompt="test",
|
||||
iterations_used=1,
|
||||
total_llm_calls=10,
|
||||
initial_score=0.3,
|
||||
final_score=0.8,
|
||||
improvement=0.5,
|
||||
history=[],
|
||||
)
|
||||
|
||||
mock_uc = AsyncMock()
|
||||
mock_uc.execute.return_value = mock_result
|
||||
|
||||
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||
mock_llm_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||
mock_judge_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||
mock_prop_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"optimize",
|
||||
"-i",
|
||||
str(config_file),
|
||||
"-o",
|
||||
str(output_file),
|
||||
"--max-concurrency",
|
||||
"10",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestCLIHelp:
|
||||
"""Tests for CLI help and no-args behavior."""
|
||||
|
||||
def test_no_args_shows_help(self) -> None:
|
||||
result = runner.invoke(app, [])
|
||||
# Typer uses exit code 2 when no_args_is_help=True
|
||||
assert result.exit_code in (0, 2)
|
||||
assert "PROMETHEUS" in result.output or "Usage" in result.output
|
||||
|
||||
def test_optimize_help(self) -> None:
|
||||
result = runner.invoke(app, ["optimize", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "input" in result.output.lower() or "INPUT" in result.output
|
||||
|
||||
def test_version_help(self) -> None:
|
||||
result = runner.invoke(app, ["version", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_init_help(self) -> None:
|
||||
result = runner.invoke(app, ["init", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_list_help(self) -> None:
|
||||
result = runner.invoke(app, ["list", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestCLIVersion:
|
||||
"""Tests for the `prometheus version` command."""
|
||||
|
||||
def test_version_prints_version(self) -> None:
|
||||
result = runner.invoke(app, ["version"])
|
||||
assert result.exit_code == 0
|
||||
assert "PROMETHEUS" in result.output
|
||||
assert "0.1.0" in result.output
|
||||
|
||||
|
||||
class TestCLIList:
|
||||
"""Tests for the `prometheus list` command."""
|
||||
|
||||
def test_list_no_runs(self, tmp_path: Path) -> None:
|
||||
result = runner.invoke(app, ["list", "-d", str(tmp_path)])
|
||||
assert result.exit_code == 0
|
||||
assert "No optimization runs found" in result.output
|
||||
|
||||
def test_list_with_result(self, tmp_path: Path) -> None:
|
||||
result_data = {
|
||||
"optimized_prompt": "Better prompt for testing",
|
||||
"initial_prompt": "test",
|
||||
"iterations_used": 5,
|
||||
"total_llm_calls": 50,
|
||||
"initial_score": 0.30,
|
||||
"final_score": 0.90,
|
||||
"improvement": 0.60,
|
||||
"history": [],
|
||||
}
|
||||
result_file = tmp_path / "output.yaml"
|
||||
import yaml as _yaml
|
||||
with open(result_file, "w") as f:
|
||||
_yaml.dump(result_data, f)
|
||||
|
||||
result = runner.invoke(app, ["list", "-d", str(tmp_path)])
|
||||
assert result.exit_code == 0
|
||||
assert "0.30" in result.output
|
||||
assert "0.90" in result.output
|
||||
|
||||
def test_list_nonexistent_directory(self) -> None:
|
||||
result = runner.invoke(app, ["list", "-d", "/nonexistent/dir"])
|
||||
assert result.exit_code == 1
|
||||
332
tests/unit/test_config.py
Normal file
332
tests/unit/test_config.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""Unit tests for config and config loading via CLI.
|
||||
|
||||
Tests config validation scenarios: missing fields, wrong types, defaults.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
from prometheus.infrastructure.file_io import YamlPersistence
|
||||
|
||||
|
||||
class TestOptimizationConfig:
|
||||
"""Tests for OptimizationConfig Pydantic model defaults."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="test prompt",
|
||||
task_description="test task",
|
||||
)
|
||||
assert config.task_model == "openai/gpt-4o-mini"
|
||||
assert config.judge_model == "openai/gpt-4o"
|
||||
assert config.proposer_model == "openai/gpt-4o"
|
||||
assert config.synth_model == "openai/gpt-4o"
|
||||
assert config.max_iterations == 30
|
||||
assert config.n_synthetic_inputs == 20
|
||||
assert config.minibatch_size == 5
|
||||
assert config.perfect_score == 1.0
|
||||
assert config.seed == 42
|
||||
assert config.output_path == "output.yaml"
|
||||
assert config.verbose is False
|
||||
assert config.error_strategy == "retry"
|
||||
assert config.max_retries == 3
|
||||
assert config.max_concurrency == 5
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="custom prompt",
|
||||
task_description="custom task",
|
||||
max_iterations=100,
|
||||
minibatch_size=10,
|
||||
seed=123,
|
||||
verbose=True,
|
||||
)
|
||||
assert config.max_iterations == 100
|
||||
assert config.minibatch_size == 10
|
||||
assert config.seed == 123
|
||||
assert config.verbose is True
|
||||
|
||||
def test_roundtrip_to_dict(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="test",
|
||||
task_description="task",
|
||||
)
|
||||
d = config.model_dump()
|
||||
assert d["seed_prompt"] == "test"
|
||||
assert d["task_description"] == "task"
|
||||
assert "history" not in d # OptimizationResult has history, not config
|
||||
|
||||
def test_config_version_defaults(self) -> None:
|
||||
config = OptimizationConfig(seed_prompt="a", task_description="b")
|
||||
assert config.config_version == 1
|
||||
|
||||
def test_config_version_stamps_current(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", config_version=0,
|
||||
)
|
||||
# Migration always stamps current version
|
||||
assert config.config_version == 1
|
||||
|
||||
|
||||
class TestConfigLoading:
|
||||
"""Tests for loading OptimizationConfig from YAML via YamlPersistence."""
|
||||
|
||||
def test_minimal_config_loads(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
data = {
|
||||
"seed_prompt": "You are helpful.",
|
||||
"task_description": "Answer questions.",
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
config = OptimizationConfig.model_validate(raw)
|
||||
|
||||
assert config.seed_prompt == "You are helpful."
|
||||
assert config.task_description == "Answer questions."
|
||||
assert config.max_iterations == 30 # default
|
||||
|
||||
def test_full_config_loads(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
data = {
|
||||
"seed_prompt": "You are helpful.",
|
||||
"task_description": "Answer questions.",
|
||||
"task_model": "openai/gpt-4o",
|
||||
"judge_model": "openai/gpt-4o-mini",
|
||||
"max_iterations": 50,
|
||||
"n_synthetic_inputs": 30,
|
||||
"minibatch_size": 8,
|
||||
"seed": 99,
|
||||
"verbose": True,
|
||||
}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
config = OptimizationConfig.model_validate(raw)
|
||||
|
||||
assert config.task_model == "openai/gpt-4o"
|
||||
assert config.max_iterations == 50
|
||||
assert config.verbose is True
|
||||
|
||||
def test_missing_seed_prompt_raises_validation_error(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
data = {"task_description": "Answer questions."}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
|
||||
with pytest.raises(ValidationError, match="seed_prompt"):
|
||||
OptimizationConfig.model_validate(raw)
|
||||
|
||||
def test_missing_task_description_raises_validation_error(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
data = {"seed_prompt": "You are helpful."}
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
|
||||
with pytest.raises(ValidationError, match="task_description"):
|
||||
OptimizationConfig.model_validate(raw)
|
||||
|
||||
def test_empty_yaml_raises_on_required_fields(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
config_file = tmp_path / "empty.yaml"
|
||||
config_file.write_text("{}", encoding="utf-8")
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
OptimizationConfig.model_validate(raw)
|
||||
|
||||
def test_partial_config_uses_defaults(self, tmp_path: Path) -> None:
|
||||
persistence = YamlPersistence()
|
||||
data = {
|
||||
"seed_prompt": "test",
|
||||
"task_description": "task",
|
||||
"max_iterations": 10,
|
||||
}
|
||||
config_file = tmp_path / "partial.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
raw = persistence.read_config(str(config_file))
|
||||
config = OptimizationConfig.model_validate(raw)
|
||||
|
||||
assert config.max_iterations == 10
|
||||
assert config.n_synthetic_inputs == 20 # default
|
||||
assert config.minibatch_size == 5 # default
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
"""Tests for Pydantic validation edge cases."""
|
||||
|
||||
def test_wrong_type_max_iterations(self) -> None:
|
||||
with pytest.raises(ValidationError, match="max_iterations"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a",
|
||||
task_description="b",
|
||||
max_iterations="not_a_number", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def test_wrong_type_seed(self) -> None:
|
||||
with pytest.raises(ValidationError, match="seed"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a",
|
||||
task_description="b",
|
||||
seed="not_a_number", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def test_negative_max_iterations(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", max_iterations=0,
|
||||
)
|
||||
|
||||
def test_negative_seed(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", seed=-1,
|
||||
)
|
||||
|
||||
def test_zero_minibatch_size(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", minibatch_size=0,
|
||||
)
|
||||
|
||||
def test_zero_max_concurrency(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", max_concurrency=0,
|
||||
)
|
||||
|
||||
def test_negative_max_retries(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", max_retries=-1,
|
||||
)
|
||||
|
||||
def test_zero_retry_delay_base(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than 0"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", retry_delay_base=0,
|
||||
)
|
||||
|
||||
def test_negative_retry_delay_base(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than 0"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", retry_delay_base=-0.5,
|
||||
)
|
||||
|
||||
def test_zero_circuit_breaker_threshold(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", circuit_breaker_threshold=0,
|
||||
)
|
||||
|
||||
def test_perfect_score_above_one(self) -> None:
|
||||
with pytest.raises(ValidationError, match="less than or equal to 1"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", perfect_score=1.5,
|
||||
)
|
||||
|
||||
def test_perfect_score_negative(self) -> None:
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", perfect_score=-0.1,
|
||||
)
|
||||
|
||||
def test_invalid_error_strategy(self) -> None:
|
||||
with pytest.raises(ValidationError, match="error_strategy must be one of"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", error_strategy="invalid",
|
||||
)
|
||||
|
||||
def test_valid_error_strategies(self) -> None:
|
||||
for strategy in ("skip", "retry", "abort"):
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", error_strategy=strategy,
|
||||
)
|
||||
assert config.error_strategy == strategy
|
||||
|
||||
def test_empty_seed_prompt(self) -> None:
|
||||
with pytest.raises(ValidationError, match="seed_prompt"):
|
||||
OptimizationConfig(seed_prompt="", task_description="b")
|
||||
|
||||
def test_empty_task_description(self) -> None:
|
||||
with pytest.raises(ValidationError, match="task_description"):
|
||||
OptimizationConfig(seed_prompt="a", task_description="")
|
||||
|
||||
def test_empty_model_string(self) -> None:
|
||||
with pytest.raises(ValidationError, match="task_model"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", task_model="",
|
||||
)
|
||||
|
||||
def test_extra_fields_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="extra"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a",
|
||||
task_description="b",
|
||||
nonexistent_field="value", # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
def test_boundary_values_accepted(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a",
|
||||
task_description="b",
|
||||
max_iterations=1,
|
||||
n_synthetic_inputs=1,
|
||||
minibatch_size=1,
|
||||
max_concurrency=1,
|
||||
max_retries=0,
|
||||
retry_delay_base=0.001,
|
||||
circuit_breaker_threshold=1,
|
||||
perfect_score=0.0,
|
||||
seed=0,
|
||||
)
|
||||
assert config.max_iterations == 1
|
||||
assert config.perfect_score == 0.0
|
||||
|
||||
|
||||
class TestEvalConfigValidation:
|
||||
"""Tests for ground-truth evaluation config fields."""
|
||||
|
||||
def test_eval_defaults(self) -> None:
|
||||
config = OptimizationConfig(seed_prompt="a", task_description="b")
|
||||
assert config.eval_dataset_path is None
|
||||
assert config.eval_metric == "bleu"
|
||||
|
||||
def test_eval_dataset_path_set(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a", task_description="b",
|
||||
eval_dataset_path="data.csv",
|
||||
)
|
||||
assert config.eval_dataset_path == "data.csv"
|
||||
|
||||
def test_valid_eval_metrics(self) -> None:
|
||||
for metric in ("exact", "bleu", "rouge_l", "cosine", "llm_judge"):
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", eval_metric=metric,
|
||||
)
|
||||
assert config.eval_metric == metric
|
||||
|
||||
def test_invalid_eval_metric_raises(self) -> None:
|
||||
with pytest.raises(ValidationError, match="eval_metric must be one of"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b",
|
||||
eval_metric="invalid_metric",
|
||||
)
|
||||
86
tests/unit/test_dataset_loader.py
Normal file
86
tests/unit/test_dataset_loader.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Tests for the ground-truth dataset loader."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.domain.entities import GroundTruthExample
|
||||
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader():
|
||||
return FileDatasetLoader()
|
||||
|
||||
|
||||
class TestCsvLoader:
|
||||
def test_load_csv(self, loader, tmp_path):
|
||||
csv_file = tmp_path / "test.csv"
|
||||
csv_file.write_text("input,expected_output\nhello,world\nfoo,bar\n")
|
||||
result = loader.load(str(csv_file))
|
||||
assert len(result) == 2
|
||||
assert result[0].input_text == "hello"
|
||||
assert result[0].expected_output == "world"
|
||||
assert result[1].input_text == "foo"
|
||||
assert result[1].expected_output == "bar"
|
||||
|
||||
def test_load_csv_skips_empty_input(self, loader, tmp_path):
|
||||
csv_file = tmp_path / "test.csv"
|
||||
csv_file.write_text("input,expected_output\n,bar\nhello,world\n")
|
||||
result = loader.load(str(csv_file))
|
||||
assert len(result) == 1
|
||||
assert result[0].input_text == "hello"
|
||||
|
||||
def test_load_csv_with_whitespace(self, loader, tmp_path):
|
||||
csv_file = tmp_path / "test.csv"
|
||||
csv_file.write_text("input,expected_output\n hello , world \n")
|
||||
result = loader.load(str(csv_file))
|
||||
assert result[0].input_text == "hello"
|
||||
assert result[0].expected_output == "world"
|
||||
|
||||
def test_load_csv_empty_file(self, loader, tmp_path):
|
||||
csv_file = tmp_path / "test.csv"
|
||||
csv_file.write_text("input,expected_output\n")
|
||||
result = loader.load(str(csv_file))
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestJsonLoader:
|
||||
def test_load_json(self, loader, tmp_path):
|
||||
json_file = tmp_path / "test.json"
|
||||
data = [
|
||||
{"input": "hello", "expected_output": "world"},
|
||||
{"input": "foo", "expected_output": "bar"},
|
||||
]
|
||||
json_file.write_text(json.dumps(data))
|
||||
result = loader.load(str(json_file))
|
||||
assert len(result) == 2
|
||||
assert result[0].input_text == "hello"
|
||||
assert result[0].expected_output == "world"
|
||||
|
||||
def test_load_json_skips_empty_input(self, loader, tmp_path):
|
||||
json_file = tmp_path / "test.json"
|
||||
data = [
|
||||
{"input": "", "expected_output": "bar"},
|
||||
{"input": "hello", "expected_output": "world"},
|
||||
]
|
||||
json_file.write_text(json.dumps(data))
|
||||
result = loader.load(str(json_file))
|
||||
assert len(result) == 1
|
||||
|
||||
def test_load_json_not_array_raises(self, loader, tmp_path):
|
||||
json_file = tmp_path / "test.json"
|
||||
json_file.write_text(json.dumps({"not": "an array"}))
|
||||
with pytest.raises(ValueError, match="must be an array"):
|
||||
loader.load(str(json_file))
|
||||
|
||||
|
||||
class TestUnsupportedFormat:
|
||||
def test_unsupported_extension_raises(self, loader, tmp_path):
|
||||
txt_file = tmp_path / "test.txt"
|
||||
txt_file.write_text("hello")
|
||||
with pytest.raises(ValueError, match="Unsupported dataset format"):
|
||||
loader.load(str(txt_file))
|
||||
313
tests/unit/test_error_handling.py
Normal file
313
tests/unit/test_error_handling.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Unit tests for error handling: retry, circuit breaker, per-call isolation."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import CircuitBreakerOpen, EvolutionLoop
|
||||
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory
|
||||
from prometheus.infrastructure.retry import is_transient_error, retry_with_backoff
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Retry utility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsTransientError:
|
||||
def test_rate_limit_429(self):
|
||||
assert is_transient_error(RuntimeError("HTTP 429: rate limit exceeded"))
|
||||
|
||||
def test_server_error_503(self):
|
||||
assert is_transient_error(RuntimeError("503 Service Unavailable"))
|
||||
|
||||
def test_timeout(self):
|
||||
assert is_transient_error(TimeoutError("request timed out"))
|
||||
|
||||
def test_connection_error(self):
|
||||
assert is_transient_error(ConnectionError("connection refused"))
|
||||
|
||||
def test_non_transient(self):
|
||||
assert not is_transient_error(ValueError("bad input"))
|
||||
|
||||
def test_os_error(self):
|
||||
assert is_transient_error(OSError("network unreachable"))
|
||||
|
||||
|
||||
class TestRetryWithBackoff:
|
||||
def test_succeeds_on_first_try(self):
|
||||
fn = MagicMock(return_value="ok")
|
||||
result = retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
|
||||
assert result == "ok"
|
||||
assert fn.call_count == 1
|
||||
|
||||
def test_retries_on_transient_then_succeeds(self):
|
||||
fn = MagicMock(
|
||||
side_effect=[
|
||||
RuntimeError("429 rate limit"),
|
||||
RuntimeError("429 rate limit"),
|
||||
"ok",
|
||||
]
|
||||
)
|
||||
with patch("prometheus.infrastructure.retry.time.sleep"):
|
||||
result = retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
|
||||
assert result == "ok"
|
||||
assert fn.call_count == 3
|
||||
|
||||
def test_raises_after_max_retries(self):
|
||||
fn = MagicMock(side_effect=RuntimeError("503 overloaded"))
|
||||
with patch("prometheus.infrastructure.retry.time.sleep"):
|
||||
with pytest.raises(RuntimeError, match="503"):
|
||||
retry_with_backoff(fn, max_retries=2, retry_delay_base=0)
|
||||
assert fn.call_count == 3 # 1 initial + 2 retries
|
||||
|
||||
def test_non_transient_not_retried(self):
|
||||
fn = MagicMock(side_effect=ValueError("bad"))
|
||||
with pytest.raises(ValueError, match="bad"):
|
||||
retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
|
||||
assert fn.call_count == 1
|
||||
|
||||
def test_exponential_backoff_delays(self):
|
||||
fn = MagicMock(side_effect=[RuntimeError("timeout"), "ok"])
|
||||
with patch("prometheus.infrastructure.retry.time.sleep") as mock_sleep:
|
||||
retry_with_backoff(fn, max_retries=3, retry_delay_base=2.0)
|
||||
mock_sleep.assert_called_once_with(2.0) # 2.0 * 2^0 = 2.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Circuit breaker (EvolutionLoop)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_eval_result(scores, feedbacks=None):
|
||||
"""Helper to create EvalResult with matching trajectories."""
|
||||
feedbacks = feedbacks or ["ok"] * len(scores)
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=feedbacks,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, f, "prompt")
|
||||
for i, (s, f) in enumerate(zip(scores, feedbacks))
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
@pytest.mark.asyncio
|
||||
async def test_trips_on_consecutive_failures(self):
|
||||
"""Loop stops when consecutive failures reach the threshold."""
|
||||
initial_eval = _make_eval_result([0.3, 0.4])
|
||||
evaluator = MagicMock()
|
||||
call_count = 0
|
||||
|
||||
def _evaluate(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return initial_eval # seed eval succeeds
|
||||
raise RuntimeError("LLM down")
|
||||
|
||||
evaluator.evaluate = AsyncMock(side_effect=_evaluate)
|
||||
proposer = MagicMock()
|
||||
proposer.propose = AsyncMock()
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=10,
|
||||
minibatch_size=2,
|
||||
circuit_breaker_threshold=3,
|
||||
error_strategy="skip",
|
||||
)
|
||||
state = await loop.run(
|
||||
Prompt("test"),
|
||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||
"task",
|
||||
)
|
||||
|
||||
error_events = [h for h in state.history if h.get("event") == "error"]
|
||||
cb_events = [h for h in state.history if h.get("event") == "circuit_breaker"]
|
||||
assert len(error_events) == 3
|
||||
assert len(cb_events) == 1
|
||||
assert state.iteration < 10 # stopped early
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_raises_on_first_error(self):
|
||||
"""With error_strategy=abort, the first error raises immediately."""
|
||||
initial_eval = _make_eval_result([0.3, 0.4])
|
||||
evaluator = MagicMock()
|
||||
call_count = 0
|
||||
|
||||
def _evaluate(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return initial_eval
|
||||
raise RuntimeError("LLM down")
|
||||
|
||||
evaluator.evaluate = AsyncMock(side_effect=_evaluate)
|
||||
proposer = MagicMock()
|
||||
proposer.propose = AsyncMock()
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=10,
|
||||
minibatch_size=2,
|
||||
circuit_breaker_threshold=3,
|
||||
error_strategy="abort",
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="LLM down"):
|
||||
await loop.run(
|
||||
Prompt("test"),
|
||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||
"task",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resets_on_success(self):
|
||||
"""Consecutive failure counter resets after a successful iteration."""
|
||||
initial_eval = _make_eval_result([0.3, 0.4])
|
||||
good_eval = _make_eval_result([0.8, 0.9])
|
||||
|
||||
evaluator = MagicMock()
|
||||
call_count = 0
|
||||
|
||||
def _evaluate(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# call 1: seed eval → succeed
|
||||
# call 2: iter 1 current eval → fail
|
||||
# call 3: iter 2 current eval → fail
|
||||
# call 4: iter 3 current eval → succeed (returns initial_eval)
|
||||
# call 5: iter 3 new eval → succeed (returns good_eval, accepted)
|
||||
# call 6+: iter 4+ current eval → succeed
|
||||
if call_count == 1:
|
||||
return initial_eval
|
||||
if call_count in (2, 3):
|
||||
raise RuntimeError("timeout")
|
||||
if call_count % 2 == 0:
|
||||
return initial_eval # current eval
|
||||
return good_eval # new eval
|
||||
|
||||
evaluator.evaluate = AsyncMock(side_effect=_evaluate)
|
||||
proposer = MagicMock()
|
||||
proposer.propose = AsyncMock(return_value=Prompt("better prompt"))
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = [
|
||||
SyntheticExample(f"in{i}", id=i) for i in range(2)
|
||||
]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=6,
|
||||
minibatch_size=2,
|
||||
circuit_breaker_threshold=3,
|
||||
error_strategy="skip",
|
||||
)
|
||||
state = await loop.run(
|
||||
Prompt("test"),
|
||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||
"task",
|
||||
)
|
||||
|
||||
# Should NOT have tripped — 2 fails, then success reset the counter
|
||||
cb_events = [h for h in state.history if h.get("event") == "circuit_breaker"]
|
||||
assert len(cb_events) == 0
|
||||
assert state.iteration == 6 # ran all iterations
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-call isolation (Evaluator)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPerCallIsolation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluator_isolates_execution_failure(self):
|
||||
"""A failing execution produces a sentinel output, not a crash."""
|
||||
executor = MagicMock()
|
||||
executor.execute = AsyncMock(side_effect=[
|
||||
"good output",
|
||||
RuntimeError("API error"),
|
||||
"another good output",
|
||||
])
|
||||
judge = MagicMock()
|
||||
judge.judge_batch = AsyncMock(return_value=[
|
||||
(0.8, "good"),
|
||||
(0.0, "[judge error]"),
|
||||
(0.7, "ok"),
|
||||
])
|
||||
|
||||
evaluator = PromptEvaluator(executor, judge)
|
||||
result = await evaluator.evaluate(
|
||||
Prompt("test"),
|
||||
[
|
||||
SyntheticExample("in0", id=0),
|
||||
SyntheticExample("in1", id=1),
|
||||
SyntheticExample("in2", id=2),
|
||||
],
|
||||
"task",
|
||||
)
|
||||
|
||||
assert len(result.scores) == 3
|
||||
assert result.scores[1] == 0.0 # failed item got zero score
|
||||
assert "execution error" in result.trajectories[1].output_text
|
||||
assert result.scores[0] == 0.8 # other items unaffected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_adapter_isolates_single_failure(self):
|
||||
"""DSPyJudgeAdapter returns sentinel for a failed item, not crash."""
|
||||
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||
|
||||
adapter = DSPyJudgeAdapter.__new__(DSPyJudgeAdapter)
|
||||
adapter._lm = MagicMock()
|
||||
adapter._max_retries = 1
|
||||
adapter._retry_delay_base = 0
|
||||
adapter._semaphore = __import__("asyncio").Semaphore(5)
|
||||
adapter._judge_criteria = ""
|
||||
adapter._judge_dimensions = []
|
||||
adapter._dimension_names = ""
|
||||
adapter._weights = {}
|
||||
adapter.call_count = 0
|
||||
|
||||
# Mock _judge to fail on first call, succeed on second
|
||||
call_count = 0
|
||||
|
||||
class FakePred:
|
||||
def __init__(self):
|
||||
self.score = 0.9
|
||||
self.feedback = "good"
|
||||
|
||||
def fake_judge(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("judge failure")
|
||||
return FakePred()
|
||||
|
||||
adapter._judge = fake_judge
|
||||
|
||||
with patch("prometheus.infrastructure.judge_adapter.dspy.context"):
|
||||
with patch(
|
||||
"prometheus.infrastructure.retry.asyncio.sleep",
|
||||
new=AsyncMock(),
|
||||
):
|
||||
results = await adapter.judge_batch(
|
||||
"task", [("input1", "output1"), ("input2", "output2")]
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
# First item failed even after retry → sentinel
|
||||
assert results[0] == (0.0, "[judge error: judge failure]")
|
||||
# Second item succeeded
|
||||
assert results[1] == (0.9, "good")
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Unit tests for PromptEvaluator.evaluate()."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -14,22 +14,23 @@ class TestPromptEvaluatorEvaluate:
|
||||
"""Tests for the evaluate() pipeline: execute → judge → trajectories."""
|
||||
|
||||
@pytest.fixture
|
||||
def executor(self) -> MagicMock:
|
||||
return MagicMock(spec=LLMPort)
|
||||
def executor(self) -> AsyncMock:
|
||||
return AsyncMock(spec=LLMPort)
|
||||
|
||||
@pytest.fixture
|
||||
def judge(self) -> MagicMock:
|
||||
return MagicMock(spec=JudgePort)
|
||||
def judge(self) -> AsyncMock:
|
||||
return AsyncMock(spec=JudgePort)
|
||||
|
||||
@pytest.fixture
|
||||
def evaluator(self, executor: MagicMock, judge: MagicMock) -> PromptEvaluator:
|
||||
def evaluator(self, executor: AsyncMock, judge: AsyncMock) -> PromptEvaluator:
|
||||
return PromptEvaluator(executor=executor, judge=judge)
|
||||
|
||||
def test_happy_path_builds_correct_trajectories(
|
||||
@pytest.mark.asyncio
|
||||
async def test_happy_path_builds_correct_trajectories(
|
||||
self,
|
||||
evaluator: PromptEvaluator,
|
||||
executor: MagicMock,
|
||||
judge: MagicMock,
|
||||
executor: AsyncMock,
|
||||
judge: AsyncMock,
|
||||
) -> None:
|
||||
prompt = Prompt(text="Answer the question.")
|
||||
examples = [
|
||||
@@ -42,7 +43,7 @@ class TestPromptEvaluatorEvaluate:
|
||||
(0.8, "Mostly correct."),
|
||||
]
|
||||
|
||||
result = evaluator.evaluate(prompt, examples, "math and geography")
|
||||
result = await evaluator.evaluate(prompt, examples, "math and geography")
|
||||
|
||||
assert isinstance(result, EvalResult)
|
||||
assert result.scores == [0.9, 0.8]
|
||||
@@ -55,14 +56,15 @@ class TestPromptEvaluatorEvaluate:
|
||||
assert result.trajectories[0].prompt_used == "Answer the question."
|
||||
assert result.trajectories[1].prompt_used == "Answer the question."
|
||||
|
||||
def test_empty_minibatch_returns_empty_result(
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_minibatch_returns_empty_result(
|
||||
self,
|
||||
evaluator: PromptEvaluator,
|
||||
executor: MagicMock,
|
||||
judge: MagicMock,
|
||||
executor: AsyncMock,
|
||||
judge: AsyncMock,
|
||||
) -> None:
|
||||
prompt = Prompt(text="test")
|
||||
result = evaluator.evaluate(prompt, [], "task")
|
||||
result = await evaluator.evaluate(prompt, [], "task")
|
||||
|
||||
assert result.scores == []
|
||||
assert result.feedbacks == []
|
||||
@@ -71,41 +73,44 @@ class TestPromptEvaluatorEvaluate:
|
||||
# judge_batch is called with empty pairs list
|
||||
judge.judge_batch.assert_called_once_with("task", [])
|
||||
|
||||
def test_executor_called_with_correct_prompt(
|
||||
@pytest.mark.asyncio
|
||||
async def test_executor_called_with_correct_prompt(
|
||||
self,
|
||||
evaluator: PromptEvaluator,
|
||||
executor: MagicMock,
|
||||
judge: MagicMock,
|
||||
executor: AsyncMock,
|
||||
judge: AsyncMock,
|
||||
) -> None:
|
||||
prompt = Prompt(text="Summarize this.")
|
||||
examples = [SyntheticExample(input_text="Long text here", id=0)]
|
||||
executor.execute.return_value = "Summary."
|
||||
judge.judge_batch.return_value = [(0.7, "Good summary.")]
|
||||
|
||||
evaluator.evaluate(prompt, examples, "summarization")
|
||||
await evaluator.evaluate(prompt, examples, "summarization")
|
||||
|
||||
executor.execute.assert_called_once_with(prompt, "Long text here")
|
||||
|
||||
def test_trajectories_prompt_used_matches_input_prompt(
|
||||
@pytest.mark.asyncio
|
||||
async def test_trajectories_prompt_used_matches_input_prompt(
|
||||
self,
|
||||
evaluator: PromptEvaluator,
|
||||
executor: MagicMock,
|
||||
judge: MagicMock,
|
||||
executor: AsyncMock,
|
||||
judge: AsyncMock,
|
||||
) -> None:
|
||||
prompt = Prompt(text="Translate to French.")
|
||||
examples = [SyntheticExample(input_text="Hello", id=0)]
|
||||
executor.execute.return_value = "Bonjour"
|
||||
judge.judge_batch.return_value = [(1.0, "Perfect.")]
|
||||
|
||||
result = evaluator.evaluate(prompt, examples, "translation")
|
||||
result = await evaluator.evaluate(prompt, examples, "translation")
|
||||
|
||||
assert result.trajectories[0].prompt_used == "Translate to French."
|
||||
|
||||
def test_scores_feedbacks_trajectories_lists_sized_correctly(
|
||||
@pytest.mark.asyncio
|
||||
async def test_scores_feedbacks_trajectories_lists_sized_correctly(
|
||||
self,
|
||||
evaluator: PromptEvaluator,
|
||||
executor: MagicMock,
|
||||
judge: MagicMock,
|
||||
executor: AsyncMock,
|
||||
judge: AsyncMock,
|
||||
) -> None:
|
||||
prompt = Prompt(text="test prompt")
|
||||
examples = [SyntheticExample(input_text=f"q{i}", id=i) for i in range(4)]
|
||||
@@ -114,7 +119,7 @@ class TestPromptEvaluatorEvaluate:
|
||||
(0.1 * i, f"fb{i}") for i in range(4)
|
||||
]
|
||||
|
||||
result = evaluator.evaluate(prompt, examples, "task")
|
||||
result = await evaluator.evaluate(prompt, examples, "task")
|
||||
|
||||
assert len(result.scores) == 4
|
||||
assert len(result.feedbacks) == 4
|
||||
|
||||
@@ -1,51 +1,55 @@
|
||||
"""Unit tests for the evolution loop — with full mocking."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
EvalResult,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
Trajectory,
|
||||
)
|
||||
|
||||
|
||||
def _make_eval(scores: list[float], label: str = "ok") -> EvalResult:
|
||||
"""Helper to build an EvalResult from a list of scores."""
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=[label] * len(scores),
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, label, "prompt")
|
||||
for i, s in enumerate(scores)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestEvolutionLoop:
|
||||
def test_accepts_improvement(
|
||||
"""Tests for the original single-candidate hill-climbing mode (population_size=1)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accepts_improvement(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: MagicMock,
|
||||
mock_judge_port: MagicMock,
|
||||
mock_proposer_port: MagicMock,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""When the new prompt improves the score, the best candidate is updated."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
initial_eval = EvalResult(
|
||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||
feedbacks=["bad"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, "bad", "prompt")
|
||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||
],
|
||||
)
|
||||
old_eval = EvalResult(
|
||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||
feedbacks=["bad"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, "bad", "prompt")
|
||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||
],
|
||||
)
|
||||
new_eval = EvalResult(
|
||||
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
evaluator.evaluate = MagicMock(side_effect=[initial_eval, old_eval, new_eval])
|
||||
low_eval = _make_eval([0.3, 0.4, 0.3, 0.5, 0.2], "bad")
|
||||
high_eval = _make_eval([0.8, 0.9, 0.7, 0.8, 0.9], "good")
|
||||
evaluator.evaluate = AsyncMock(side_effect=[low_eval, low_eval, high_eval])
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
@@ -54,48 +58,29 @@ class TestEvolutionLoop:
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
)
|
||||
with patch.object(loop, "_log"):
|
||||
state = loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert state.best_candidate is not None
|
||||
assert state.best_candidate.best_score > 0
|
||||
|
||||
def test_rejects_regression(
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_regression(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: MagicMock,
|
||||
mock_judge_port: MagicMock,
|
||||
mock_proposer_port: MagicMock,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""When the new prompt degrades the score, the best candidate stays unchanged."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
initial_eval = EvalResult(
|
||||
scores=[0.7, 0.8, 0.7, 0.8, 0.9],
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, "ok", "prompt")
|
||||
for i, s in enumerate([0.7, 0.8, 0.7, 0.8, 0.9])
|
||||
],
|
||||
)
|
||||
old_eval = EvalResult(
|
||||
scores=[0.7, 0.8, 0.7, 0.8, 0.9],
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, "ok", "prompt")
|
||||
for i, s in enumerate([0.7, 0.8, 0.7, 0.8, 0.9])
|
||||
],
|
||||
)
|
||||
new_eval = EvalResult(
|
||||
scores=[0.2, 0.1, 0.3, 0.2, 0.1],
|
||||
feedbacks=["bad"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
evaluator.evaluate = MagicMock(side_effect=[initial_eval, old_eval, new_eval])
|
||||
high_eval = _make_eval([0.7, 0.8, 0.7, 0.8, 0.9], "ok")
|
||||
low_eval = _make_eval([0.2, 0.1, 0.3, 0.2, 0.1], "bad")
|
||||
evaluator.evaluate = AsyncMock(side_effect=[high_eval, high_eval, low_eval])
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
@@ -104,35 +89,28 @@ class TestEvolutionLoop:
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
)
|
||||
with patch.object(loop, "_log"):
|
||||
state = loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert state.best_candidate is not None
|
||||
assert state.best_candidate.prompt.text == seed_prompt.text
|
||||
|
||||
def test_skips_perfect_scores(
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_perfect_scores(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: MagicMock,
|
||||
mock_judge_port: MagicMock,
|
||||
mock_proposer_port: MagicMock,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""When all scores are perfect, no proposition is made."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
perfect_eval = EvalResult(
|
||||
scores=[1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
feedbacks=["perfect"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", 1.0, "perfect", "prompt")
|
||||
for i in range(5)
|
||||
],
|
||||
)
|
||||
evaluator.evaluate = MagicMock(return_value=perfect_eval)
|
||||
perfect_eval = _make_eval([1.0, 1.0, 1.0, 1.0, 1.0], "perfect")
|
||||
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
@@ -141,7 +119,226 @@ class TestEvolutionLoop:
|
||||
max_iterations=3,
|
||||
minibatch_size=5,
|
||||
)
|
||||
with patch.object(loop, "_log"):
|
||||
loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
mock_proposer_port.propose.assert_not_called()
|
||||
|
||||
|
||||
class TestPopulationEvolution:
|
||||
"""Tests for population-based evolution (population_size > 1)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_population_initialization(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
mock_mutation_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Population is initialized with the right number of candidates."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
evaluator.evaluate = AsyncMock(
|
||||
return_value=_make_eval([0.5] * 5, "ok")
|
||||
)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=0, # no iterations, just initialization
|
||||
minibatch_size=5,
|
||||
population_size=4,
|
||||
mutation_port=mock_mutation_port,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
# 1 seed + 3 mutations = 4 candidates
|
||||
assert len(state.candidates) == 4
|
||||
assert mock_mutation_port.mutate.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_population_initialization_uses_proposer_fallback(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""When no mutation_port is provided, population init falls back to proposer."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
evaluator.evaluate = AsyncMock(
|
||||
return_value=_make_eval([0.5] * 5, "ok")
|
||||
)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=0,
|
||||
minibatch_size=5,
|
||||
population_size=3,
|
||||
# mutation_port intentionally omitted
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert len(state.candidates) == 3
|
||||
assert mock_proposer_port.propose.call_count == 2 # 3-1 = 2 init mutations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_population_iteration_replaces_worst(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
mock_crossover_port: AsyncMock,
|
||||
mock_mutation_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Crossover child replaces worst candidate when its fitness is higher."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
# Sequence:
|
||||
# 1. Initial eval (seed)
|
||||
# 2. Population init: 3 mutation calls use proposer.propose(), NOT evaluator.evaluate
|
||||
# 3. Population iteration: crossover produces child → eval child
|
||||
# Only 2 evaluator.evaluate calls total
|
||||
seed_eval = _make_eval([0.5] * 5, "ok")
|
||||
# Crossover child eval - high score to beat worst
|
||||
child_eval = _make_eval([0.9, 0.9, 0.8, 0.9, 0.8], "great")
|
||||
|
||||
all_evals = [seed_eval, child_eval]
|
||||
evaluator.evaluate = AsyncMock(side_effect=all_evals)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
population_size=4,
|
||||
crossover_rate=1.0,
|
||||
crossover_port=mock_crossover_port,
|
||||
mutation_rate=0.0, # disable post-crossover mutation for determinism
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
accepted_events = [h for h in state.history if h.get("event") == "pop_accepted"]
|
||||
assert len(accepted_events) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_population_iteration_rejects_inferior_child(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
mock_crossover_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Inferior child is rejected and doesn't replace any candidate."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
seed_eval = _make_eval([0.8] * 5, "ok")
|
||||
# Crossover produces very LOW-scoring child
|
||||
child_eval = _make_eval([0.1] * 5, "terrible")
|
||||
|
||||
all_evals = [seed_eval, child_eval]
|
||||
evaluator.evaluate = AsyncMock(side_effect=all_evals)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
population_size=4,
|
||||
crossover_rate=1.0,
|
||||
crossover_port=mock_crossover_port,
|
||||
mutation_rate=0.0,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
rejected_events = [h for h in state.history if h.get("event") == "pop_rejected"]
|
||||
assert len(rejected_events) >= 1
|
||||
|
||||
|
||||
class TestDiversityScore:
|
||||
"""Tests for the diversity/similarity scoring logic."""
|
||||
|
||||
def test_identical_prompts_have_high_similarity(self) -> None:
|
||||
"""Identical prompts should have very high similarity."""
|
||||
identical = Prompt(text="You are a helpful assistant. Answer the question.")
|
||||
pop_a = Candidate(prompt=identical, best_score=4.0, generation=0)
|
||||
pop_b = Candidate(
|
||||
prompt=Prompt(text="Completely different prompt about data analysis."),
|
||||
best_score=3.0,
|
||||
generation=0,
|
||||
)
|
||||
sim_same = EvolutionLoop._compute_diversity_score(identical, [pop_a, pop_b])
|
||||
# Average includes similarity to the different member, so ~0.5 not 0.9+
|
||||
assert sim_same > 0.3
|
||||
|
||||
def test_different_prompts_have_lower_similarity(self) -> None:
|
||||
"""Different prompts should have lower similarity than identical ones."""
|
||||
prompt_a = Prompt(text="You are a helpful assistant. Answer the question.")
|
||||
prompt_b = Prompt(text="Provide detailed analysis of complex data patterns with precision.")
|
||||
pop_a = Candidate(prompt=prompt_a, best_score=4.0, generation=0)
|
||||
pop_b = Candidate(prompt=prompt_b, best_score=3.0, generation=0)
|
||||
sim_a = EvolutionLoop._compute_diversity_score(prompt_a, [pop_a, pop_b])
|
||||
sim_b = EvolutionLoop._compute_diversity_score(prompt_b, [pop_a, pop_b])
|
||||
# Both should be < 1.0 since they're different
|
||||
assert sim_a < 1.0
|
||||
assert sim_b < 1.0
|
||||
|
||||
def test_single_member_population_returns_1(self) -> None:
|
||||
"""Single-member population always returns 1.0 (no penalty)."""
|
||||
prompt = Prompt(text="Any prompt text here.")
|
||||
pop = [Candidate(prompt=prompt, best_score=1.0, generation=0)]
|
||||
sim = EvolutionLoop._compute_diversity_score(prompt, pop)
|
||||
assert sim == 1.0
|
||||
|
||||
def test_empty_prompt_returns_zero(self) -> None:
|
||||
"""Empty prompt text returns 0.0 when population has >1 member."""
|
||||
prompt = Prompt(text="")
|
||||
pop = [
|
||||
Candidate(prompt=Prompt(text="some text"), best_score=1.0, generation=0),
|
||||
Candidate(prompt=Prompt(text="other text"), best_score=2.0, generation=0),
|
||||
]
|
||||
sim = EvolutionLoop._compute_diversity_score(prompt, pop)
|
||||
assert sim == 0.0
|
||||
|
||||
|
||||
class TestPromptDiff:
|
||||
"""Tests for the static _compute_prompt_diff helper."""
|
||||
|
||||
def test_identical_prompts(self) -> None:
|
||||
result = EvolutionLoop._compute_prompt_diff("hello\nworld", "hello\nworld")
|
||||
assert result["lines_added"] == 0
|
||||
assert result["lines_removed"] == 0
|
||||
assert result["chars_delta"] == 0
|
||||
|
||||
def test_added_lines(self) -> None:
|
||||
result = EvolutionLoop._compute_prompt_diff("hello", "hello\nworld")
|
||||
assert result["lines_added"] == 1
|
||||
assert result["lines_removed"] == 0
|
||||
assert result["chars_delta"] == 6 # "\nworld"
|
||||
|
||||
def test_removed_lines(self) -> None:
|
||||
result = EvolutionLoop._compute_prompt_diff("hello\nworld", "hello")
|
||||
assert result["lines_added"] == 0
|
||||
assert result["lines_removed"] == 1
|
||||
|
||||
133
tests/unit/test_ground_truth_evaluator.py
Normal file
133
tests/unit/test_ground_truth_evaluator.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tests for GroundTruthEvaluator — execution + similarity comparison."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.ground_truth_evaluator import GroundTruthEvaluator
|
||||
from prometheus.domain.entities import EvalResult, GroundTruthExample, Prompt
|
||||
from prometheus.domain.ports import LLMPort, SimilarityPort
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_executor() -> AsyncMock:
|
||||
port = AsyncMock(spec=LLMPort)
|
||||
port.execute.return_value = "Paris is the capital of France."
|
||||
return port
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_similarity() -> AsyncMock:
|
||||
port = AsyncMock(spec=SimilarityPort)
|
||||
port.compute.return_value = 0.85
|
||||
return port
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gt_dataset() -> list[GroundTruthExample]:
|
||||
return [
|
||||
GroundTruthExample(input_text="What is the capital of France?", expected_output="Paris", id=0),
|
||||
GroundTruthExample(input_text="What is 2+2?", expected_output="4", id=1),
|
||||
GroundTruthExample(input_text="What color is the sky?", expected_output="blue", id=2),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt() -> Prompt:
|
||||
return Prompt(text="Answer the following question accurately.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGroundTruthEvaluator:
|
||||
async def test_evaluate_happy_path(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor,
|
||||
similarity=mock_similarity,
|
||||
max_concurrency=2,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
|
||||
assert isinstance(result, EvalResult)
|
||||
assert len(result.scores) == 3
|
||||
assert len(result.feedbacks) == 3
|
||||
assert len(result.trajectories) == 3
|
||||
assert all(s == 0.85 for s in result.scores)
|
||||
assert result.mean_score == pytest.approx(0.85)
|
||||
assert result.total_score == pytest.approx(2.55)
|
||||
|
||||
async def test_executor_called_for_each_input(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=mock_similarity,
|
||||
)
|
||||
await evaluator.evaluate(prompt, gt_dataset)
|
||||
assert mock_executor.execute.call_count == 3
|
||||
|
||||
async def test_similarity_called_for_each_output(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=mock_similarity,
|
||||
)
|
||||
await evaluator.evaluate(prompt, gt_dataset)
|
||||
assert mock_similarity.compute.call_count == 3
|
||||
|
||||
async def test_execution_error_produces_zero_score(self, mock_similarity, gt_dataset, prompt):
|
||||
failing_executor = AsyncMock(spec=LLMPort)
|
||||
failing_executor.execute.side_effect = RuntimeError("API timeout")
|
||||
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=failing_executor, similarity=mock_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
|
||||
assert len(result.scores) == 3
|
||||
# The similarity adapter is called with the error sentinel
|
||||
assert all(isinstance(s, float) for s in result.scores)
|
||||
assert all("[execution error:" in t.output_text for t in result.trajectories)
|
||||
|
||||
async def test_empty_dataset(self, mock_executor, mock_similarity, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=mock_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, [])
|
||||
assert result.scores == []
|
||||
assert result.mean_score == 0.0
|
||||
assert result.total_score == 0.0
|
||||
|
||||
async def test_trajectory_contains_prompt_used(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=mock_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
for t in result.trajectories:
|
||||
assert t.prompt_used == prompt.text
|
||||
|
||||
async def test_scores_clamped_to_unit_range(self, mock_executor, gt_dataset, prompt):
|
||||
# Similarity returns a value > 1.0 (should be clamped)
|
||||
over_similarity = AsyncMock(spec=SimilarityPort)
|
||||
over_similarity.compute.return_value = 1.5
|
||||
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=over_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
assert all(0.0 <= s <= 1.0 for s in result.scores)
|
||||
|
||||
async def test_feedback_for_exact_match(self, mock_executor, gt_dataset, prompt):
|
||||
exact_similarity = AsyncMock(spec=SimilarityPort)
|
||||
exact_similarity.compute.return_value = 1.0
|
||||
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=exact_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
assert all("Exact match" in fb for fb in result.feedbacks)
|
||||
|
||||
async def test_feedback_for_poor_match(self, mock_executor, gt_dataset, prompt):
|
||||
poor_similarity = AsyncMock(spec=SimilarityPort)
|
||||
poor_similarity.compute.return_value = 0.1
|
||||
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=poor_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
assert all("Poor match" in fb for fb in result.feedbacks)
|
||||
316
tests/unit/test_holdout_validation.py
Normal file
316
tests/unit/test_holdout_validation.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Unit tests for hold-out validation and early stopping."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
EvalResult,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
Trajectory,
|
||||
)
|
||||
|
||||
|
||||
def _make_eval(mean_score: float, n: int = 5) -> EvalResult:
|
||||
"""Helper: create an EvalResult with a given mean score."""
|
||||
scores = [mean_score] * n
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=["feedback"] * n,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", mean_score, "feedback", "prompt")
|
||||
for i in range(n)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestBootstrapSplit:
|
||||
"""Tests for SyntheticBootstrap.split_pool."""
|
||||
|
||||
def test_split_produces_correct_sizes(self):
|
||||
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(20)]
|
||||
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
|
||||
assert len(train) + len(val) == 20
|
||||
assert len(val) == 6 # 20 * 0.3 = 6
|
||||
assert len(train) == 14
|
||||
|
||||
def test_split_zero_fraction_returns_all_train(self):
|
||||
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(10)]
|
||||
train, val = SyntheticBootstrap.split_pool(pool, 0.0)
|
||||
assert len(train) == 10
|
||||
assert len(val) == 0
|
||||
|
||||
def test_split_single_element(self):
|
||||
pool = [SyntheticExample(input_text="only", id=0)]
|
||||
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
|
||||
assert len(train) == 1
|
||||
assert len(val) == 0
|
||||
|
||||
def test_split_deterministic_with_seed(self):
|
||||
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(50)]
|
||||
train1, val1 = SyntheticBootstrap.split_pool(pool, 0.3, rng=MagicMock(wraps=__import__("random").Random(42)))
|
||||
train2, val2 = SyntheticBootstrap.split_pool(pool, 0.3, rng=MagicMock(wraps=__import__("random").Random(42)))
|
||||
assert [ex.id for ex in train1] == [ex.id for ex in train2]
|
||||
assert [ex.id for ex in val1] == [ex.id for ex in val2]
|
||||
|
||||
def test_split_no_overlap(self):
|
||||
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(30)]
|
||||
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
|
||||
train_ids = {ex.id for ex in train}
|
||||
val_ids = {ex.id for ex in val}
|
||||
assert train_ids.isdisjoint(val_ids)
|
||||
assert train_ids | val_ids == {ex.id for ex in pool}
|
||||
|
||||
|
||||
class TestValidationEvaluation:
|
||||
"""Tests for hold-out evaluation during evolution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_pool_evaluated_after_each_iteration(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""When a validation pool is provided, the best candidate is evaluated on it."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
# Initial eval (train) + validation eval + iteration train eval + new prompt eval + validation eval
|
||||
train_eval = _make_eval(0.5)
|
||||
val_eval = _make_eval(0.6)
|
||||
new_eval = _make_eval(0.7)
|
||||
val_eval_2 = _make_eval(0.65)
|
||||
|
||||
evaluator.evaluate = AsyncMock(
|
||||
side_effect=[train_eval, val_eval, train_eval, new_eval, val_eval_2]
|
||||
)
|
||||
|
||||
validation_pool = synthetic_pool[-6:]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
validation_pool=validation_pool,
|
||||
)
|
||||
|
||||
# Should have validation metrics in state
|
||||
assert state.best_validation_score is not None
|
||||
# History should contain validation_eval entries
|
||||
val_events = [h for h in state.history if h["event"] == "validation_eval"]
|
||||
assert len(val_events) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_validation_without_pool(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Without a validation pool, no validation is performed."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
train_eval = _make_eval(0.5)
|
||||
old_eval = _make_eval(0.5)
|
||||
new_eval = _make_eval(0.7)
|
||||
evaluator.evaluate = AsyncMock(side_effect=[train_eval, old_eval, new_eval])
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert state.best_validation_score is None
|
||||
assert not state.early_stopped
|
||||
val_events = [h for h in state.history if h["event"] == "validation_eval"]
|
||||
assert len(val_events) == 0
|
||||
|
||||
|
||||
class TestEarlyStopping:
|
||||
"""Tests for early stopping when validation score degrades."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_early_stop_triggers_on_patience_exceeded(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Early stopping triggers when validation doesn't improve for K iterations."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
patience = 3
|
||||
# Build eval sequence:
|
||||
# 1. Initial train eval
|
||||
# 2. Initial validation eval (0.5)
|
||||
# Then for each of 3 iterations:
|
||||
# - train eval (current best)
|
||||
# - train eval (new prompt - accepted)
|
||||
# - validation eval (degrading)
|
||||
evals = [
|
||||
_make_eval(0.5), # initial train
|
||||
_make_eval(0.5), # initial validation
|
||||
]
|
||||
for i in range(patience):
|
||||
evals.extend([
|
||||
_make_eval(0.5 + i * 0.1), # current eval (train)
|
||||
_make_eval(0.6 + i * 0.1), # new eval (train) - accepted
|
||||
_make_eval(0.4), # validation eval (degrading)
|
||||
])
|
||||
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
|
||||
validation_pool = synthetic_pool[-5:]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=10, # would go further without early stop
|
||||
minibatch_size=5,
|
||||
early_stop_patience=patience,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
validation_pool=validation_pool,
|
||||
)
|
||||
|
||||
assert state.early_stopped is True
|
||||
assert state.iteration == patience
|
||||
assert state.best_validation_score is not None
|
||||
# Should have an early_stop event in history
|
||||
early_stop_events = [h for h in state.history if h["event"] == "early_stop"]
|
||||
assert len(early_stop_events) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_early_stop_does_not_trigger_when_improving(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""When validation keeps improving, early stopping does not trigger."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
evals = [
|
||||
_make_eval(0.3), # initial train
|
||||
_make_eval(0.3), # initial validation
|
||||
]
|
||||
# 3 iterations, each with improving validation
|
||||
for i in range(3):
|
||||
evals.extend([
|
||||
_make_eval(0.3 + i * 0.1), # current train eval
|
||||
_make_eval(0.4 + i * 0.1), # new train eval (accepted)
|
||||
_make_eval(0.3 + (i + 1) * 0.1), # validation eval (improving)
|
||||
])
|
||||
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
|
||||
validation_pool = synthetic_pool[-5:]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=5,
|
||||
early_stop_patience=5,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
validation_pool=validation_pool,
|
||||
)
|
||||
|
||||
assert state.early_stopped is False
|
||||
assert state.iteration == 3
|
||||
assert state.best_validation_score is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_patience_resets_on_improvement(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Patience counter resets when validation improves after degrading."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
evals = [
|
||||
_make_eval(0.3), # initial train
|
||||
_make_eval(0.3), # initial validation
|
||||
# iter 1: degrade
|
||||
_make_eval(0.3), # current train
|
||||
_make_eval(0.5), # new train (accepted)
|
||||
_make_eval(0.2), # validation degrade (patience=1)
|
||||
# iter 2: degrade
|
||||
_make_eval(0.5), # current train
|
||||
_make_eval(0.6), # new train (accepted)
|
||||
_make_eval(0.2), # validation degrade (patience=2)
|
||||
# iter 3: improve! (resets patience)
|
||||
_make_eval(0.6), # current train
|
||||
_make_eval(0.7), # new train (accepted)
|
||||
_make_eval(0.4), # validation improve (patience=0)
|
||||
# iter 4: degrade again
|
||||
_make_eval(0.7), # current train
|
||||
_make_eval(0.8), # new train (accepted)
|
||||
_make_eval(0.2), # validation degrade (patience=1)
|
||||
]
|
||||
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
validation_pool = synthetic_pool[-5:]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=4,
|
||||
minibatch_size=5,
|
||||
early_stop_patience=3,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
validation_pool=validation_pool,
|
||||
)
|
||||
|
||||
assert state.early_stopped is False
|
||||
assert state.iteration == 4
|
||||
189
tests/unit/test_logging.py
Normal file
189
tests/unit/test_logging.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Unit tests for structured logging configuration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from prometheus.cli.logging_setup import configure_logging, get_logger
|
||||
|
||||
|
||||
class TestConfigureLogging:
|
||||
def _count_handlers(self, name: str = "prometheus") -> int:
|
||||
return len(logging.getLogger(name).handlers)
|
||||
|
||||
def test_default_creates_console_handler(self) -> None:
|
||||
configure_logging(level=logging.INFO)
|
||||
prom = logging.getLogger("prometheus")
|
||||
assert len(prom.handlers) == 1
|
||||
assert isinstance(prom.handlers[0], logging.StreamHandler)
|
||||
prom.handlers.clear()
|
||||
|
||||
def test_json_format_produces_valid_json(self, capsys) -> None:
|
||||
configure_logging(level=logging.INFO, log_format="json")
|
||||
logger = get_logger("test_json")
|
||||
logger.info("hello", extra={"structured": {"key": "value"}})
|
||||
|
||||
captured = capsys.readouterr()
|
||||
# Output goes to stderr
|
||||
line = captured.err.strip().split("\n")[-1]
|
||||
data = json.loads(line)
|
||||
assert data["message"] == "hello"
|
||||
assert data["structured"]["key"] == "value"
|
||||
assert data["level"] == "INFO"
|
||||
assert "timestamp" in data
|
||||
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
def test_text_format_includes_structured_extras(self, capsys) -> None:
|
||||
configure_logging(level=logging.INFO, log_format="text")
|
||||
logger = get_logger("test_text")
|
||||
logger.info("msg", extra={"structured": {"foo": "bar"}})
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "foo=bar" in captured.err
|
||||
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
def test_debug_level_shows_debug_messages(self, capsys) -> None:
|
||||
configure_logging(level=logging.DEBUG)
|
||||
logger = get_logger("test_debug")
|
||||
logger.debug("debug msg")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "debug msg" in captured.err
|
||||
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
def test_warning_level_hides_debug_messages(self, capsys) -> None:
|
||||
configure_logging(level=logging.WARNING)
|
||||
logger = get_logger("test_warn")
|
||||
logger.debug("should not appear")
|
||||
logger.info("also hidden")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "should not appear" not in captured.err
|
||||
assert "also hidden" not in captured.err
|
||||
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
def test_file_handler_writes_to_file(self, tmp_path: Path) -> None:
|
||||
log_file = tmp_path / "test.log"
|
||||
configure_logging(level=logging.INFO, log_file=str(log_file))
|
||||
logger = get_logger("test_file")
|
||||
logger.info("file message")
|
||||
|
||||
prom = logging.getLogger("prometheus")
|
||||
# Flush handlers
|
||||
for h in prom.handlers:
|
||||
h.flush()
|
||||
prom.handlers.clear()
|
||||
|
||||
content = log_file.read_text()
|
||||
assert "file message" in content
|
||||
|
||||
def test_json_file_output(self, tmp_path: Path) -> None:
|
||||
log_file = tmp_path / "test.json.log"
|
||||
configure_logging(level=logging.INFO, log_format="json", log_file=str(log_file))
|
||||
logger = get_logger("test_json_file")
|
||||
logger.info("json file msg", extra={"structured": {"x": 1}})
|
||||
|
||||
prom = logging.getLogger("prometheus")
|
||||
for h in prom.handlers:
|
||||
h.flush()
|
||||
prom.handlers.clear()
|
||||
|
||||
content = log_file.read_text().strip()
|
||||
data = json.loads(content)
|
||||
assert data["message"] == "json file msg"
|
||||
assert data["structured"]["x"] == 1
|
||||
|
||||
def test_reconfigure_clears_old_handlers(self) -> None:
|
||||
configure_logging(level=logging.INFO)
|
||||
configure_logging(level=logging.DEBUG)
|
||||
prom = logging.getLogger("prometheus")
|
||||
assert len(prom.handlers) == 1
|
||||
prom.handlers.clear()
|
||||
|
||||
def test_propagate_false_prevents_duplicate_output(self, capsys) -> None:
|
||||
configure_logging(level=logging.INFO)
|
||||
prom = logging.getLogger("prometheus")
|
||||
assert prom.propagate is False
|
||||
prom.handlers.clear()
|
||||
|
||||
|
||||
class TestGetLogger:
|
||||
def test_returns_child_of_prometheus(self) -> None:
|
||||
logger = get_logger("mymodule")
|
||||
assert logger.name == "prometheus.mymodule"
|
||||
|
||||
def test_inherits_level_from_parent(self) -> None:
|
||||
configure_logging(level=logging.DEBUG)
|
||||
logger = get_logger("child")
|
||||
assert logger.getEffectiveLevel() <= logging.DEBUG
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
|
||||
class TestJsonFormatter:
|
||||
def test_exception_included(self, capsys) -> None:
|
||||
configure_logging(level=logging.ERROR, log_format="json")
|
||||
logger = get_logger("test_exc")
|
||||
try:
|
||||
raise ValueError("boom")
|
||||
except ValueError:
|
||||
logger.error("failed", exc_info=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
line = captured.err.strip().split("\n")[-1]
|
||||
data = json.loads(line)
|
||||
assert "ValueError: boom" in data["exception"]
|
||||
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
|
||||
class TestLoggingCLIIntegration:
|
||||
"""Tests for CLI flags that configure logging."""
|
||||
|
||||
def test_verbose_flag_enables_info(self, tmp_path: Path) -> None:
|
||||
"""Simulate what -v does — configure_logging at INFO level."""
|
||||
configure_logging(level=logging.INFO)
|
||||
logger = get_logger("evolution")
|
||||
logger.info("test message")
|
||||
|
||||
prom = logging.getLogger("prometheus")
|
||||
assert len(prom.handlers) == 1
|
||||
prom.handlers.clear()
|
||||
|
||||
def test_debug_flag_enables_debug(self) -> None:
|
||||
"""Simulate what --debug does — configure_logging at DEBUG level."""
|
||||
configure_logging(level=logging.DEBUG)
|
||||
logger = get_logger("evolution")
|
||||
logger.debug("debug message")
|
||||
|
||||
prom = logging.getLogger("prometheus")
|
||||
assert prom.level == logging.DEBUG
|
||||
prom.handlers.clear()
|
||||
|
||||
def test_log_format_invalid_rejected(self) -> None:
|
||||
"""Invalid log_format should be caught by OptimizationConfig validator."""
|
||||
from pydantic import ValidationError
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
|
||||
import pytest
|
||||
|
||||
with pytest.raises(ValidationError, match="log_format must be one of"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a",
|
||||
task_description="b",
|
||||
log_format="xml",
|
||||
)
|
||||
|
||||
def test_log_format_text_and_json_accepted(self) -> None:
|
||||
"""Both text and json log_format values should be valid."""
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
|
||||
for fmt in ("text", "json"):
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", log_format=fmt,
|
||||
)
|
||||
assert config.log_format == fmt
|
||||
96
tests/unit/test_scoring_extended.py
Normal file
96
tests/unit/test_scoring_extended.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Additional unit tests for scoring edge cases."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.domain.entities import EvalResult, Trajectory
|
||||
from prometheus.domain.scoring import normalize_score, should_accept
|
||||
|
||||
|
||||
def _make_eval(scores: list[float]) -> EvalResult:
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=[""] * len(scores),
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", s, "", "p")
|
||||
for i, s in enumerate(scores)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestShouldAcceptEdgeCases:
|
||||
"""Extended edge-case tests for should_accept."""
|
||||
|
||||
def test_tiny_improvement_accepted(self) -> None:
|
||||
old = _make_eval([0.5])
|
||||
new = _make_eval([0.5001])
|
||||
assert should_accept(old, new) is True
|
||||
|
||||
def test_tiny_improvement_below_threshold(self) -> None:
|
||||
old = _make_eval([0.5])
|
||||
new = _make_eval([0.5001])
|
||||
assert should_accept(old, new, min_improvement=0.01) is False
|
||||
|
||||
def test_zero_scores_equal(self) -> None:
|
||||
old = _make_eval([0.0, 0.0])
|
||||
new = _make_eval([0.0, 0.0])
|
||||
assert should_accept(old, new) is False
|
||||
|
||||
def test_negative_to_zero_not_accepted(self) -> None:
|
||||
"""Scores should be [0,1] but test should_accept with edge values."""
|
||||
old = _make_eval([-0.1])
|
||||
new = _make_eval([0.0])
|
||||
assert should_accept(old, new) is True
|
||||
|
||||
def test_large_improvement(self) -> None:
|
||||
old = _make_eval([0.0, 0.0, 0.0])
|
||||
new = _make_eval([1.0, 1.0, 1.0])
|
||||
assert should_accept(old, new) is True
|
||||
|
||||
def test_single_score_improvement(self) -> None:
|
||||
old = _make_eval([0.4])
|
||||
new = _make_eval([0.5])
|
||||
assert should_accept(old, new) is True
|
||||
|
||||
def test_min_improvement_exactly_met(self) -> None:
|
||||
"""When improvement exactly equals min_improvement, still rejected (strict >)."""
|
||||
old = _make_eval([0.5])
|
||||
new = _make_eval([0.7])
|
||||
assert should_accept(old, new, min_improvement=0.2) is False
|
||||
|
||||
def test_min_improvement_just_over(self) -> None:
|
||||
old = _make_eval([0.5])
|
||||
new = _make_eval([0.7001])
|
||||
assert should_accept(old, new, min_improvement=0.2) is True
|
||||
|
||||
|
||||
class TestNormalizeScoreEdgeCases:
|
||||
"""Extended edge-case tests for normalize_score."""
|
||||
|
||||
def test_exact_bounds(self) -> None:
|
||||
assert normalize_score(0.0) == 0.0
|
||||
assert normalize_score(1.0) == 1.0
|
||||
|
||||
def test_very_large_value(self) -> None:
|
||||
assert normalize_score(1e10) == 1.0
|
||||
|
||||
def test_very_negative_value(self) -> None:
|
||||
assert normalize_score(-1e10) == 0.0
|
||||
|
||||
def test_custom_bounds_at_edges(self) -> None:
|
||||
assert normalize_score(5.0, min_val=0.0, max_val=10.0) == 5.0
|
||||
assert normalize_score(0.0, min_val=0.0, max_val=10.0) == 0.0
|
||||
assert normalize_score(10.0, min_val=0.0, max_val=10.0) == 10.0
|
||||
|
||||
def test_negative_custom_range(self) -> None:
|
||||
assert normalize_score(0.0, min_val=-5.0, max_val=5.0) == 0.0
|
||||
assert normalize_score(-3.0, min_val=-5.0, max_val=5.0) == -3.0
|
||||
assert normalize_score(-10.0, min_val=-5.0, max_val=5.0) == -5.0
|
||||
|
||||
def test_zero_span_range(self) -> None:
|
||||
"""When min == max, clamps to min."""
|
||||
assert normalize_score(5.0, min_val=5.0, max_val=5.0) == 5.0
|
||||
assert normalize_score(0.0, min_val=5.0, max_val=5.0) == 5.0
|
||||
|
||||
def test_fractional_score(self) -> None:
|
||||
assert normalize_score(0.3333) == pytest.approx(0.3333)
|
||||
133
tests/unit/test_similarity.py
Normal file
133
tests/unit/test_similarity.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tests for similarity adapters — exact, BLEU, ROUGE-L, cosine."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.infrastructure.similarity import (
|
||||
BleuSimilarity,
|
||||
CosineSimilarity,
|
||||
ExactMatchSimilarity,
|
||||
RougeLSimilarity,
|
||||
create_similarity_adapter,
|
||||
)
|
||||
|
||||
|
||||
class TestExactMatchSimilarity:
|
||||
def test_exact_match(self):
|
||||
s = ExactMatchSimilarity()
|
||||
assert s.compute("Hello World", "Hello World") == 1.0
|
||||
|
||||
def test_case_insensitive(self):
|
||||
s = ExactMatchSimilarity()
|
||||
assert s.compute("hello world", "HELLO WORLD") == 1.0
|
||||
|
||||
def test_whitespace_trimmed(self):
|
||||
s = ExactMatchSimilarity()
|
||||
assert s.compute(" hello ", "hello") == 1.0
|
||||
|
||||
def test_no_match(self):
|
||||
s = ExactMatchSimilarity()
|
||||
assert s.compute("hello", "world") == 0.0
|
||||
|
||||
def test_partial_no_match(self):
|
||||
s = ExactMatchSimilarity()
|
||||
assert s.compute("hello world", "hello") == 0.0
|
||||
|
||||
|
||||
class TestBleuSimilarity:
|
||||
def test_perfect_match(self):
|
||||
s = BleuSimilarity()
|
||||
assert s.compute("the cat sat on the mat", "the cat sat on the mat") == 1.0
|
||||
|
||||
def test_no_overlap(self):
|
||||
s = BleuSimilarity()
|
||||
assert s.compute("aaa bbb ccc", "ddd eee fff") == 0.0
|
||||
|
||||
def test_partial_overlap(self):
|
||||
s = BleuSimilarity()
|
||||
score = s.compute("the cat sat", "the cat")
|
||||
assert 0.0 < score < 1.0
|
||||
|
||||
def test_empty_prediction(self):
|
||||
s = BleuSimilarity()
|
||||
assert s.compute("", "hello world") == 0.0
|
||||
|
||||
def test_empty_expected(self):
|
||||
s = BleuSimilarity()
|
||||
assert s.compute("hello world", "") == 0.0
|
||||
|
||||
def test_both_empty(self):
|
||||
s = BleuSimilarity()
|
||||
assert s.compute("", "") == 0.0
|
||||
|
||||
def test_shorter_prediction_gets_brevity_penalty(self):
|
||||
s = BleuSimilarity()
|
||||
short = s.compute("cat", "the cat sat on the mat")
|
||||
full = s.compute("the cat sat on the mat", "the cat sat on the mat")
|
||||
assert short < full
|
||||
|
||||
|
||||
class TestRougeLSimilarity:
|
||||
def test_perfect_match(self):
|
||||
s = RougeLSimilarity()
|
||||
assert s.compute("the cat sat", "the cat sat") == 1.0
|
||||
|
||||
def test_no_overlap(self):
|
||||
s = RougeLSimilarity()
|
||||
assert s.compute("aaa bbb", "ccc ddd") == 0.0
|
||||
|
||||
def test_partial_overlap(self):
|
||||
s = RougeLSimilarity()
|
||||
score = s.compute("the cat sat on the mat", "the cat on the rug")
|
||||
assert 0.0 < score < 1.0
|
||||
|
||||
def test_empty_prediction(self):
|
||||
s = RougeLSimilarity()
|
||||
assert s.compute("", "hello") == 0.0
|
||||
|
||||
def test_subsequence(self):
|
||||
s = RougeLSimilarity()
|
||||
# "cat mat" is a subsequence of "the cat sat on the mat"
|
||||
score = s.compute("cat mat", "the cat sat on the mat")
|
||||
assert score > 0.0
|
||||
|
||||
|
||||
class TestCosineSimilarity:
|
||||
def test_identical_texts(self):
|
||||
s = CosineSimilarity()
|
||||
assert s.compute("hello world", "hello world") == pytest.approx(1.0)
|
||||
|
||||
def test_no_overlap(self):
|
||||
s = CosineSimilarity()
|
||||
assert s.compute("aaa bbb", "ccc ddd") == 0.0
|
||||
|
||||
def test_partial_overlap(self):
|
||||
s = CosineSimilarity()
|
||||
score = s.compute("hello world foo", "hello world bar")
|
||||
assert 0.0 < score < 1.0
|
||||
|
||||
def test_empty_prediction(self):
|
||||
s = CosineSimilarity()
|
||||
assert s.compute("", "hello") == 0.0
|
||||
|
||||
|
||||
class TestCreateSimilarityAdapter:
|
||||
def test_create_exact(self):
|
||||
adapter = create_similarity_adapter("exact")
|
||||
assert isinstance(adapter, ExactMatchSimilarity)
|
||||
|
||||
def test_create_bleu(self):
|
||||
adapter = create_similarity_adapter("bleu")
|
||||
assert isinstance(adapter, BleuSimilarity)
|
||||
|
||||
def test_create_rouge_l(self):
|
||||
adapter = create_similarity_adapter("rouge_l")
|
||||
assert isinstance(adapter, RougeLSimilarity)
|
||||
|
||||
def test_create_cosine(self):
|
||||
adapter = create_similarity_adapter("cosine")
|
||||
assert isinstance(adapter, CosineSimilarity)
|
||||
|
||||
def test_unknown_metric_raises(self):
|
||||
with pytest.raises(ValueError, match="Unknown eval metric"):
|
||||
create_similarity_adapter("nonexistent")
|
||||
233
tests/unit/test_use_cases.py
Normal file
233
tests/unit/test_use_cases.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Unit tests for OptimizePromptUseCase — direct orchestration tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.dto import OptimizationConfig, OptimizationResult
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.application.use_cases import OptimizePromptUseCase
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
EvalResult,
|
||||
OptimizationState,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
Trajectory,
|
||||
)
|
||||
|
||||
|
||||
def _make_eval(scores: list[float]) -> EvalResult:
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=["feedback"] * len(scores),
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", s, "feedback", "prompt")
|
||||
for i, s in enumerate(scores)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _make_state(
|
||||
iterations: int = 3,
|
||||
initial_score: float = 0.3,
|
||||
final_score: float = 0.8,
|
||||
accepted: bool = True,
|
||||
) -> OptimizationState:
|
||||
seed = Candidate(prompt=Prompt(text="seed"), best_score=initial_score, generation=0)
|
||||
best = Candidate(
|
||||
prompt=Prompt(text="optimized" if accepted else "seed"),
|
||||
best_score=final_score,
|
||||
generation=iterations if accepted else 0,
|
||||
)
|
||||
history = []
|
||||
for i in range(1, iterations + 1):
|
||||
event = "accepted" if accepted else "rejected"
|
||||
history.append({"iteration": i, "event": event, "old_score": 0.3, "new_score": 0.8})
|
||||
|
||||
return OptimizationState(
|
||||
iteration=iterations,
|
||||
best_candidate=best,
|
||||
candidates=[seed, best] if accepted else [seed],
|
||||
total_llm_calls=iterations * 11 + 10,
|
||||
history=history,
|
||||
)
|
||||
|
||||
|
||||
class TestOptimizePromptUseCaseExecute:
|
||||
"""Tests for the execute() orchestration method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_evaluator(self) -> MagicMock:
|
||||
return MagicMock(spec=PromptEvaluator)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_proposer(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bootstrap(self) -> MagicMock:
|
||||
return MagicMock(spec=SyntheticBootstrap)
|
||||
|
||||
@pytest.fixture
|
||||
def use_case(
|
||||
self,
|
||||
mock_evaluator: MagicMock,
|
||||
mock_proposer: MagicMock,
|
||||
mock_bootstrap: MagicMock,
|
||||
) -> OptimizePromptUseCase:
|
||||
return OptimizePromptUseCase(
|
||||
evaluator=mock_evaluator,
|
||||
proposer=mock_proposer,
|
||||
bootstrap=mock_bootstrap,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> OptimizationConfig:
|
||||
return OptimizationConfig(
|
||||
seed_prompt="Answer the question.",
|
||||
task_description="Q&A task",
|
||||
max_iterations=5,
|
||||
n_synthetic_inputs=20,
|
||||
minibatch_size=5,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_optimization_result(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = [
|
||||
SyntheticExample(input_text=f"q{i}", id=i) for i in range(20)
|
||||
]
|
||||
|
||||
mock_state = _make_state(iterations=3, initial_score=0.3, final_score=0.9)
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert isinstance(result, OptimizationResult)
|
||||
assert result.initial_prompt == "Answer the question."
|
||||
assert result.final_score == 0.9
|
||||
assert result.improvement == pytest.approx(0.6)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bootstrap_called_with_config_params(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = []
|
||||
mock_state = _make_state()
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
await use_case.execute(config)
|
||||
|
||||
mock_bootstrap.run.assert_called_once_with(
|
||||
task_description="Q&A task",
|
||||
n_examples=20,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evolution_loop_configured_from_config(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = []
|
||||
mock_state = _make_state()
|
||||
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state) as mock_run:
|
||||
await use_case.execute(config)
|
||||
|
||||
# Verify the loop was instantiated with correct params
|
||||
mock_run.assert_called_once()
|
||||
call_args = mock_run.call_args
|
||||
seed_prompt = call_args[0][0]
|
||||
assert seed_prompt.text == "Answer the question."
|
||||
synthetic_pool = call_args[0][1]
|
||||
assert len(synthetic_pool) == 0 # bootstrap returned empty
|
||||
assert call_args[0][2] == "Q&A task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_total_llm_calls_includes_bootstrap_call(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = []
|
||||
mock_state = _make_state(iterations=3)
|
||||
# total_llm_calls from state + 1 for bootstrap
|
||||
expected = mock_state.total_llm_calls + 1
|
||||
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert result.total_llm_calls == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_candidates_fallback(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = [
|
||||
SyntheticExample(input_text=f"q{i}", id=i) for i in range(20)
|
||||
]
|
||||
mock_state = OptimizationState(
|
||||
iteration=0,
|
||||
best_candidate=None,
|
||||
candidates=[],
|
||||
total_llm_calls=0,
|
||||
)
|
||||
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert result.optimized_prompt == "Answer the question."
|
||||
assert result.initial_score == 0.0
|
||||
assert result.final_score == 0.0
|
||||
assert result.improvement == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_iterations_used_matches_state(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = []
|
||||
mock_state = _make_state(iterations=7)
|
||||
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert result.iterations_used == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_passed_through(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = []
|
||||
history = [
|
||||
{"iteration": 1, "event": "accepted"},
|
||||
{"iteration": 2, "event": "rejected"},
|
||||
]
|
||||
mock_state = _make_state()
|
||||
mock_state.history = history
|
||||
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert result.history == history
|
||||
Reference in New Issue
Block a user