feat: v0.2.0 sprint — ground truth eval, crossover/mutation, checkpointing, similarity guards, dataset loader, CLI commands, extended test coverage
Aggregates all v0.2.0 sprint work (GARAA-30 through GARAA-40) and fixes 2 integration tests that broke when the codebase went async (DSPyLLMAdapter and full pipeline tests now properly await coroutines). 277 tests pass (260 unit + 17 integration). Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
294
tests/unit/test_adapters.py
Normal file
294
tests/unit/test_adapters.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Unit tests for infrastructure adapters — LLM, Judge, Proposer, Synthetic.
|
||||
|
||||
Uses mocked DSPy modules to isolate adapter logic from LLM calls.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import dspy
|
||||
import pytest
|
||||
|
||||
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory
|
||||
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
||||
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
||||
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
||||
|
||||
|
||||
# --- LLM Adapter ---
|
||||
|
||||
|
||||
class TestDSPyLLMAdapter:
|
||||
"""Tests for DSPyLLMAdapter.execute()."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lm(self) -> MagicMock:
|
||||
return MagicMock(spec=dspy.LM)
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, mock_lm: MagicMock) -> DSPyLLMAdapter:
|
||||
return DSPyLLMAdapter(lm=mock_lm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_output_string(
|
||||
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
mock_predictor = MagicMock()
|
||||
mock_predictor.return_value = MagicMock(output="Hello response")
|
||||
adapter._predictor = mock_predictor
|
||||
|
||||
prompt = Prompt(text="Say hello.")
|
||||
result = await adapter.execute(prompt, "Hi there")
|
||||
|
||||
assert result == "Hello response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_passes_prompt_text_and_input(
|
||||
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
mock_predictor = MagicMock()
|
||||
mock_predictor.return_value = MagicMock(output="response")
|
||||
adapter._predictor = mock_predictor
|
||||
|
||||
prompt = Prompt(text="Translate this.")
|
||||
await adapter.execute(prompt, "Hello world")
|
||||
|
||||
mock_predictor.assert_called_once_with(
|
||||
instruction="Translate this.",
|
||||
input_text="Hello world",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_uses_dspy_context(
|
||||
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
mock_predictor = MagicMock()
|
||||
mock_predictor.return_value = MagicMock(output="ok")
|
||||
adapter._predictor = mock_predictor
|
||||
|
||||
with patch("prometheus.infrastructure.llm_adapter.dspy.context") as mock_ctx:
|
||||
await adapter.execute(Prompt(text="test"), "input")
|
||||
mock_ctx.assert_called_once_with(lm=mock_lm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_converts_output_to_str(
|
||||
self, adapter: DSPyLLMAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
mock_predictor = MagicMock()
|
||||
mock_predictor.return_value = MagicMock(output=42)
|
||||
adapter._predictor = mock_predictor
|
||||
|
||||
result = await adapter.execute(Prompt(text="test"), "input")
|
||||
assert isinstance(result, str)
|
||||
assert result == "42"
|
||||
|
||||
|
||||
# --- Judge Adapter ---
|
||||
|
||||
|
||||
class TestDSPyJudgeAdapter:
|
||||
"""Tests for DSPyJudgeAdapter.judge_batch()."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lm(self) -> MagicMock:
|
||||
return MagicMock(spec=dspy.LM)
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, mock_lm: MagicMock) -> DSPyJudgeAdapter:
|
||||
return DSPyJudgeAdapter(lm=mock_lm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_batch_returns_scores_and_feedback(
|
||||
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._judge = MagicMock()
|
||||
adapter._judge.side_effect = [
|
||||
MagicMock(score=0.9, feedback="Excellent."),
|
||||
MagicMock(score=0.4, feedback="Incomplete."),
|
||||
]
|
||||
|
||||
pairs = [("What is 2+2?", "4"), ("Capital of France?", "London")]
|
||||
result = await adapter.judge_batch("math and geography", pairs)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] == (0.9, "Excellent.")
|
||||
assert result[1] == (0.4, "Incomplete.")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_batch_empty_pairs(
|
||||
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
result = await adapter.judge_batch("task", [])
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_batch_uses_dspy_context(
|
||||
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._judge = MagicMock()
|
||||
adapter._judge.return_value = MagicMock(score=0.5, feedback="ok")
|
||||
|
||||
with patch("prometheus.infrastructure.judge_adapter.dspy.context") as mock_ctx:
|
||||
await adapter.judge_batch("task", [("in", "out")])
|
||||
mock_ctx.assert_called_once_with(lm=mock_lm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_batch_returns_all_results(
|
||||
self, adapter: DSPyJudgeAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
"""Judge calls run in parallel but all results are returned."""
|
||||
adapter._judge = MagicMock()
|
||||
adapter._judge.side_effect = [
|
||||
MagicMock(score=0.5, feedback="ok"),
|
||||
MagicMock(score=0.7, feedback="better"),
|
||||
MagicMock(score=0.3, feedback="worse"),
|
||||
]
|
||||
|
||||
pairs = [("first", "out1"), ("second", "out2"), ("third", "out3")]
|
||||
results = await adapter.judge_batch("task", pairs)
|
||||
|
||||
assert len(results) == 3
|
||||
scores = [r[0] for r in results]
|
||||
assert 0.5 in scores
|
||||
assert 0.7 in scores
|
||||
assert 0.3 in scores
|
||||
|
||||
|
||||
# --- Proposer Adapter ---
|
||||
|
||||
|
||||
class TestDSPyProposerAdapter:
|
||||
"""Tests for DSPyProposerAdapter.propose()."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lm(self) -> MagicMock:
|
||||
return MagicMock(spec=dspy.LM)
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, mock_lm: MagicMock) -> DSPyProposerAdapter:
|
||||
return DSPyProposerAdapter(lm=mock_lm)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propose_returns_new_prompt(
|
||||
self, adapter: DSPyProposerAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._proposer = MagicMock()
|
||||
adapter._proposer.return_value = MagicMock(
|
||||
new_instruction="Be concise and accurate."
|
||||
)
|
||||
|
||||
current = Prompt(text="Answer questions.")
|
||||
trajectories = [
|
||||
Trajectory("in", "out", 0.3, "too verbose", "Answer questions.")
|
||||
]
|
||||
result = await adapter.propose(current, trajectories, "Q&A task")
|
||||
|
||||
assert isinstance(result, Prompt)
|
||||
assert result.text == "Be concise and accurate."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propose_uses_dspy_context(
|
||||
self, adapter: DSPyProposerAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._proposer = MagicMock()
|
||||
adapter._proposer.return_value = MagicMock(new_instruction="improved")
|
||||
|
||||
with patch("prometheus.infrastructure.proposer_adapter.dspy.context") as mock_ctx:
|
||||
await adapter.propose(Prompt(text="t"), [], "task")
|
||||
mock_ctx.assert_called_once_with(lm=mock_lm)
|
||||
|
||||
def test_format_failures_single_trajectory(self) -> None:
|
||||
trajectories = [
|
||||
Trajectory("What is AI?", "A type of robot.", 0.3, "Incomplete definition.", "prompt")
|
||||
]
|
||||
result = DSPyProposerAdapter._format_failures(trajectories)
|
||||
|
||||
assert "What is AI?" in result
|
||||
assert "A type of robot." in result
|
||||
assert "0.30" in result
|
||||
assert "Incomplete definition." in result
|
||||
assert "# Example 1" in result
|
||||
|
||||
def test_format_failures_multiple_trajectories(self) -> None:
|
||||
trajectories = [
|
||||
Trajectory("input1", "output1", 0.4, "bad", "prompt"),
|
||||
Trajectory("input2", "output2", 0.2, "worse", "prompt"),
|
||||
]
|
||||
result = DSPyProposerAdapter._format_failures(trajectories)
|
||||
|
||||
assert "# Example 1" in result
|
||||
assert "# Example 2" in result
|
||||
assert "---" in result
|
||||
assert "input1" in result
|
||||
assert "input2" in result
|
||||
|
||||
def test_format_failures_empty_list(self) -> None:
|
||||
result = DSPyProposerAdapter._format_failures([])
|
||||
assert result == ""
|
||||
|
||||
|
||||
# --- Synthetic Adapter ---
|
||||
|
||||
|
||||
class TestDSPySyntheticAdapter:
|
||||
"""Tests for DSPySyntheticAdapter.generate_inputs()."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lm(self) -> MagicMock:
|
||||
return MagicMock(spec=dspy.LM)
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self, mock_lm: MagicMock) -> DSPySyntheticAdapter:
|
||||
return DSPySyntheticAdapter(lm=mock_lm)
|
||||
|
||||
def test_generate_inputs_returns_examples(
|
||||
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._generator = MagicMock()
|
||||
adapter._generator.return_value = MagicMock(
|
||||
examples=["What is AI?", "Explain ML.", "What is NLP?"]
|
||||
)
|
||||
|
||||
result = adapter.generate_inputs("AI task", 3)
|
||||
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(ex, SyntheticExample) for ex in result)
|
||||
assert result[0].input_text == "What is AI?"
|
||||
assert result[0].id == 0
|
||||
assert result[1].id == 1
|
||||
|
||||
def test_generate_inputs_truncates_to_n(
|
||||
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._generator = MagicMock()
|
||||
adapter._generator.return_value = MagicMock(
|
||||
examples=["q1", "q2", "q3", "q4", "q5"]
|
||||
)
|
||||
|
||||
result = adapter.generate_inputs("task", 3)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
def test_generate_inputs_passes_correct_args(
|
||||
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._generator = MagicMock()
|
||||
adapter._generator.return_value = MagicMock(examples=["q1"])
|
||||
|
||||
adapter.generate_inputs("my task", 5)
|
||||
|
||||
adapter._generator.assert_called_once_with(
|
||||
task_description="my task",
|
||||
n_examples=5,
|
||||
)
|
||||
|
||||
def test_generate_inputs_empty_list(
|
||||
self, adapter: DSPySyntheticAdapter, mock_lm: MagicMock
|
||||
) -> None:
|
||||
adapter._generator = MagicMock()
|
||||
adapter._generator.return_value = MagicMock(examples=[])
|
||||
|
||||
result = adapter.generate_inputs("task", 0)
|
||||
|
||||
assert result == []
|
||||
333
tests/unit/test_checkpoint.py
Normal file
333
tests/unit/test_checkpoint.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Unit tests for checkpoint & resume functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
EvalResult,
|
||||
OptimizationState,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
Trajectory,
|
||||
)
|
||||
from prometheus.infrastructure.checkpoint import JsonCheckpointPersistence
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JsonCheckpointPersistence — save/load round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJsonCheckpointPersistence:
|
||||
def test_roundtrip_full_state(self, tmp_path: Path) -> None:
|
||||
"""Saving and loading preserves all fields."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
|
||||
|
||||
state = OptimizationState(
|
||||
iteration=7,
|
||||
best_candidate=Candidate(
|
||||
prompt=Prompt(text="best prompt", metadata={"source": "test"}),
|
||||
best_score=0.92,
|
||||
generation=5,
|
||||
),
|
||||
candidates=[
|
||||
Candidate(prompt=Prompt(text="p1"), best_score=0.5, generation=0),
|
||||
Candidate(prompt=Prompt(text="p2"), best_score=0.92, generation=5),
|
||||
],
|
||||
synthetic_pool=[
|
||||
SyntheticExample(input_text="q1", category="cat_a", id=0),
|
||||
SyntheticExample(input_text="q2", category="cat_b", id=1),
|
||||
],
|
||||
history=[{"iteration": 1, "event": "accepted", "old_score": 0.5, "new_score": 0.7}],
|
||||
total_llm_calls=42,
|
||||
)
|
||||
|
||||
ckpt.save(state)
|
||||
assert ckpt.latest_exists()
|
||||
|
||||
loaded = ckpt.load()
|
||||
assert loaded is not None
|
||||
assert loaded.iteration == 7
|
||||
assert loaded.total_llm_calls == 42
|
||||
assert loaded.best_candidate is not None
|
||||
assert loaded.best_candidate.prompt.text == "best prompt"
|
||||
assert loaded.best_candidate.prompt.metadata == {"source": "test"}
|
||||
assert loaded.best_candidate.best_score == 0.92
|
||||
assert len(loaded.candidates) == 2
|
||||
assert len(loaded.synthetic_pool) == 2
|
||||
assert loaded.synthetic_pool[0].input_text == "q1"
|
||||
assert loaded.synthetic_pool[1].category == "cat_b"
|
||||
assert loaded.history[0]["event"] == "accepted"
|
||||
|
||||
def test_load_returns_none_when_no_checkpoint(self, tmp_path: Path) -> None:
|
||||
"""Loading from empty dir returns None."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "nope")
|
||||
assert ckpt.load() is None
|
||||
assert not ckpt.latest_exists()
|
||||
|
||||
def test_creates_directory_on_save(self, tmp_path: Path) -> None:
|
||||
"""Save creates the directory tree if it doesn't exist."""
|
||||
deep_dir = tmp_path / "a" / "b" / "c"
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=deep_dir)
|
||||
state = OptimizationState(iteration=1)
|
||||
ckpt.save(state)
|
||||
assert (deep_dir / "latest.json").exists()
|
||||
|
||||
def test_overwrites_previous_checkpoint(self, tmp_path: Path) -> None:
|
||||
"""Second save overwrites the first."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
|
||||
|
||||
ckpt.save(OptimizationState(iteration=1, total_llm_calls=10))
|
||||
ckpt.save(OptimizationState(iteration=5, total_llm_calls=50))
|
||||
|
||||
loaded = ckpt.load()
|
||||
assert loaded is not None
|
||||
assert loaded.iteration == 5
|
||||
assert loaded.total_llm_calls == 50
|
||||
|
||||
def test_json_is_human_readable(self, tmp_path: Path) -> None:
|
||||
"""Checkpoint file is valid, pretty-printed JSON."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
|
||||
state = OptimizationState(
|
||||
iteration=3,
|
||||
best_candidate=Candidate(prompt=Prompt(text="hello"), best_score=0.8),
|
||||
)
|
||||
ckpt.save(state)
|
||||
|
||||
raw = json.loads((tmp_path / "latest.json").read_text())
|
||||
assert raw["schema_version"] == 1
|
||||
assert raw["iteration"] == 3
|
||||
assert raw["best_candidate"]["prompt_text"] == "hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EvolutionLoop — checkpoint integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEvolutionCheckpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_saved_on_interval(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""Checkpoint is saved every checkpoint_interval iterations."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
# All iterations accepted so checkpoint triggers
|
||||
good_eval = EvalResult(
|
||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"out{i}", s, "ok", "p")
|
||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||
],
|
||||
)
|
||||
better_eval = EvalResult(
|
||||
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
# initial_eval + 5 iterations (each needs old_eval + new_eval)
|
||||
evaluator.evaluate = AsyncMock(
|
||||
side_effect=[good_eval] # initial
|
||||
+ [good_eval, better_eval] * 5 # 5 iterations
|
||||
)
|
||||
|
||||
proposer = AsyncMock()
|
||||
proposer.propose.return_value = Prompt(text="improved prompt")
|
||||
|
||||
checkpoint_port = MagicMock()
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=5,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=checkpoint_port,
|
||||
checkpoint_interval=2,
|
||||
)
|
||||
|
||||
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
# Checkpoint at iterations 2, 4 (every 2nd)
|
||||
save_calls = checkpoint_port.save.call_count
|
||||
assert save_calls >= 2 # at least at iters 2 and 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_checkpoint_without_port(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""No checkpointing happens when checkpoint_port is None (default)."""
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
perfect_eval = EvalResult(
|
||||
scores=[1.0] * 5,
|
||||
feedbacks=["perfect"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", 1.0, "perfect", "p")
|
||||
for i in range(5)
|
||||
],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=AsyncMock(),
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=None,
|
||||
)
|
||||
# Should run without error — no checkpoint port, no crash
|
||||
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_skips_seed_evaluation(
|
||||
self,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""When initial_state is provided, seed eval is skipped and loop starts from saved iteration."""
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
proposer = AsyncMock()
|
||||
proposer.propose.return_value = Prompt(text="new prompt")
|
||||
|
||||
# Only return evaluations for resumed iterations (1 iter: old_eval + new_eval)
|
||||
old_eval = EvalResult(
|
||||
scores=[0.5] * 5,
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", 0.5, "ok", "p") for i in range(5)
|
||||
],
|
||||
)
|
||||
new_eval = EvalResult(
|
||||
scores=[0.8] * 5,
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
|
||||
|
||||
# Create a state simulating checkpoint at iteration 4
|
||||
initial_state = OptimizationState(
|
||||
iteration=4,
|
||||
best_candidate=Candidate(
|
||||
prompt=Prompt(text="checkpoint prompt"), best_score=2.5, generation=4
|
||||
),
|
||||
candidates=[Candidate(prompt=Prompt(text="checkpoint prompt"), best_score=2.5)],
|
||||
total_llm_calls=40,
|
||||
)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=5, # only iteration 5 remains
|
||||
minibatch_size=5,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt=Prompt(text="seed"),
|
||||
synthetic_pool=synthetic_pool,
|
||||
task_description=task_description,
|
||||
initial_state=initial_state,
|
||||
)
|
||||
|
||||
# Should have run only 1 iteration (iter 5)
|
||||
assert state.iteration == 5
|
||||
# total_llm_calls should include the 40 from checkpoint + new calls
|
||||
assert state.total_llm_calls > 40
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_save_and_resume_roundtrip(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""End-to-end: run a few iterations, checkpoint, resume, finish."""
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
old_eval = EvalResult(
|
||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", s, "ok", "p")
|
||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||
],
|
||||
)
|
||||
new_eval = EvalResult(
|
||||
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(
|
||||
side_effect=[old_eval, old_eval, new_eval, old_eval, new_eval]
|
||||
)
|
||||
proposer = AsyncMock()
|
||||
proposer.propose.return_value = Prompt(text="improved prompt")
|
||||
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=2,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=ckpt,
|
||||
checkpoint_interval=1,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
assert state.iteration == 2
|
||||
assert ckpt.latest_exists()
|
||||
|
||||
# Capture the checkpoint state *before* resume (state is mutated in-place)
|
||||
loaded = ckpt.load()
|
||||
assert loaded is not None
|
||||
saved_llm_calls = loaded.total_llm_calls
|
||||
saved_iteration = loaded.iteration
|
||||
|
||||
# Set up evaluator for resumed run (just 1 more iteration)
|
||||
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
|
||||
proposer.propose.return_value = Prompt(text="even better prompt")
|
||||
|
||||
loop2 = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=ckpt,
|
||||
checkpoint_interval=1,
|
||||
)
|
||||
resumed = await loop2.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
initial_state=loaded,
|
||||
)
|
||||
assert resumed.iteration == 3
|
||||
assert resumed.total_llm_calls > saved_llm_calls
|
||||
assert resumed.iteration > saved_iteration
|
||||
278
tests/unit/test_cli.py
Normal file
278
tests/unit/test_cli.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Tests for the CLI interface — prometheus optimize, version, etc.
|
||||
|
||||
Uses Typer's CliRunner for isolated command testing.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from prometheus.application.dto import OptimizationResult
|
||||
from prometheus.cli.app import app
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
class TestCLIOptimize:
|
||||
"""Tests for the `prometheus optimize` command."""
|
||||
|
||||
def _write_config(self, tmp_path: Path, **overrides: object) -> Path:
|
||||
"""Write a minimal valid config YAML and return its path."""
|
||||
data = {
|
||||
"seed_prompt": "You are a helpful assistant.",
|
||||
"task_description": "Answer factual questions accurately.",
|
||||
}
|
||||
data.update(overrides)
|
||||
config_file = tmp_path / "config.yaml"
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(data, f)
|
||||
return config_file
|
||||
|
||||
def test_optimize_with_valid_config(self, tmp_path: Path) -> None:
|
||||
config_file = self._write_config(tmp_path)
|
||||
output_file = tmp_path / "output.yaml"
|
||||
|
||||
mock_result = OptimizationResult(
|
||||
optimized_prompt="Improved prompt",
|
||||
initial_prompt="You are a helpful assistant.",
|
||||
iterations_used=5,
|
||||
total_llm_calls=50,
|
||||
initial_score=0.3,
|
||||
final_score=0.9,
|
||||
improvement=0.6,
|
||||
history=[],
|
||||
)
|
||||
|
||||
mock_uc = AsyncMock()
|
||||
mock_uc.execute.return_value = mock_result
|
||||
|
||||
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||
mock_llm_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||
mock_judge_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||
mock_prop_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"optimize",
|
||||
"-i",
|
||||
str(config_file),
|
||||
"-o",
|
||||
str(output_file),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Optimized Prompt" in result.output
|
||||
|
||||
def test_optimize_missing_input_file(self) -> None:
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["optimize", "-i", "/nonexistent/config.yaml"],
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
|
||||
def test_optimize_with_verbose_flag(self, tmp_path: Path) -> None:
|
||||
config_file = self._write_config(tmp_path)
|
||||
output_file = tmp_path / "output.yaml"
|
||||
|
||||
mock_result = OptimizationResult(
|
||||
optimized_prompt="Improved",
|
||||
initial_prompt="test",
|
||||
iterations_used=1,
|
||||
total_llm_calls=10,
|
||||
initial_score=0.3,
|
||||
final_score=0.8,
|
||||
improvement=0.5,
|
||||
history=[],
|
||||
)
|
||||
|
||||
mock_uc = AsyncMock()
|
||||
mock_uc.execute.return_value = mock_result
|
||||
|
||||
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||
mock_llm_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||
mock_judge_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||
mock_prop_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"optimize",
|
||||
"-i",
|
||||
str(config_file),
|
||||
"-o",
|
||||
str(output_file),
|
||||
"-v",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_optimize_displays_metrics(self, tmp_path: Path) -> None:
|
||||
config_file = self._write_config(tmp_path)
|
||||
output_file = tmp_path / "output.yaml"
|
||||
|
||||
mock_result = OptimizationResult(
|
||||
optimized_prompt="Better prompt",
|
||||
initial_prompt="test",
|
||||
iterations_used=3,
|
||||
total_llm_calls=30,
|
||||
initial_score=0.40,
|
||||
final_score=0.85,
|
||||
improvement=0.45,
|
||||
history=[],
|
||||
)
|
||||
|
||||
mock_uc = AsyncMock()
|
||||
mock_uc.execute.return_value = mock_result
|
||||
|
||||
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||
mock_llm_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||
mock_judge_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||
mock_prop_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"optimize",
|
||||
"-i",
|
||||
str(config_file),
|
||||
"-o",
|
||||
str(output_file),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "0.40" in result.output
|
||||
assert "0.85" in result.output
|
||||
assert "+0.45" in result.output
|
||||
|
||||
def test_optimize_with_max_concurrency_flag(self, tmp_path: Path) -> None:
|
||||
config_file = self._write_config(tmp_path)
|
||||
output_file = tmp_path / "output.yaml"
|
||||
|
||||
mock_result = OptimizationResult(
|
||||
optimized_prompt="Better prompt",
|
||||
initial_prompt="test",
|
||||
iterations_used=1,
|
||||
total_llm_calls=10,
|
||||
initial_score=0.3,
|
||||
final_score=0.8,
|
||||
improvement=0.5,
|
||||
history=[],
|
||||
)
|
||||
|
||||
mock_uc = AsyncMock()
|
||||
mock_uc.execute.return_value = mock_result
|
||||
|
||||
with patch("prometheus.cli.commands.optimize.OptimizePromptUseCase", return_value=mock_uc):
|
||||
with patch("prometheus.cli.commands.optimize.DSPySyntheticAdapter"):
|
||||
with patch("prometheus.cli.commands.optimize.DSPyLLMAdapter") as mock_llm_cls:
|
||||
mock_llm_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyJudgeAdapter") as mock_judge_cls:
|
||||
mock_judge_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.DSPyProposerAdapter") as mock_prop_cls:
|
||||
mock_prop_cls.return_value = MagicMock()
|
||||
with patch("prometheus.cli.commands.optimize.dspy"):
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"optimize",
|
||||
"-i",
|
||||
str(config_file),
|
||||
"-o",
|
||||
str(output_file),
|
||||
"--max-concurrency",
|
||||
"10",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestCLIHelp:
|
||||
"""Tests for CLI help and no-args behavior."""
|
||||
|
||||
def test_no_args_shows_help(self) -> None:
|
||||
result = runner.invoke(app, [])
|
||||
# Typer uses exit code 2 when no_args_is_help=True
|
||||
assert result.exit_code in (0, 2)
|
||||
assert "PROMETHEUS" in result.output or "Usage" in result.output
|
||||
|
||||
def test_optimize_help(self) -> None:
|
||||
result = runner.invoke(app, ["optimize", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "input" in result.output.lower() or "INPUT" in result.output
|
||||
|
||||
def test_version_help(self) -> None:
|
||||
result = runner.invoke(app, ["version", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_init_help(self) -> None:
|
||||
result = runner.invoke(app, ["init", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_list_help(self) -> None:
|
||||
result = runner.invoke(app, ["list", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestCLIVersion:
|
||||
"""Tests for the `prometheus version` command."""
|
||||
|
||||
def test_version_prints_version(self) -> None:
|
||||
result = runner.invoke(app, ["version"])
|
||||
assert result.exit_code == 0
|
||||
assert "PROMETHEUS" in result.output
|
||||
assert "0.1.0" in result.output
|
||||
|
||||
|
||||
class TestCLIList:
|
||||
"""Tests for the `prometheus list` command."""
|
||||
|
||||
def test_list_no_runs(self, tmp_path: Path) -> None:
|
||||
result = runner.invoke(app, ["list", "-d", str(tmp_path)])
|
||||
assert result.exit_code == 0
|
||||
assert "No optimization runs found" in result.output
|
||||
|
||||
def test_list_with_result(self, tmp_path: Path) -> None:
|
||||
result_data = {
|
||||
"optimized_prompt": "Better prompt for testing",
|
||||
"initial_prompt": "test",
|
||||
"iterations_used": 5,
|
||||
"total_llm_calls": 50,
|
||||
"initial_score": 0.30,
|
||||
"final_score": 0.90,
|
||||
"improvement": 0.60,
|
||||
"history": [],
|
||||
}
|
||||
result_file = tmp_path / "output.yaml"
|
||||
import yaml as _yaml
|
||||
with open(result_file, "w") as f:
|
||||
_yaml.dump(result_data, f)
|
||||
|
||||
result = runner.invoke(app, ["list", "-d", str(tmp_path)])
|
||||
assert result.exit_code == 0
|
||||
assert "0.30" in result.output
|
||||
assert "0.90" in result.output
|
||||
|
||||
def test_list_nonexistent_directory(self) -> None:
|
||||
result = runner.invoke(app, ["list", "-d", "/nonexistent/dir"])
|
||||
assert result.exit_code == 1
|
||||
@@ -300,3 +300,33 @@ class TestConfigValidation:
|
||||
)
|
||||
assert config.max_iterations == 1
|
||||
assert config.perfect_score == 0.0
|
||||
|
||||
|
||||
class TestEvalConfigValidation:
|
||||
"""Tests for ground-truth evaluation config fields."""
|
||||
|
||||
def test_eval_defaults(self) -> None:
|
||||
config = OptimizationConfig(seed_prompt="a", task_description="b")
|
||||
assert config.eval_dataset_path is None
|
||||
assert config.eval_metric == "bleu"
|
||||
|
||||
def test_eval_dataset_path_set(self) -> None:
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a", task_description="b",
|
||||
eval_dataset_path="data.csv",
|
||||
)
|
||||
assert config.eval_dataset_path == "data.csv"
|
||||
|
||||
def test_valid_eval_metrics(self) -> None:
|
||||
for metric in ("exact", "bleu", "rouge_l", "cosine", "llm_judge"):
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", eval_metric=metric,
|
||||
)
|
||||
assert config.eval_metric == metric
|
||||
|
||||
def test_invalid_eval_metric_raises(self) -> None:
|
||||
with pytest.raises(ValidationError, match="eval_metric must be one of"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a", task_description="b",
|
||||
eval_metric="invalid_metric",
|
||||
)
|
||||
|
||||
86
tests/unit/test_dataset_loader.py
Normal file
86
tests/unit/test_dataset_loader.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Tests for the ground-truth dataset loader."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.domain.entities import GroundTruthExample
|
||||
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader():
|
||||
return FileDatasetLoader()
|
||||
|
||||
|
||||
class TestCsvLoader:
|
||||
def test_load_csv(self, loader, tmp_path):
|
||||
csv_file = tmp_path / "test.csv"
|
||||
csv_file.write_text("input,expected_output\nhello,world\nfoo,bar\n")
|
||||
result = loader.load(str(csv_file))
|
||||
assert len(result) == 2
|
||||
assert result[0].input_text == "hello"
|
||||
assert result[0].expected_output == "world"
|
||||
assert result[1].input_text == "foo"
|
||||
assert result[1].expected_output == "bar"
|
||||
|
||||
def test_load_csv_skips_empty_input(self, loader, tmp_path):
|
||||
csv_file = tmp_path / "test.csv"
|
||||
csv_file.write_text("input,expected_output\n,bar\nhello,world\n")
|
||||
result = loader.load(str(csv_file))
|
||||
assert len(result) == 1
|
||||
assert result[0].input_text == "hello"
|
||||
|
||||
def test_load_csv_with_whitespace(self, loader, tmp_path):
|
||||
csv_file = tmp_path / "test.csv"
|
||||
csv_file.write_text("input,expected_output\n hello , world \n")
|
||||
result = loader.load(str(csv_file))
|
||||
assert result[0].input_text == "hello"
|
||||
assert result[0].expected_output == "world"
|
||||
|
||||
def test_load_csv_empty_file(self, loader, tmp_path):
|
||||
csv_file = tmp_path / "test.csv"
|
||||
csv_file.write_text("input,expected_output\n")
|
||||
result = loader.load(str(csv_file))
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestJsonLoader:
|
||||
def test_load_json(self, loader, tmp_path):
|
||||
json_file = tmp_path / "test.json"
|
||||
data = [
|
||||
{"input": "hello", "expected_output": "world"},
|
||||
{"input": "foo", "expected_output": "bar"},
|
||||
]
|
||||
json_file.write_text(json.dumps(data))
|
||||
result = loader.load(str(json_file))
|
||||
assert len(result) == 2
|
||||
assert result[0].input_text == "hello"
|
||||
assert result[0].expected_output == "world"
|
||||
|
||||
def test_load_json_skips_empty_input(self, loader, tmp_path):
|
||||
json_file = tmp_path / "test.json"
|
||||
data = [
|
||||
{"input": "", "expected_output": "bar"},
|
||||
{"input": "hello", "expected_output": "world"},
|
||||
]
|
||||
json_file.write_text(json.dumps(data))
|
||||
result = loader.load(str(json_file))
|
||||
assert len(result) == 1
|
||||
|
||||
def test_load_json_not_array_raises(self, loader, tmp_path):
|
||||
json_file = tmp_path / "test.json"
|
||||
json_file.write_text(json.dumps({"not": "an array"}))
|
||||
with pytest.raises(ValueError, match="must be an array"):
|
||||
loader.load(str(json_file))
|
||||
|
||||
|
||||
class TestUnsupportedFormat:
|
||||
def test_unsupported_extension_raises(self, loader, tmp_path):
|
||||
txt_file = tmp_path / "test.txt"
|
||||
txt_file.write_text("hello")
|
||||
with pytest.raises(ValueError, match="Unsupported dataset format"):
|
||||
loader.load(str(txt_file))
|
||||
@@ -278,6 +278,7 @@ class TestPerCallIsolation:
|
||||
adapter._judge_dimensions = []
|
||||
adapter._dimension_names = ""
|
||||
adapter._weights = {}
|
||||
adapter.call_count = 0
|
||||
|
||||
# Mock _judge to fail on first call, succeed on second
|
||||
call_count = 0
|
||||
|
||||
@@ -8,10 +8,30 @@ import pytest
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
EvalResult,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
Trajectory,
|
||||
)
|
||||
|
||||
|
||||
def _make_eval(scores: list[float], label: str = "ok") -> EvalResult:
|
||||
"""Helper to build an EvalResult from a list of scores."""
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=[label] * len(scores),
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, label, "prompt")
|
||||
for i, s in enumerate(scores)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestEvolutionLoop:
|
||||
"""Tests for the original single-candidate hill-climbing mode (population_size=1)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accepts_improvement(
|
||||
self,
|
||||
@@ -27,28 +47,9 @@ class TestEvolutionLoop:
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
initial_eval = EvalResult(
|
||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||
feedbacks=["bad"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, "bad", "prompt")
|
||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||
],
|
||||
)
|
||||
old_eval = EvalResult(
|
||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||
feedbacks=["bad"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, "bad", "prompt")
|
||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||
],
|
||||
)
|
||||
new_eval = EvalResult(
|
||||
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(side_effect=[initial_eval, old_eval, new_eval])
|
||||
low_eval = _make_eval([0.3, 0.4, 0.3, 0.5, 0.2], "bad")
|
||||
high_eval = _make_eval([0.8, 0.9, 0.7, 0.8, 0.9], "good")
|
||||
evaluator.evaluate = AsyncMock(side_effect=[low_eval, low_eval, high_eval])
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
@@ -57,8 +58,7 @@ class TestEvolutionLoop:
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
)
|
||||
with patch.object(loop, "_log"):
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert state.best_candidate is not None
|
||||
assert state.best_candidate.best_score > 0
|
||||
@@ -78,28 +78,9 @@ class TestEvolutionLoop:
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
initial_eval = EvalResult(
|
||||
scores=[0.7, 0.8, 0.7, 0.8, 0.9],
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, "ok", "prompt")
|
||||
for i, s in enumerate([0.7, 0.8, 0.7, 0.8, 0.9])
|
||||
],
|
||||
)
|
||||
old_eval = EvalResult(
|
||||
scores=[0.7, 0.8, 0.7, 0.8, 0.9],
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, "ok", "prompt")
|
||||
for i, s in enumerate([0.7, 0.8, 0.7, 0.8, 0.9])
|
||||
],
|
||||
)
|
||||
new_eval = EvalResult(
|
||||
scores=[0.2, 0.1, 0.3, 0.2, 0.1],
|
||||
feedbacks=["bad"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(side_effect=[initial_eval, old_eval, new_eval])
|
||||
high_eval = _make_eval([0.7, 0.8, 0.7, 0.8, 0.9], "ok")
|
||||
low_eval = _make_eval([0.2, 0.1, 0.3, 0.2, 0.1], "bad")
|
||||
evaluator.evaluate = AsyncMock(side_effect=[high_eval, high_eval, low_eval])
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
@@ -108,8 +89,7 @@ class TestEvolutionLoop:
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
)
|
||||
with patch.object(loop, "_log"):
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert state.best_candidate is not None
|
||||
assert state.best_candidate.prompt.text == seed_prompt.text
|
||||
@@ -129,14 +109,7 @@ class TestEvolutionLoop:
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
perfect_eval = EvalResult(
|
||||
scores=[1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
feedbacks=["perfect"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", 1.0, "perfect", "prompt")
|
||||
for i in range(5)
|
||||
],
|
||||
)
|
||||
perfect_eval = _make_eval([1.0, 1.0, 1.0, 1.0, 1.0], "perfect")
|
||||
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
@@ -146,7 +119,226 @@ class TestEvolutionLoop:
|
||||
max_iterations=3,
|
||||
minibatch_size=5,
|
||||
)
|
||||
with patch.object(loop, "_log"):
|
||||
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
mock_proposer_port.propose.assert_not_called()
|
||||
|
||||
|
||||
class TestPopulationEvolution:
|
||||
"""Tests for population-based evolution (population_size > 1)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_population_initialization(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
mock_mutation_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Population is initialized with the right number of candidates."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
evaluator.evaluate = AsyncMock(
|
||||
return_value=_make_eval([0.5] * 5, "ok")
|
||||
)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=0, # no iterations, just initialization
|
||||
minibatch_size=5,
|
||||
population_size=4,
|
||||
mutation_port=mock_mutation_port,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
# 1 seed + 3 mutations = 4 candidates
|
||||
assert len(state.candidates) == 4
|
||||
assert mock_mutation_port.mutate.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_population_initialization_uses_proposer_fallback(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""When no mutation_port is provided, population init falls back to proposer."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
evaluator.evaluate = AsyncMock(
|
||||
return_value=_make_eval([0.5] * 5, "ok")
|
||||
)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=0,
|
||||
minibatch_size=5,
|
||||
population_size=3,
|
||||
# mutation_port intentionally omitted
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert len(state.candidates) == 3
|
||||
assert mock_proposer_port.propose.call_count == 2 # 3-1 = 2 init mutations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_population_iteration_replaces_worst(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
mock_crossover_port: AsyncMock,
|
||||
mock_mutation_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Crossover child replaces worst candidate when its fitness is higher."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
# Sequence:
|
||||
# 1. Initial eval (seed)
|
||||
# 2. Population init: 3 mutation calls use proposer.propose(), NOT evaluator.evaluate
|
||||
# 3. Population iteration: crossover produces child → eval child
|
||||
# Only 2 evaluator.evaluate calls total
|
||||
seed_eval = _make_eval([0.5] * 5, "ok")
|
||||
# Crossover child eval - high score to beat worst
|
||||
child_eval = _make_eval([0.9, 0.9, 0.8, 0.9, 0.8], "great")
|
||||
|
||||
all_evals = [seed_eval, child_eval]
|
||||
evaluator.evaluate = AsyncMock(side_effect=all_evals)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
population_size=4,
|
||||
crossover_rate=1.0,
|
||||
crossover_port=mock_crossover_port,
|
||||
mutation_rate=0.0, # disable post-crossover mutation for determinism
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
accepted_events = [h for h in state.history if h.get("event") == "pop_accepted"]
|
||||
assert len(accepted_events) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_population_iteration_rejects_inferior_child(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
mock_crossover_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Inferior child is rejected and doesn't replace any candidate."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
seed_eval = _make_eval([0.8] * 5, "ok")
|
||||
# Crossover produces very LOW-scoring child
|
||||
child_eval = _make_eval([0.1] * 5, "terrible")
|
||||
|
||||
all_evals = [seed_eval, child_eval]
|
||||
evaluator.evaluate = AsyncMock(side_effect=all_evals)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
population_size=4,
|
||||
crossover_rate=1.0,
|
||||
crossover_port=mock_crossover_port,
|
||||
mutation_rate=0.0,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
rejected_events = [h for h in state.history if h.get("event") == "pop_rejected"]
|
||||
assert len(rejected_events) >= 1
|
||||
|
||||
|
||||
class TestDiversityScore:
|
||||
"""Tests for the diversity/similarity scoring logic."""
|
||||
|
||||
def test_identical_prompts_have_high_similarity(self) -> None:
|
||||
"""Identical prompts should have very high similarity."""
|
||||
identical = Prompt(text="You are a helpful assistant. Answer the question.")
|
||||
pop_a = Candidate(prompt=identical, best_score=4.0, generation=0)
|
||||
pop_b = Candidate(
|
||||
prompt=Prompt(text="Completely different prompt about data analysis."),
|
||||
best_score=3.0,
|
||||
generation=0,
|
||||
)
|
||||
sim_same = EvolutionLoop._compute_diversity_score(identical, [pop_a, pop_b])
|
||||
# Average includes similarity to the different member, so ~0.5 not 0.9+
|
||||
assert sim_same > 0.3
|
||||
|
||||
def test_different_prompts_have_lower_similarity(self) -> None:
|
||||
"""Different prompts should have lower similarity than identical ones."""
|
||||
prompt_a = Prompt(text="You are a helpful assistant. Answer the question.")
|
||||
prompt_b = Prompt(text="Provide detailed analysis of complex data patterns with precision.")
|
||||
pop_a = Candidate(prompt=prompt_a, best_score=4.0, generation=0)
|
||||
pop_b = Candidate(prompt=prompt_b, best_score=3.0, generation=0)
|
||||
sim_a = EvolutionLoop._compute_diversity_score(prompt_a, [pop_a, pop_b])
|
||||
sim_b = EvolutionLoop._compute_diversity_score(prompt_b, [pop_a, pop_b])
|
||||
# Both should be < 1.0 since they're different
|
||||
assert sim_a < 1.0
|
||||
assert sim_b < 1.0
|
||||
|
||||
def test_single_member_population_returns_1(self) -> None:
|
||||
"""Single-member population always returns 1.0 (no penalty)."""
|
||||
prompt = Prompt(text="Any prompt text here.")
|
||||
pop = [Candidate(prompt=prompt, best_score=1.0, generation=0)]
|
||||
sim = EvolutionLoop._compute_diversity_score(prompt, pop)
|
||||
assert sim == 1.0
|
||||
|
||||
def test_empty_prompt_returns_zero(self) -> None:
|
||||
"""Empty prompt text returns 0.0 when population has >1 member."""
|
||||
prompt = Prompt(text="")
|
||||
pop = [
|
||||
Candidate(prompt=Prompt(text="some text"), best_score=1.0, generation=0),
|
||||
Candidate(prompt=Prompt(text="other text"), best_score=2.0, generation=0),
|
||||
]
|
||||
sim = EvolutionLoop._compute_diversity_score(prompt, pop)
|
||||
assert sim == 0.0
|
||||
|
||||
|
||||
class TestPromptDiff:
|
||||
"""Tests for the static _compute_prompt_diff helper."""
|
||||
|
||||
def test_identical_prompts(self) -> None:
|
||||
result = EvolutionLoop._compute_prompt_diff("hello\nworld", "hello\nworld")
|
||||
assert result["lines_added"] == 0
|
||||
assert result["lines_removed"] == 0
|
||||
assert result["chars_delta"] == 0
|
||||
|
||||
def test_added_lines(self) -> None:
|
||||
result = EvolutionLoop._compute_prompt_diff("hello", "hello\nworld")
|
||||
assert result["lines_added"] == 1
|
||||
assert result["lines_removed"] == 0
|
||||
assert result["chars_delta"] == 6 # "\nworld"
|
||||
|
||||
def test_removed_lines(self) -> None:
|
||||
result = EvolutionLoop._compute_prompt_diff("hello\nworld", "hello")
|
||||
assert result["lines_added"] == 0
|
||||
assert result["lines_removed"] == 1
|
||||
|
||||
133
tests/unit/test_ground_truth_evaluator.py
Normal file
133
tests/unit/test_ground_truth_evaluator.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tests for GroundTruthEvaluator — execution + similarity comparison."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.ground_truth_evaluator import GroundTruthEvaluator
|
||||
from prometheus.domain.entities import EvalResult, GroundTruthExample, Prompt
|
||||
from prometheus.domain.ports import LLMPort, SimilarityPort
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_executor() -> AsyncMock:
|
||||
port = AsyncMock(spec=LLMPort)
|
||||
port.execute.return_value = "Paris is the capital of France."
|
||||
return port
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_similarity() -> AsyncMock:
|
||||
port = AsyncMock(spec=SimilarityPort)
|
||||
port.compute.return_value = 0.85
|
||||
return port
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gt_dataset() -> list[GroundTruthExample]:
|
||||
return [
|
||||
GroundTruthExample(input_text="What is the capital of France?", expected_output="Paris", id=0),
|
||||
GroundTruthExample(input_text="What is 2+2?", expected_output="4", id=1),
|
||||
GroundTruthExample(input_text="What color is the sky?", expected_output="blue", id=2),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt() -> Prompt:
|
||||
return Prompt(text="Answer the following question accurately.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestGroundTruthEvaluator:
|
||||
async def test_evaluate_happy_path(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor,
|
||||
similarity=mock_similarity,
|
||||
max_concurrency=2,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
|
||||
assert isinstance(result, EvalResult)
|
||||
assert len(result.scores) == 3
|
||||
assert len(result.feedbacks) == 3
|
||||
assert len(result.trajectories) == 3
|
||||
assert all(s == 0.85 for s in result.scores)
|
||||
assert result.mean_score == pytest.approx(0.85)
|
||||
assert result.total_score == pytest.approx(2.55)
|
||||
|
||||
async def test_executor_called_for_each_input(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=mock_similarity,
|
||||
)
|
||||
await evaluator.evaluate(prompt, gt_dataset)
|
||||
assert mock_executor.execute.call_count == 3
|
||||
|
||||
async def test_similarity_called_for_each_output(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=mock_similarity,
|
||||
)
|
||||
await evaluator.evaluate(prompt, gt_dataset)
|
||||
assert mock_similarity.compute.call_count == 3
|
||||
|
||||
async def test_execution_error_produces_zero_score(self, mock_similarity, gt_dataset, prompt):
|
||||
failing_executor = AsyncMock(spec=LLMPort)
|
||||
failing_executor.execute.side_effect = RuntimeError("API timeout")
|
||||
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=failing_executor, similarity=mock_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
|
||||
assert len(result.scores) == 3
|
||||
# The similarity adapter is called with the error sentinel
|
||||
assert all(isinstance(s, float) for s in result.scores)
|
||||
assert all("[execution error:" in t.output_text for t in result.trajectories)
|
||||
|
||||
async def test_empty_dataset(self, mock_executor, mock_similarity, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=mock_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, [])
|
||||
assert result.scores == []
|
||||
assert result.mean_score == 0.0
|
||||
assert result.total_score == 0.0
|
||||
|
||||
async def test_trajectory_contains_prompt_used(self, mock_executor, mock_similarity, gt_dataset, prompt):
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=mock_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
for t in result.trajectories:
|
||||
assert t.prompt_used == prompt.text
|
||||
|
||||
async def test_scores_clamped_to_unit_range(self, mock_executor, gt_dataset, prompt):
|
||||
# Similarity returns a value > 1.0 (should be clamped)
|
||||
over_similarity = AsyncMock(spec=SimilarityPort)
|
||||
over_similarity.compute.return_value = 1.5
|
||||
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=over_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
assert all(0.0 <= s <= 1.0 for s in result.scores)
|
||||
|
||||
async def test_feedback_for_exact_match(self, mock_executor, gt_dataset, prompt):
|
||||
exact_similarity = AsyncMock(spec=SimilarityPort)
|
||||
exact_similarity.compute.return_value = 1.0
|
||||
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=exact_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
assert all("Exact match" in fb for fb in result.feedbacks)
|
||||
|
||||
async def test_feedback_for_poor_match(self, mock_executor, gt_dataset, prompt):
|
||||
poor_similarity = AsyncMock(spec=SimilarityPort)
|
||||
poor_similarity.compute.return_value = 0.1
|
||||
|
||||
evaluator = GroundTruthEvaluator(
|
||||
executor=mock_executor, similarity=poor_similarity,
|
||||
)
|
||||
result = await evaluator.evaluate(prompt, gt_dataset)
|
||||
assert all("Poor match" in fb for fb in result.feedbacks)
|
||||
316
tests/unit/test_holdout_validation.py
Normal file
316
tests/unit/test_holdout_validation.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Unit tests for hold-out validation and early stopping."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
EvalResult,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
Trajectory,
|
||||
)
|
||||
|
||||
|
||||
def _make_eval(mean_score: float, n: int = 5) -> EvalResult:
|
||||
"""Helper: create an EvalResult with a given mean score."""
|
||||
scores = [mean_score] * n
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=["feedback"] * n,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", mean_score, "feedback", "prompt")
|
||||
for i in range(n)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestBootstrapSplit:
|
||||
"""Tests for SyntheticBootstrap.split_pool."""
|
||||
|
||||
def test_split_produces_correct_sizes(self):
|
||||
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(20)]
|
||||
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
|
||||
assert len(train) + len(val) == 20
|
||||
assert len(val) == 6 # 20 * 0.3 = 6
|
||||
assert len(train) == 14
|
||||
|
||||
def test_split_zero_fraction_returns_all_train(self):
|
||||
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(10)]
|
||||
train, val = SyntheticBootstrap.split_pool(pool, 0.0)
|
||||
assert len(train) == 10
|
||||
assert len(val) == 0
|
||||
|
||||
def test_split_single_element(self):
|
||||
pool = [SyntheticExample(input_text="only", id=0)]
|
||||
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
|
||||
assert len(train) == 1
|
||||
assert len(val) == 0
|
||||
|
||||
def test_split_deterministic_with_seed(self):
|
||||
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(50)]
|
||||
train1, val1 = SyntheticBootstrap.split_pool(pool, 0.3, rng=MagicMock(wraps=__import__("random").Random(42)))
|
||||
train2, val2 = SyntheticBootstrap.split_pool(pool, 0.3, rng=MagicMock(wraps=__import__("random").Random(42)))
|
||||
assert [ex.id for ex in train1] == [ex.id for ex in train2]
|
||||
assert [ex.id for ex in val1] == [ex.id for ex in val2]
|
||||
|
||||
def test_split_no_overlap(self):
|
||||
pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(30)]
|
||||
train, val = SyntheticBootstrap.split_pool(pool, 0.3)
|
||||
train_ids = {ex.id for ex in train}
|
||||
val_ids = {ex.id for ex in val}
|
||||
assert train_ids.isdisjoint(val_ids)
|
||||
assert train_ids | val_ids == {ex.id for ex in pool}
|
||||
|
||||
|
||||
class TestValidationEvaluation:
|
||||
"""Tests for hold-out evaluation during evolution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_pool_evaluated_after_each_iteration(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""When a validation pool is provided, the best candidate is evaluated on it."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
# Initial eval (train) + validation eval + iteration train eval + new prompt eval + validation eval
|
||||
train_eval = _make_eval(0.5)
|
||||
val_eval = _make_eval(0.6)
|
||||
new_eval = _make_eval(0.7)
|
||||
val_eval_2 = _make_eval(0.65)
|
||||
|
||||
evaluator.evaluate = AsyncMock(
|
||||
side_effect=[train_eval, val_eval, train_eval, new_eval, val_eval_2]
|
||||
)
|
||||
|
||||
validation_pool = synthetic_pool[-6:]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
validation_pool=validation_pool,
|
||||
)
|
||||
|
||||
# Should have validation metrics in state
|
||||
assert state.best_validation_score is not None
|
||||
# History should contain validation_eval entries
|
||||
val_events = [h for h in state.history if h["event"] == "validation_eval"]
|
||||
assert len(val_events) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_validation_without_pool(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Without a validation pool, no validation is performed."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
train_eval = _make_eval(0.5)
|
||||
old_eval = _make_eval(0.5)
|
||||
new_eval = _make_eval(0.7)
|
||||
evaluator.evaluate = AsyncMock(side_effect=[train_eval, old_eval, new_eval])
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=1,
|
||||
minibatch_size=5,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
assert state.best_validation_score is None
|
||||
assert not state.early_stopped
|
||||
val_events = [h for h in state.history if h["event"] == "validation_eval"]
|
||||
assert len(val_events) == 0
|
||||
|
||||
|
||||
class TestEarlyStopping:
|
||||
"""Tests for early stopping when validation score degrades."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_early_stop_triggers_on_patience_exceeded(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Early stopping triggers when validation doesn't improve for K iterations."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
patience = 3
|
||||
# Build eval sequence:
|
||||
# 1. Initial train eval
|
||||
# 2. Initial validation eval (0.5)
|
||||
# Then for each of 3 iterations:
|
||||
# - train eval (current best)
|
||||
# - train eval (new prompt - accepted)
|
||||
# - validation eval (degrading)
|
||||
evals = [
|
||||
_make_eval(0.5), # initial train
|
||||
_make_eval(0.5), # initial validation
|
||||
]
|
||||
for i in range(patience):
|
||||
evals.extend([
|
||||
_make_eval(0.5 + i * 0.1), # current eval (train)
|
||||
_make_eval(0.6 + i * 0.1), # new eval (train) - accepted
|
||||
_make_eval(0.4), # validation eval (degrading)
|
||||
])
|
||||
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
|
||||
validation_pool = synthetic_pool[-5:]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=10, # would go further without early stop
|
||||
minibatch_size=5,
|
||||
early_stop_patience=patience,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
validation_pool=validation_pool,
|
||||
)
|
||||
|
||||
assert state.early_stopped is True
|
||||
assert state.iteration == patience
|
||||
assert state.best_validation_score is not None
|
||||
# Should have an early_stop event in history
|
||||
early_stop_events = [h for h in state.history if h["event"] == "early_stop"]
|
||||
assert len(early_stop_events) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_early_stop_does_not_trigger_when_improving(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""When validation keeps improving, early stopping does not trigger."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
evals = [
|
||||
_make_eval(0.3), # initial train
|
||||
_make_eval(0.3), # initial validation
|
||||
]
|
||||
# 3 iterations, each with improving validation
|
||||
for i in range(3):
|
||||
evals.extend([
|
||||
_make_eval(0.3 + i * 0.1), # current train eval
|
||||
_make_eval(0.4 + i * 0.1), # new train eval (accepted)
|
||||
_make_eval(0.3 + (i + 1) * 0.1), # validation eval (improving)
|
||||
])
|
||||
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
|
||||
validation_pool = synthetic_pool[-5:]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=5,
|
||||
early_stop_patience=5,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
validation_pool=validation_pool,
|
||||
)
|
||||
|
||||
assert state.early_stopped is False
|
||||
assert state.iteration == 3
|
||||
assert state.best_validation_score is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_patience_resets_on_improvement(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
mock_llm_port: AsyncMock,
|
||||
mock_judge_port: AsyncMock,
|
||||
mock_proposer_port: AsyncMock,
|
||||
) -> None:
|
||||
"""Patience counter resets when validation improves after degrading."""
|
||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
evals = [
|
||||
_make_eval(0.3), # initial train
|
||||
_make_eval(0.3), # initial validation
|
||||
# iter 1: degrade
|
||||
_make_eval(0.3), # current train
|
||||
_make_eval(0.5), # new train (accepted)
|
||||
_make_eval(0.2), # validation degrade (patience=1)
|
||||
# iter 2: degrade
|
||||
_make_eval(0.5), # current train
|
||||
_make_eval(0.6), # new train (accepted)
|
||||
_make_eval(0.2), # validation degrade (patience=2)
|
||||
# iter 3: improve! (resets patience)
|
||||
_make_eval(0.6), # current train
|
||||
_make_eval(0.7), # new train (accepted)
|
||||
_make_eval(0.4), # validation improve (patience=0)
|
||||
# iter 4: degrade again
|
||||
_make_eval(0.7), # current train
|
||||
_make_eval(0.8), # new train (accepted)
|
||||
_make_eval(0.2), # validation degrade (patience=1)
|
||||
]
|
||||
|
||||
evaluator.evaluate = AsyncMock(side_effect=evals)
|
||||
validation_pool = synthetic_pool[-5:]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=mock_proposer_port,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=4,
|
||||
minibatch_size=5,
|
||||
early_stop_patience=3,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
validation_pool=validation_pool,
|
||||
)
|
||||
|
||||
assert state.early_stopped is False
|
||||
assert state.iteration == 4
|
||||
189
tests/unit/test_logging.py
Normal file
189
tests/unit/test_logging.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Unit tests for structured logging configuration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from prometheus.cli.logging_setup import configure_logging, get_logger
|
||||
|
||||
|
||||
class TestConfigureLogging:
|
||||
def _count_handlers(self, name: str = "prometheus") -> int:
|
||||
return len(logging.getLogger(name).handlers)
|
||||
|
||||
def test_default_creates_console_handler(self) -> None:
|
||||
configure_logging(level=logging.INFO)
|
||||
prom = logging.getLogger("prometheus")
|
||||
assert len(prom.handlers) == 1
|
||||
assert isinstance(prom.handlers[0], logging.StreamHandler)
|
||||
prom.handlers.clear()
|
||||
|
||||
def test_json_format_produces_valid_json(self, capsys) -> None:
|
||||
configure_logging(level=logging.INFO, log_format="json")
|
||||
logger = get_logger("test_json")
|
||||
logger.info("hello", extra={"structured": {"key": "value"}})
|
||||
|
||||
captured = capsys.readouterr()
|
||||
# Output goes to stderr
|
||||
line = captured.err.strip().split("\n")[-1]
|
||||
data = json.loads(line)
|
||||
assert data["message"] == "hello"
|
||||
assert data["structured"]["key"] == "value"
|
||||
assert data["level"] == "INFO"
|
||||
assert "timestamp" in data
|
||||
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
def test_text_format_includes_structured_extras(self, capsys) -> None:
|
||||
configure_logging(level=logging.INFO, log_format="text")
|
||||
logger = get_logger("test_text")
|
||||
logger.info("msg", extra={"structured": {"foo": "bar"}})
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "foo=bar" in captured.err
|
||||
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
def test_debug_level_shows_debug_messages(self, capsys) -> None:
|
||||
configure_logging(level=logging.DEBUG)
|
||||
logger = get_logger("test_debug")
|
||||
logger.debug("debug msg")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "debug msg" in captured.err
|
||||
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
def test_warning_level_hides_debug_messages(self, capsys) -> None:
|
||||
configure_logging(level=logging.WARNING)
|
||||
logger = get_logger("test_warn")
|
||||
logger.debug("should not appear")
|
||||
logger.info("also hidden")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "should not appear" not in captured.err
|
||||
assert "also hidden" not in captured.err
|
||||
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
def test_file_handler_writes_to_file(self, tmp_path: Path) -> None:
|
||||
log_file = tmp_path / "test.log"
|
||||
configure_logging(level=logging.INFO, log_file=str(log_file))
|
||||
logger = get_logger("test_file")
|
||||
logger.info("file message")
|
||||
|
||||
prom = logging.getLogger("prometheus")
|
||||
# Flush handlers
|
||||
for h in prom.handlers:
|
||||
h.flush()
|
||||
prom.handlers.clear()
|
||||
|
||||
content = log_file.read_text()
|
||||
assert "file message" in content
|
||||
|
||||
def test_json_file_output(self, tmp_path: Path) -> None:
|
||||
log_file = tmp_path / "test.json.log"
|
||||
configure_logging(level=logging.INFO, log_format="json", log_file=str(log_file))
|
||||
logger = get_logger("test_json_file")
|
||||
logger.info("json file msg", extra={"structured": {"x": 1}})
|
||||
|
||||
prom = logging.getLogger("prometheus")
|
||||
for h in prom.handlers:
|
||||
h.flush()
|
||||
prom.handlers.clear()
|
||||
|
||||
content = log_file.read_text().strip()
|
||||
data = json.loads(content)
|
||||
assert data["message"] == "json file msg"
|
||||
assert data["structured"]["x"] == 1
|
||||
|
||||
def test_reconfigure_clears_old_handlers(self) -> None:
|
||||
configure_logging(level=logging.INFO)
|
||||
configure_logging(level=logging.DEBUG)
|
||||
prom = logging.getLogger("prometheus")
|
||||
assert len(prom.handlers) == 1
|
||||
prom.handlers.clear()
|
||||
|
||||
def test_propagate_false_prevents_duplicate_output(self, capsys) -> None:
|
||||
configure_logging(level=logging.INFO)
|
||||
prom = logging.getLogger("prometheus")
|
||||
assert prom.propagate is False
|
||||
prom.handlers.clear()
|
||||
|
||||
|
||||
class TestGetLogger:
|
||||
def test_returns_child_of_prometheus(self) -> None:
|
||||
logger = get_logger("mymodule")
|
||||
assert logger.name == "prometheus.mymodule"
|
||||
|
||||
def test_inherits_level_from_parent(self) -> None:
|
||||
configure_logging(level=logging.DEBUG)
|
||||
logger = get_logger("child")
|
||||
assert logger.getEffectiveLevel() <= logging.DEBUG
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
|
||||
class TestJsonFormatter:
|
||||
def test_exception_included(self, capsys) -> None:
|
||||
configure_logging(level=logging.ERROR, log_format="json")
|
||||
logger = get_logger("test_exc")
|
||||
try:
|
||||
raise ValueError("boom")
|
||||
except ValueError:
|
||||
logger.error("failed", exc_info=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
line = captured.err.strip().split("\n")[-1]
|
||||
data = json.loads(line)
|
||||
assert "ValueError: boom" in data["exception"]
|
||||
|
||||
logging.getLogger("prometheus").handlers.clear()
|
||||
|
||||
|
||||
class TestLoggingCLIIntegration:
|
||||
"""Tests for CLI flags that configure logging."""
|
||||
|
||||
def test_verbose_flag_enables_info(self, tmp_path: Path) -> None:
|
||||
"""Simulate what -v does — configure_logging at INFO level."""
|
||||
configure_logging(level=logging.INFO)
|
||||
logger = get_logger("evolution")
|
||||
logger.info("test message")
|
||||
|
||||
prom = logging.getLogger("prometheus")
|
||||
assert len(prom.handlers) == 1
|
||||
prom.handlers.clear()
|
||||
|
||||
def test_debug_flag_enables_debug(self) -> None:
|
||||
"""Simulate what --debug does — configure_logging at DEBUG level."""
|
||||
configure_logging(level=logging.DEBUG)
|
||||
logger = get_logger("evolution")
|
||||
logger.debug("debug message")
|
||||
|
||||
prom = logging.getLogger("prometheus")
|
||||
assert prom.level == logging.DEBUG
|
||||
prom.handlers.clear()
|
||||
|
||||
def test_log_format_invalid_rejected(self) -> None:
|
||||
"""Invalid log_format should be caught by OptimizationConfig validator."""
|
||||
from pydantic import ValidationError
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
|
||||
import pytest
|
||||
|
||||
with pytest.raises(ValidationError, match="log_format must be one of"):
|
||||
OptimizationConfig(
|
||||
seed_prompt="a",
|
||||
task_description="b",
|
||||
log_format="xml",
|
||||
)
|
||||
|
||||
def test_log_format_text_and_json_accepted(self) -> None:
|
||||
"""Both text and json log_format values should be valid."""
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
|
||||
for fmt in ("text", "json"):
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="a", task_description="b", log_format=fmt,
|
||||
)
|
||||
assert config.log_format == fmt
|
||||
96
tests/unit/test_scoring_extended.py
Normal file
96
tests/unit/test_scoring_extended.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Additional unit tests for scoring edge cases."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.domain.entities import EvalResult, Trajectory
|
||||
from prometheus.domain.scoring import normalize_score, should_accept
|
||||
|
||||
|
||||
def _make_eval(scores: list[float]) -> EvalResult:
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=[""] * len(scores),
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", s, "", "p")
|
||||
for i, s in enumerate(scores)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestShouldAcceptEdgeCases:
|
||||
"""Extended edge-case tests for should_accept."""
|
||||
|
||||
def test_tiny_improvement_accepted(self) -> None:
|
||||
old = _make_eval([0.5])
|
||||
new = _make_eval([0.5001])
|
||||
assert should_accept(old, new) is True
|
||||
|
||||
def test_tiny_improvement_below_threshold(self) -> None:
|
||||
old = _make_eval([0.5])
|
||||
new = _make_eval([0.5001])
|
||||
assert should_accept(old, new, min_improvement=0.01) is False
|
||||
|
||||
def test_zero_scores_equal(self) -> None:
|
||||
old = _make_eval([0.0, 0.0])
|
||||
new = _make_eval([0.0, 0.0])
|
||||
assert should_accept(old, new) is False
|
||||
|
||||
def test_negative_to_zero_not_accepted(self) -> None:
|
||||
"""Scores should be [0,1] but test should_accept with edge values."""
|
||||
old = _make_eval([-0.1])
|
||||
new = _make_eval([0.0])
|
||||
assert should_accept(old, new) is True
|
||||
|
||||
def test_large_improvement(self) -> None:
|
||||
old = _make_eval([0.0, 0.0, 0.0])
|
||||
new = _make_eval([1.0, 1.0, 1.0])
|
||||
assert should_accept(old, new) is True
|
||||
|
||||
def test_single_score_improvement(self) -> None:
|
||||
old = _make_eval([0.4])
|
||||
new = _make_eval([0.5])
|
||||
assert should_accept(old, new) is True
|
||||
|
||||
def test_min_improvement_exactly_met(self) -> None:
|
||||
"""When improvement exactly equals min_improvement, still rejected (strict >)."""
|
||||
old = _make_eval([0.5])
|
||||
new = _make_eval([0.7])
|
||||
assert should_accept(old, new, min_improvement=0.2) is False
|
||||
|
||||
def test_min_improvement_just_over(self) -> None:
|
||||
old = _make_eval([0.5])
|
||||
new = _make_eval([0.7001])
|
||||
assert should_accept(old, new, min_improvement=0.2) is True
|
||||
|
||||
|
||||
class TestNormalizeScoreEdgeCases:
|
||||
"""Extended edge-case tests for normalize_score."""
|
||||
|
||||
def test_exact_bounds(self) -> None:
|
||||
assert normalize_score(0.0) == 0.0
|
||||
assert normalize_score(1.0) == 1.0
|
||||
|
||||
def test_very_large_value(self) -> None:
|
||||
assert normalize_score(1e10) == 1.0
|
||||
|
||||
def test_very_negative_value(self) -> None:
|
||||
assert normalize_score(-1e10) == 0.0
|
||||
|
||||
def test_custom_bounds_at_edges(self) -> None:
|
||||
assert normalize_score(5.0, min_val=0.0, max_val=10.0) == 5.0
|
||||
assert normalize_score(0.0, min_val=0.0, max_val=10.0) == 0.0
|
||||
assert normalize_score(10.0, min_val=0.0, max_val=10.0) == 10.0
|
||||
|
||||
def test_negative_custom_range(self) -> None:
|
||||
assert normalize_score(0.0, min_val=-5.0, max_val=5.0) == 0.0
|
||||
assert normalize_score(-3.0, min_val=-5.0, max_val=5.0) == -3.0
|
||||
assert normalize_score(-10.0, min_val=-5.0, max_val=5.0) == -5.0
|
||||
|
||||
def test_zero_span_range(self) -> None:
|
||||
"""When min == max, clamps to min."""
|
||||
assert normalize_score(5.0, min_val=5.0, max_val=5.0) == 5.0
|
||||
assert normalize_score(0.0, min_val=5.0, max_val=5.0) == 5.0
|
||||
|
||||
def test_fractional_score(self) -> None:
|
||||
assert normalize_score(0.3333) == pytest.approx(0.3333)
|
||||
133
tests/unit/test_similarity.py
Normal file
133
tests/unit/test_similarity.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tests for similarity adapters — exact, BLEU, ROUGE-L, cosine."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.infrastructure.similarity import (
|
||||
BleuSimilarity,
|
||||
CosineSimilarity,
|
||||
ExactMatchSimilarity,
|
||||
RougeLSimilarity,
|
||||
create_similarity_adapter,
|
||||
)
|
||||
|
||||
|
||||
class TestExactMatchSimilarity:
|
||||
def test_exact_match(self):
|
||||
s = ExactMatchSimilarity()
|
||||
assert s.compute("Hello World", "Hello World") == 1.0
|
||||
|
||||
def test_case_insensitive(self):
|
||||
s = ExactMatchSimilarity()
|
||||
assert s.compute("hello world", "HELLO WORLD") == 1.0
|
||||
|
||||
def test_whitespace_trimmed(self):
|
||||
s = ExactMatchSimilarity()
|
||||
assert s.compute(" hello ", "hello") == 1.0
|
||||
|
||||
def test_no_match(self):
|
||||
s = ExactMatchSimilarity()
|
||||
assert s.compute("hello", "world") == 0.0
|
||||
|
||||
def test_partial_no_match(self):
|
||||
s = ExactMatchSimilarity()
|
||||
assert s.compute("hello world", "hello") == 0.0
|
||||
|
||||
|
||||
class TestBleuSimilarity:
|
||||
def test_perfect_match(self):
|
||||
s = BleuSimilarity()
|
||||
assert s.compute("the cat sat on the mat", "the cat sat on the mat") == 1.0
|
||||
|
||||
def test_no_overlap(self):
|
||||
s = BleuSimilarity()
|
||||
assert s.compute("aaa bbb ccc", "ddd eee fff") == 0.0
|
||||
|
||||
def test_partial_overlap(self):
|
||||
s = BleuSimilarity()
|
||||
score = s.compute("the cat sat", "the cat")
|
||||
assert 0.0 < score < 1.0
|
||||
|
||||
def test_empty_prediction(self):
|
||||
s = BleuSimilarity()
|
||||
assert s.compute("", "hello world") == 0.0
|
||||
|
||||
def test_empty_expected(self):
|
||||
s = BleuSimilarity()
|
||||
assert s.compute("hello world", "") == 0.0
|
||||
|
||||
def test_both_empty(self):
|
||||
s = BleuSimilarity()
|
||||
assert s.compute("", "") == 0.0
|
||||
|
||||
def test_shorter_prediction_gets_brevity_penalty(self):
|
||||
s = BleuSimilarity()
|
||||
short = s.compute("cat", "the cat sat on the mat")
|
||||
full = s.compute("the cat sat on the mat", "the cat sat on the mat")
|
||||
assert short < full
|
||||
|
||||
|
||||
class TestRougeLSimilarity:
|
||||
def test_perfect_match(self):
|
||||
s = RougeLSimilarity()
|
||||
assert s.compute("the cat sat", "the cat sat") == 1.0
|
||||
|
||||
def test_no_overlap(self):
|
||||
s = RougeLSimilarity()
|
||||
assert s.compute("aaa bbb", "ccc ddd") == 0.0
|
||||
|
||||
def test_partial_overlap(self):
|
||||
s = RougeLSimilarity()
|
||||
score = s.compute("the cat sat on the mat", "the cat on the rug")
|
||||
assert 0.0 < score < 1.0
|
||||
|
||||
def test_empty_prediction(self):
|
||||
s = RougeLSimilarity()
|
||||
assert s.compute("", "hello") == 0.0
|
||||
|
||||
def test_subsequence(self):
|
||||
s = RougeLSimilarity()
|
||||
# "cat mat" is a subsequence of "the cat sat on the mat"
|
||||
score = s.compute("cat mat", "the cat sat on the mat")
|
||||
assert score > 0.0
|
||||
|
||||
|
||||
class TestCosineSimilarity:
|
||||
def test_identical_texts(self):
|
||||
s = CosineSimilarity()
|
||||
assert s.compute("hello world", "hello world") == pytest.approx(1.0)
|
||||
|
||||
def test_no_overlap(self):
|
||||
s = CosineSimilarity()
|
||||
assert s.compute("aaa bbb", "ccc ddd") == 0.0
|
||||
|
||||
def test_partial_overlap(self):
|
||||
s = CosineSimilarity()
|
||||
score = s.compute("hello world foo", "hello world bar")
|
||||
assert 0.0 < score < 1.0
|
||||
|
||||
def test_empty_prediction(self):
|
||||
s = CosineSimilarity()
|
||||
assert s.compute("", "hello") == 0.0
|
||||
|
||||
|
||||
class TestCreateSimilarityAdapter:
|
||||
def test_create_exact(self):
|
||||
adapter = create_similarity_adapter("exact")
|
||||
assert isinstance(adapter, ExactMatchSimilarity)
|
||||
|
||||
def test_create_bleu(self):
|
||||
adapter = create_similarity_adapter("bleu")
|
||||
assert isinstance(adapter, BleuSimilarity)
|
||||
|
||||
def test_create_rouge_l(self):
|
||||
adapter = create_similarity_adapter("rouge_l")
|
||||
assert isinstance(adapter, RougeLSimilarity)
|
||||
|
||||
def test_create_cosine(self):
|
||||
adapter = create_similarity_adapter("cosine")
|
||||
assert isinstance(adapter, CosineSimilarity)
|
||||
|
||||
def test_unknown_metric_raises(self):
|
||||
with pytest.raises(ValueError, match="Unknown eval metric"):
|
||||
create_similarity_adapter("nonexistent")
|
||||
233
tests/unit/test_use_cases.py
Normal file
233
tests/unit/test_use_cases.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Unit tests for OptimizePromptUseCase — direct orchestration tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.dto import OptimizationConfig, OptimizationResult
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.application.use_cases import OptimizePromptUseCase
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
EvalResult,
|
||||
OptimizationState,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
Trajectory,
|
||||
)
|
||||
|
||||
|
||||
def _make_eval(scores: list[float]) -> EvalResult:
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=["feedback"] * len(scores),
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", s, "feedback", "prompt")
|
||||
for i, s in enumerate(scores)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _make_state(
|
||||
iterations: int = 3,
|
||||
initial_score: float = 0.3,
|
||||
final_score: float = 0.8,
|
||||
accepted: bool = True,
|
||||
) -> OptimizationState:
|
||||
seed = Candidate(prompt=Prompt(text="seed"), best_score=initial_score, generation=0)
|
||||
best = Candidate(
|
||||
prompt=Prompt(text="optimized" if accepted else "seed"),
|
||||
best_score=final_score,
|
||||
generation=iterations if accepted else 0,
|
||||
)
|
||||
history = []
|
||||
for i in range(1, iterations + 1):
|
||||
event = "accepted" if accepted else "rejected"
|
||||
history.append({"iteration": i, "event": event, "old_score": 0.3, "new_score": 0.8})
|
||||
|
||||
return OptimizationState(
|
||||
iteration=iterations,
|
||||
best_candidate=best,
|
||||
candidates=[seed, best] if accepted else [seed],
|
||||
total_llm_calls=iterations * 11 + 10,
|
||||
history=history,
|
||||
)
|
||||
|
||||
|
||||
class TestOptimizePromptUseCaseExecute:
|
||||
"""Tests for the execute() orchestration method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_evaluator(self) -> MagicMock:
|
||||
return MagicMock(spec=PromptEvaluator)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_proposer(self) -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bootstrap(self) -> MagicMock:
|
||||
return MagicMock(spec=SyntheticBootstrap)
|
||||
|
||||
@pytest.fixture
|
||||
def use_case(
|
||||
self,
|
||||
mock_evaluator: MagicMock,
|
||||
mock_proposer: MagicMock,
|
||||
mock_bootstrap: MagicMock,
|
||||
) -> OptimizePromptUseCase:
|
||||
return OptimizePromptUseCase(
|
||||
evaluator=mock_evaluator,
|
||||
proposer=mock_proposer,
|
||||
bootstrap=mock_bootstrap,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> OptimizationConfig:
|
||||
return OptimizationConfig(
|
||||
seed_prompt="Answer the question.",
|
||||
task_description="Q&A task",
|
||||
max_iterations=5,
|
||||
n_synthetic_inputs=20,
|
||||
minibatch_size=5,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_optimization_result(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = [
|
||||
SyntheticExample(input_text=f"q{i}", id=i) for i in range(20)
|
||||
]
|
||||
|
||||
mock_state = _make_state(iterations=3, initial_score=0.3, final_score=0.9)
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert isinstance(result, OptimizationResult)
|
||||
assert result.initial_prompt == "Answer the question."
|
||||
assert result.final_score == 0.9
|
||||
assert result.improvement == pytest.approx(0.6)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bootstrap_called_with_config_params(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = []
|
||||
mock_state = _make_state()
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
await use_case.execute(config)
|
||||
|
||||
mock_bootstrap.run.assert_called_once_with(
|
||||
task_description="Q&A task",
|
||||
n_examples=20,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evolution_loop_configured_from_config(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = []
|
||||
mock_state = _make_state()
|
||||
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state) as mock_run:
|
||||
await use_case.execute(config)
|
||||
|
||||
# Verify the loop was instantiated with correct params
|
||||
mock_run.assert_called_once()
|
||||
call_args = mock_run.call_args
|
||||
seed_prompt = call_args[0][0]
|
||||
assert seed_prompt.text == "Answer the question."
|
||||
synthetic_pool = call_args[0][1]
|
||||
assert len(synthetic_pool) == 0 # bootstrap returned empty
|
||||
assert call_args[0][2] == "Q&A task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_total_llm_calls_includes_bootstrap_call(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = []
|
||||
mock_state = _make_state(iterations=3)
|
||||
# total_llm_calls from state + 1 for bootstrap
|
||||
expected = mock_state.total_llm_calls + 1
|
||||
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert result.total_llm_calls == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_candidates_fallback(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = [
|
||||
SyntheticExample(input_text=f"q{i}", id=i) for i in range(20)
|
||||
]
|
||||
mock_state = OptimizationState(
|
||||
iteration=0,
|
||||
best_candidate=None,
|
||||
candidates=[],
|
||||
total_llm_calls=0,
|
||||
)
|
||||
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert result.optimized_prompt == "Answer the question."
|
||||
assert result.initial_score == 0.0
|
||||
assert result.final_score == 0.0
|
||||
assert result.improvement == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_iterations_used_matches_state(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = []
|
||||
mock_state = _make_state(iterations=7)
|
||||
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert result.iterations_used == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_passed_through(
|
||||
self,
|
||||
use_case: OptimizePromptUseCase,
|
||||
mock_bootstrap: MagicMock,
|
||||
config: OptimizationConfig,
|
||||
) -> None:
|
||||
mock_bootstrap.run.return_value = []
|
||||
history = [
|
||||
{"iteration": 1, "event": "accepted"},
|
||||
{"iteration": 2, "event": "rejected"},
|
||||
]
|
||||
mock_state = _make_state()
|
||||
mock_state.history = history
|
||||
|
||||
with patch.object(EvolutionLoop, "run", return_value=mock_state):
|
||||
result = await use_case.execute(config)
|
||||
|
||||
assert result.history == history
|
||||
Reference in New Issue
Block a user