feat: v0.2.0 sprint — ground truth eval, crossover/mutation, checkpointing, similarity guards, dataset loader, CLI commands, extended test coverage
Aggregates all v0.2.0 sprint work (GARAA-30 through GARAA-40) and fixes 2 integration tests that broke when the codebase went async (DSPyLLMAdapter and full pipeline tests now properly await coroutines). 277 tests pass (260 unit + 17 integration). Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
333
tests/unit/test_checkpoint.py
Normal file
333
tests/unit/test_checkpoint.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Unit tests for checkpoint & resume functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import EvolutionLoop
|
||||
from prometheus.domain.entities import (
|
||||
Candidate,
|
||||
EvalResult,
|
||||
OptimizationState,
|
||||
Prompt,
|
||||
SyntheticExample,
|
||||
Trajectory,
|
||||
)
|
||||
from prometheus.infrastructure.checkpoint import JsonCheckpointPersistence
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JsonCheckpointPersistence — save/load round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJsonCheckpointPersistence:
|
||||
def test_roundtrip_full_state(self, tmp_path: Path) -> None:
|
||||
"""Saving and loading preserves all fields."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
|
||||
|
||||
state = OptimizationState(
|
||||
iteration=7,
|
||||
best_candidate=Candidate(
|
||||
prompt=Prompt(text="best prompt", metadata={"source": "test"}),
|
||||
best_score=0.92,
|
||||
generation=5,
|
||||
),
|
||||
candidates=[
|
||||
Candidate(prompt=Prompt(text="p1"), best_score=0.5, generation=0),
|
||||
Candidate(prompt=Prompt(text="p2"), best_score=0.92, generation=5),
|
||||
],
|
||||
synthetic_pool=[
|
||||
SyntheticExample(input_text="q1", category="cat_a", id=0),
|
||||
SyntheticExample(input_text="q2", category="cat_b", id=1),
|
||||
],
|
||||
history=[{"iteration": 1, "event": "accepted", "old_score": 0.5, "new_score": 0.7}],
|
||||
total_llm_calls=42,
|
||||
)
|
||||
|
||||
ckpt.save(state)
|
||||
assert ckpt.latest_exists()
|
||||
|
||||
loaded = ckpt.load()
|
||||
assert loaded is not None
|
||||
assert loaded.iteration == 7
|
||||
assert loaded.total_llm_calls == 42
|
||||
assert loaded.best_candidate is not None
|
||||
assert loaded.best_candidate.prompt.text == "best prompt"
|
||||
assert loaded.best_candidate.prompt.metadata == {"source": "test"}
|
||||
assert loaded.best_candidate.best_score == 0.92
|
||||
assert len(loaded.candidates) == 2
|
||||
assert len(loaded.synthetic_pool) == 2
|
||||
assert loaded.synthetic_pool[0].input_text == "q1"
|
||||
assert loaded.synthetic_pool[1].category == "cat_b"
|
||||
assert loaded.history[0]["event"] == "accepted"
|
||||
|
||||
def test_load_returns_none_when_no_checkpoint(self, tmp_path: Path) -> None:
|
||||
"""Loading from empty dir returns None."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "nope")
|
||||
assert ckpt.load() is None
|
||||
assert not ckpt.latest_exists()
|
||||
|
||||
def test_creates_directory_on_save(self, tmp_path: Path) -> None:
|
||||
"""Save creates the directory tree if it doesn't exist."""
|
||||
deep_dir = tmp_path / "a" / "b" / "c"
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=deep_dir)
|
||||
state = OptimizationState(iteration=1)
|
||||
ckpt.save(state)
|
||||
assert (deep_dir / "latest.json").exists()
|
||||
|
||||
def test_overwrites_previous_checkpoint(self, tmp_path: Path) -> None:
|
||||
"""Second save overwrites the first."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
|
||||
|
||||
ckpt.save(OptimizationState(iteration=1, total_llm_calls=10))
|
||||
ckpt.save(OptimizationState(iteration=5, total_llm_calls=50))
|
||||
|
||||
loaded = ckpt.load()
|
||||
assert loaded is not None
|
||||
assert loaded.iteration == 5
|
||||
assert loaded.total_llm_calls == 50
|
||||
|
||||
def test_json_is_human_readable(self, tmp_path: Path) -> None:
|
||||
"""Checkpoint file is valid, pretty-printed JSON."""
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
|
||||
state = OptimizationState(
|
||||
iteration=3,
|
||||
best_candidate=Candidate(prompt=Prompt(text="hello"), best_score=0.8),
|
||||
)
|
||||
ckpt.save(state)
|
||||
|
||||
raw = json.loads((tmp_path / "latest.json").read_text())
|
||||
assert raw["schema_version"] == 1
|
||||
assert raw["iteration"] == 3
|
||||
assert raw["best_candidate"]["prompt_text"] == "hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EvolutionLoop — checkpoint integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEvolutionCheckpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_saved_on_interval(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""Checkpoint is saved every checkpoint_interval iterations."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
# All iterations accepted so checkpoint triggers
|
||||
good_eval = EvalResult(
|
||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"out{i}", s, "ok", "p")
|
||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||
],
|
||||
)
|
||||
better_eval = EvalResult(
|
||||
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
# initial_eval + 5 iterations (each needs old_eval + new_eval)
|
||||
evaluator.evaluate = AsyncMock(
|
||||
side_effect=[good_eval] # initial
|
||||
+ [good_eval, better_eval] * 5 # 5 iterations
|
||||
)
|
||||
|
||||
proposer = AsyncMock()
|
||||
proposer.propose.return_value = Prompt(text="improved prompt")
|
||||
|
||||
checkpoint_port = MagicMock()
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=5,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=checkpoint_port,
|
||||
checkpoint_interval=2,
|
||||
)
|
||||
|
||||
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
# Checkpoint at iterations 2, 4 (every 2nd)
|
||||
save_calls = checkpoint_port.save.call_count
|
||||
assert save_calls >= 2 # at least at iters 2 and 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_checkpoint_without_port(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""No checkpointing happens when checkpoint_port is None (default)."""
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
perfect_eval = EvalResult(
|
||||
scores=[1.0] * 5,
|
||||
feedbacks=["perfect"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", 1.0, "perfect", "p")
|
||||
for i in range(5)
|
||||
],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=AsyncMock(),
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=None,
|
||||
)
|
||||
# Should run without error — no checkpoint port, no crash
|
||||
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_skips_seed_evaluation(
|
||||
self,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""When initial_state is provided, seed eval is skipped and loop starts from saved iteration."""
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
proposer = AsyncMock()
|
||||
proposer.propose.return_value = Prompt(text="new prompt")
|
||||
|
||||
# Only return evaluations for resumed iterations (1 iter: old_eval + new_eval)
|
||||
old_eval = EvalResult(
|
||||
scores=[0.5] * 5,
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", 0.5, "ok", "p") for i in range(5)
|
||||
],
|
||||
)
|
||||
new_eval = EvalResult(
|
||||
scores=[0.8] * 5,
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
|
||||
|
||||
# Create a state simulating checkpoint at iteration 4
|
||||
initial_state = OptimizationState(
|
||||
iteration=4,
|
||||
best_candidate=Candidate(
|
||||
prompt=Prompt(text="checkpoint prompt"), best_score=2.5, generation=4
|
||||
),
|
||||
candidates=[Candidate(prompt=Prompt(text="checkpoint prompt"), best_score=2.5)],
|
||||
total_llm_calls=40,
|
||||
)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=5, # only iteration 5 remains
|
||||
minibatch_size=5,
|
||||
)
|
||||
state = await loop.run(
|
||||
seed_prompt=Prompt(text="seed"),
|
||||
synthetic_pool=synthetic_pool,
|
||||
task_description=task_description,
|
||||
initial_state=initial_state,
|
||||
)
|
||||
|
||||
# Should have run only 1 iteration (iter 5)
|
||||
assert state.iteration == 5
|
||||
# total_llm_calls should include the 40 from checkpoint + new calls
|
||||
assert state.total_llm_calls > 40
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_save_and_resume_roundtrip(
|
||||
self,
|
||||
seed_prompt: Prompt,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""End-to-end: run a few iterations, checkpoint, resume, finish."""
|
||||
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
||||
|
||||
old_eval = EvalResult(
|
||||
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
||||
feedbacks=["ok"] * 5,
|
||||
trajectories=[
|
||||
Trajectory(f"in{i}", f"out{i}", s, "ok", "p")
|
||||
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
||||
],
|
||||
)
|
||||
new_eval = EvalResult(
|
||||
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
||||
feedbacks=["good"] * 5,
|
||||
trajectories=[],
|
||||
)
|
||||
evaluator.evaluate = AsyncMock(
|
||||
side_effect=[old_eval, old_eval, new_eval, old_eval, new_eval]
|
||||
)
|
||||
proposer = AsyncMock()
|
||||
proposer.propose.return_value = Prompt(text="improved prompt")
|
||||
|
||||
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=2,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=ckpt,
|
||||
checkpoint_interval=1,
|
||||
)
|
||||
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||
assert state.iteration == 2
|
||||
assert ckpt.latest_exists()
|
||||
|
||||
# Capture the checkpoint state *before* resume (state is mutated in-place)
|
||||
loaded = ckpt.load()
|
||||
assert loaded is not None
|
||||
saved_llm_calls = loaded.total_llm_calls
|
||||
saved_iteration = loaded.iteration
|
||||
|
||||
# Set up evaluator for resumed run (just 1 more iteration)
|
||||
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
|
||||
proposer.propose.return_value = Prompt(text="even better prompt")
|
||||
|
||||
loop2 = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=3,
|
||||
minibatch_size=5,
|
||||
checkpoint_port=ckpt,
|
||||
checkpoint_interval=1,
|
||||
)
|
||||
resumed = await loop2.run(
|
||||
seed_prompt, synthetic_pool, task_description,
|
||||
initial_state=loaded,
|
||||
)
|
||||
assert resumed.iteration == 3
|
||||
assert resumed.total_llm_calls > saved_llm_calls
|
||||
assert resumed.iteration > saved_iteration
|
||||
Reference in New Issue
Block a user