Aggregates all v0.2.0 sprint work (GARAA-30 through GARAA-40) and fixes 2 integration tests that broke when the codebase went async (DSPyLLMAdapter and full pipeline tests now properly await coroutines). 277 tests pass (260 unit + 17 integration). Co-Authored-By: Paperclip <noreply@paperclip.ing>
334 lines
12 KiB
Python
334 lines
12 KiB
Python
"""Unit tests for checkpoint & resume functionality."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
|
from prometheus.application.evaluator import PromptEvaluator
|
|
from prometheus.application.evolution import EvolutionLoop
|
|
from prometheus.domain.entities import (
|
|
Candidate,
|
|
EvalResult,
|
|
OptimizationState,
|
|
Prompt,
|
|
SyntheticExample,
|
|
Trajectory,
|
|
)
|
|
from prometheus.infrastructure.checkpoint import JsonCheckpointPersistence
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# JsonCheckpointPersistence — save/load round-trip
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestJsonCheckpointPersistence:
|
|
def test_roundtrip_full_state(self, tmp_path: Path) -> None:
|
|
"""Saving and loading preserves all fields."""
|
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
|
|
|
|
state = OptimizationState(
|
|
iteration=7,
|
|
best_candidate=Candidate(
|
|
prompt=Prompt(text="best prompt", metadata={"source": "test"}),
|
|
best_score=0.92,
|
|
generation=5,
|
|
),
|
|
candidates=[
|
|
Candidate(prompt=Prompt(text="p1"), best_score=0.5, generation=0),
|
|
Candidate(prompt=Prompt(text="p2"), best_score=0.92, generation=5),
|
|
],
|
|
synthetic_pool=[
|
|
SyntheticExample(input_text="q1", category="cat_a", id=0),
|
|
SyntheticExample(input_text="q2", category="cat_b", id=1),
|
|
],
|
|
history=[{"iteration": 1, "event": "accepted", "old_score": 0.5, "new_score": 0.7}],
|
|
total_llm_calls=42,
|
|
)
|
|
|
|
ckpt.save(state)
|
|
assert ckpt.latest_exists()
|
|
|
|
loaded = ckpt.load()
|
|
assert loaded is not None
|
|
assert loaded.iteration == 7
|
|
assert loaded.total_llm_calls == 42
|
|
assert loaded.best_candidate is not None
|
|
assert loaded.best_candidate.prompt.text == "best prompt"
|
|
assert loaded.best_candidate.prompt.metadata == {"source": "test"}
|
|
assert loaded.best_candidate.best_score == 0.92
|
|
assert len(loaded.candidates) == 2
|
|
assert len(loaded.synthetic_pool) == 2
|
|
assert loaded.synthetic_pool[0].input_text == "q1"
|
|
assert loaded.synthetic_pool[1].category == "cat_b"
|
|
assert loaded.history[0]["event"] == "accepted"
|
|
|
|
def test_load_returns_none_when_no_checkpoint(self, tmp_path: Path) -> None:
|
|
"""Loading from empty dir returns None."""
|
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "nope")
|
|
assert ckpt.load() is None
|
|
assert not ckpt.latest_exists()
|
|
|
|
def test_creates_directory_on_save(self, tmp_path: Path) -> None:
|
|
"""Save creates the directory tree if it doesn't exist."""
|
|
deep_dir = tmp_path / "a" / "b" / "c"
|
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=deep_dir)
|
|
state = OptimizationState(iteration=1)
|
|
ckpt.save(state)
|
|
assert (deep_dir / "latest.json").exists()
|
|
|
|
def test_overwrites_previous_checkpoint(self, tmp_path: Path) -> None:
|
|
"""Second save overwrites the first."""
|
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
|
|
|
|
ckpt.save(OptimizationState(iteration=1, total_llm_calls=10))
|
|
ckpt.save(OptimizationState(iteration=5, total_llm_calls=50))
|
|
|
|
loaded = ckpt.load()
|
|
assert loaded is not None
|
|
assert loaded.iteration == 5
|
|
assert loaded.total_llm_calls == 50
|
|
|
|
def test_json_is_human_readable(self, tmp_path: Path) -> None:
|
|
"""Checkpoint file is valid, pretty-printed JSON."""
|
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path)
|
|
state = OptimizationState(
|
|
iteration=3,
|
|
best_candidate=Candidate(prompt=Prompt(text="hello"), best_score=0.8),
|
|
)
|
|
ckpt.save(state)
|
|
|
|
raw = json.loads((tmp_path / "latest.json").read_text())
|
|
assert raw["schema_version"] == 1
|
|
assert raw["iteration"] == 3
|
|
assert raw["best_candidate"]["prompt_text"] == "hello"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# EvolutionLoop — checkpoint integration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEvolutionCheckpoint:
|
|
@pytest.mark.asyncio
|
|
async def test_checkpoint_saved_on_interval(
|
|
self,
|
|
seed_prompt: Prompt,
|
|
synthetic_pool: list[SyntheticExample],
|
|
task_description: str,
|
|
) -> None:
|
|
"""Checkpoint is saved every checkpoint_interval iterations."""
|
|
from unittest.mock import MagicMock
|
|
|
|
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
|
|
|
# All iterations accepted so checkpoint triggers
|
|
good_eval = EvalResult(
|
|
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
|
feedbacks=["ok"] * 5,
|
|
trajectories=[
|
|
Trajectory(f"input{i}", f"out{i}", s, "ok", "p")
|
|
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
|
],
|
|
)
|
|
better_eval = EvalResult(
|
|
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
|
feedbacks=["good"] * 5,
|
|
trajectories=[],
|
|
)
|
|
# initial_eval + 5 iterations (each needs old_eval + new_eval)
|
|
evaluator.evaluate = AsyncMock(
|
|
side_effect=[good_eval] # initial
|
|
+ [good_eval, better_eval] * 5 # 5 iterations
|
|
)
|
|
|
|
proposer = AsyncMock()
|
|
proposer.propose.return_value = Prompt(text="improved prompt")
|
|
|
|
checkpoint_port = MagicMock()
|
|
loop = EvolutionLoop(
|
|
evaluator=evaluator,
|
|
proposer=proposer,
|
|
bootstrap=bootstrap,
|
|
max_iterations=5,
|
|
minibatch_size=5,
|
|
checkpoint_port=checkpoint_port,
|
|
checkpoint_interval=2,
|
|
)
|
|
|
|
await loop.run(seed_prompt, synthetic_pool, task_description)
|
|
|
|
# Checkpoint at iterations 2, 4 (every 2nd)
|
|
save_calls = checkpoint_port.save.call_count
|
|
assert save_calls >= 2 # at least at iters 2 and 4
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_checkpoint_without_port(
|
|
self,
|
|
seed_prompt: Prompt,
|
|
synthetic_pool: list[SyntheticExample],
|
|
task_description: str,
|
|
) -> None:
|
|
"""No checkpointing happens when checkpoint_port is None (default)."""
|
|
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
|
|
|
perfect_eval = EvalResult(
|
|
scores=[1.0] * 5,
|
|
feedbacks=["perfect"] * 5,
|
|
trajectories=[
|
|
Trajectory(f"in{i}", f"out{i}", 1.0, "perfect", "p")
|
|
for i in range(5)
|
|
],
|
|
)
|
|
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
|
|
|
loop = EvolutionLoop(
|
|
evaluator=evaluator,
|
|
proposer=AsyncMock(),
|
|
bootstrap=bootstrap,
|
|
max_iterations=3,
|
|
minibatch_size=5,
|
|
checkpoint_port=None,
|
|
)
|
|
# Should run without error — no checkpoint port, no crash
|
|
await loop.run(seed_prompt, synthetic_pool, task_description)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resume_skips_seed_evaluation(
|
|
self,
|
|
synthetic_pool: list[SyntheticExample],
|
|
task_description: str,
|
|
) -> None:
|
|
"""When initial_state is provided, seed eval is skipped and loop starts from saved iteration."""
|
|
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
|
|
|
proposer = AsyncMock()
|
|
proposer.propose.return_value = Prompt(text="new prompt")
|
|
|
|
# Only return evaluations for resumed iterations (1 iter: old_eval + new_eval)
|
|
old_eval = EvalResult(
|
|
scores=[0.5] * 5,
|
|
feedbacks=["ok"] * 5,
|
|
trajectories=[
|
|
Trajectory(f"in{i}", f"out{i}", 0.5, "ok", "p") for i in range(5)
|
|
],
|
|
)
|
|
new_eval = EvalResult(
|
|
scores=[0.8] * 5,
|
|
feedbacks=["good"] * 5,
|
|
trajectories=[],
|
|
)
|
|
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
|
|
|
|
# Create a state simulating checkpoint at iteration 4
|
|
initial_state = OptimizationState(
|
|
iteration=4,
|
|
best_candidate=Candidate(
|
|
prompt=Prompt(text="checkpoint prompt"), best_score=2.5, generation=4
|
|
),
|
|
candidates=[Candidate(prompt=Prompt(text="checkpoint prompt"), best_score=2.5)],
|
|
total_llm_calls=40,
|
|
)
|
|
|
|
loop = EvolutionLoop(
|
|
evaluator=evaluator,
|
|
proposer=proposer,
|
|
bootstrap=bootstrap,
|
|
max_iterations=5, # only iteration 5 remains
|
|
minibatch_size=5,
|
|
)
|
|
state = await loop.run(
|
|
seed_prompt=Prompt(text="seed"),
|
|
synthetic_pool=synthetic_pool,
|
|
task_description=task_description,
|
|
initial_state=initial_state,
|
|
)
|
|
|
|
# Should have run only 1 iteration (iter 5)
|
|
assert state.iteration == 5
|
|
# total_llm_calls should include the 40 from checkpoint + new calls
|
|
assert state.total_llm_calls > 40
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_save_and_resume_roundtrip(
|
|
self,
|
|
seed_prompt: Prompt,
|
|
synthetic_pool: list[SyntheticExample],
|
|
task_description: str,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""End-to-end: run a few iterations, checkpoint, resume, finish."""
|
|
evaluator = PromptEvaluator(AsyncMock(), AsyncMock())
|
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
|
bootstrap.sample_minibatch.return_value = synthetic_pool[:5]
|
|
|
|
old_eval = EvalResult(
|
|
scores=[0.3, 0.4, 0.3, 0.5, 0.2],
|
|
feedbacks=["ok"] * 5,
|
|
trajectories=[
|
|
Trajectory(f"in{i}", f"out{i}", s, "ok", "p")
|
|
for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2])
|
|
],
|
|
)
|
|
new_eval = EvalResult(
|
|
scores=[0.8, 0.9, 0.7, 0.8, 0.9],
|
|
feedbacks=["good"] * 5,
|
|
trajectories=[],
|
|
)
|
|
evaluator.evaluate = AsyncMock(
|
|
side_effect=[old_eval, old_eval, new_eval, old_eval, new_eval]
|
|
)
|
|
proposer = AsyncMock()
|
|
proposer.propose.return_value = Prompt(text="improved prompt")
|
|
|
|
ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts")
|
|
loop = EvolutionLoop(
|
|
evaluator=evaluator,
|
|
proposer=proposer,
|
|
bootstrap=bootstrap,
|
|
max_iterations=2,
|
|
minibatch_size=5,
|
|
checkpoint_port=ckpt,
|
|
checkpoint_interval=1,
|
|
)
|
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
|
assert state.iteration == 2
|
|
assert ckpt.latest_exists()
|
|
|
|
# Capture the checkpoint state *before* resume (state is mutated in-place)
|
|
loaded = ckpt.load()
|
|
assert loaded is not None
|
|
saved_llm_calls = loaded.total_llm_calls
|
|
saved_iteration = loaded.iteration
|
|
|
|
# Set up evaluator for resumed run (just 1 more iteration)
|
|
evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval])
|
|
proposer.propose.return_value = Prompt(text="even better prompt")
|
|
|
|
loop2 = EvolutionLoop(
|
|
evaluator=evaluator,
|
|
proposer=proposer,
|
|
bootstrap=bootstrap,
|
|
max_iterations=3,
|
|
minibatch_size=5,
|
|
checkpoint_port=ckpt,
|
|
checkpoint_interval=1,
|
|
)
|
|
resumed = await loop2.run(
|
|
seed_prompt, synthetic_pool, task_description,
|
|
initial_state=loaded,
|
|
)
|
|
assert resumed.iteration == 3
|
|
assert resumed.total_llm_calls > saved_llm_calls
|
|
assert resumed.iteration > saved_iteration
|