"""Unit tests for checkpoint & resume functionality.""" from __future__ import annotations import json from pathlib import Path from unittest.mock import AsyncMock, MagicMock import pytest from prometheus.application.bootstrap import SyntheticBootstrap from prometheus.application.evaluator import PromptEvaluator from prometheus.application.evolution import EvolutionLoop from prometheus.domain.entities import ( Candidate, EvalResult, OptimizationState, Prompt, SyntheticExample, Trajectory, ) from prometheus.infrastructure.checkpoint import JsonCheckpointPersistence # --------------------------------------------------------------------------- # JsonCheckpointPersistence — save/load round-trip # --------------------------------------------------------------------------- class TestJsonCheckpointPersistence: def test_roundtrip_full_state(self, tmp_path: Path) -> None: """Saving and loading preserves all fields.""" ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts") state = OptimizationState( iteration=7, best_candidate=Candidate( prompt=Prompt(text="best prompt", metadata={"source": "test"}), best_score=0.92, generation=5, ), candidates=[ Candidate(prompt=Prompt(text="p1"), best_score=0.5, generation=0), Candidate(prompt=Prompt(text="p2"), best_score=0.92, generation=5), ], synthetic_pool=[ SyntheticExample(input_text="q1", category="cat_a", id=0), SyntheticExample(input_text="q2", category="cat_b", id=1), ], history=[{"iteration": 1, "event": "accepted", "old_score": 0.5, "new_score": 0.7}], total_llm_calls=42, ) ckpt.save(state) assert ckpt.latest_exists() loaded = ckpt.load() assert loaded is not None assert loaded.iteration == 7 assert loaded.total_llm_calls == 42 assert loaded.best_candidate is not None assert loaded.best_candidate.prompt.text == "best prompt" assert loaded.best_candidate.prompt.metadata == {"source": "test"} assert loaded.best_candidate.best_score == 0.92 assert len(loaded.candidates) == 2 assert len(loaded.synthetic_pool) == 2 assert loaded.synthetic_pool[0].input_text == "q1" assert loaded.synthetic_pool[1].category == "cat_b" assert loaded.history[0]["event"] == "accepted" def test_load_returns_none_when_no_checkpoint(self, tmp_path: Path) -> None: """Loading from empty dir returns None.""" ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "nope") assert ckpt.load() is None assert not ckpt.latest_exists() def test_creates_directory_on_save(self, tmp_path: Path) -> None: """Save creates the directory tree if it doesn't exist.""" deep_dir = tmp_path / "a" / "b" / "c" ckpt = JsonCheckpointPersistence(checkpoint_dir=deep_dir) state = OptimizationState(iteration=1) ckpt.save(state) assert (deep_dir / "latest.json").exists() def test_overwrites_previous_checkpoint(self, tmp_path: Path) -> None: """Second save overwrites the first.""" ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path) ckpt.save(OptimizationState(iteration=1, total_llm_calls=10)) ckpt.save(OptimizationState(iteration=5, total_llm_calls=50)) loaded = ckpt.load() assert loaded is not None assert loaded.iteration == 5 assert loaded.total_llm_calls == 50 def test_json_is_human_readable(self, tmp_path: Path) -> None: """Checkpoint file is valid, pretty-printed JSON.""" ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path) state = OptimizationState( iteration=3, best_candidate=Candidate(prompt=Prompt(text="hello"), best_score=0.8), ) ckpt.save(state) raw = json.loads((tmp_path / "latest.json").read_text()) assert raw["schema_version"] == 1 assert raw["iteration"] == 3 assert raw["best_candidate"]["prompt_text"] == "hello" # --------------------------------------------------------------------------- # EvolutionLoop — checkpoint integration # --------------------------------------------------------------------------- class TestEvolutionCheckpoint: @pytest.mark.asyncio async def test_checkpoint_saved_on_interval( self, seed_prompt: Prompt, synthetic_pool: list[SyntheticExample], task_description: str, ) -> None: """Checkpoint is saved every checkpoint_interval iterations.""" from unittest.mock import MagicMock evaluator = PromptEvaluator(AsyncMock(), AsyncMock()) bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap.sample_minibatch.return_value = synthetic_pool[:5] # All iterations accepted so checkpoint triggers good_eval = EvalResult( scores=[0.3, 0.4, 0.3, 0.5, 0.2], feedbacks=["ok"] * 5, trajectories=[ Trajectory(f"input{i}", f"out{i}", s, "ok", "p") for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2]) ], ) better_eval = EvalResult( scores=[0.8, 0.9, 0.7, 0.8, 0.9], feedbacks=["good"] * 5, trajectories=[], ) # initial_eval + 5 iterations (each needs old_eval + new_eval) evaluator.evaluate = AsyncMock( side_effect=[good_eval] # initial + [good_eval, better_eval] * 5 # 5 iterations ) proposer = AsyncMock() proposer.propose.return_value = Prompt(text="improved prompt") checkpoint_port = MagicMock() loop = EvolutionLoop( evaluator=evaluator, proposer=proposer, bootstrap=bootstrap, max_iterations=5, minibatch_size=5, checkpoint_port=checkpoint_port, checkpoint_interval=2, ) await loop.run(seed_prompt, synthetic_pool, task_description) # Checkpoint at iterations 2, 4 (every 2nd) save_calls = checkpoint_port.save.call_count assert save_calls >= 2 # at least at iters 2 and 4 @pytest.mark.asyncio async def test_no_checkpoint_without_port( self, seed_prompt: Prompt, synthetic_pool: list[SyntheticExample], task_description: str, ) -> None: """No checkpointing happens when checkpoint_port is None (default).""" evaluator = PromptEvaluator(AsyncMock(), AsyncMock()) bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap.sample_minibatch.return_value = synthetic_pool[:5] perfect_eval = EvalResult( scores=[1.0] * 5, feedbacks=["perfect"] * 5, trajectories=[ Trajectory(f"in{i}", f"out{i}", 1.0, "perfect", "p") for i in range(5) ], ) evaluator.evaluate = AsyncMock(return_value=perfect_eval) loop = EvolutionLoop( evaluator=evaluator, proposer=AsyncMock(), bootstrap=bootstrap, max_iterations=3, minibatch_size=5, checkpoint_port=None, ) # Should run without error — no checkpoint port, no crash await loop.run(seed_prompt, synthetic_pool, task_description) @pytest.mark.asyncio async def test_resume_skips_seed_evaluation( self, synthetic_pool: list[SyntheticExample], task_description: str, ) -> None: """When initial_state is provided, seed eval is skipped and loop starts from saved iteration.""" evaluator = PromptEvaluator(AsyncMock(), AsyncMock()) bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap.sample_minibatch.return_value = synthetic_pool[:5] proposer = AsyncMock() proposer.propose.return_value = Prompt(text="new prompt") # Only return evaluations for resumed iterations (1 iter: old_eval + new_eval) old_eval = EvalResult( scores=[0.5] * 5, feedbacks=["ok"] * 5, trajectories=[ Trajectory(f"in{i}", f"out{i}", 0.5, "ok", "p") for i in range(5) ], ) new_eval = EvalResult( scores=[0.8] * 5, feedbacks=["good"] * 5, trajectories=[], ) evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval]) # Create a state simulating checkpoint at iteration 4 initial_state = OptimizationState( iteration=4, best_candidate=Candidate( prompt=Prompt(text="checkpoint prompt"), best_score=2.5, generation=4 ), candidates=[Candidate(prompt=Prompt(text="checkpoint prompt"), best_score=2.5)], total_llm_calls=40, ) loop = EvolutionLoop( evaluator=evaluator, proposer=proposer, bootstrap=bootstrap, max_iterations=5, # only iteration 5 remains minibatch_size=5, ) state = await loop.run( seed_prompt=Prompt(text="seed"), synthetic_pool=synthetic_pool, task_description=task_description, initial_state=initial_state, ) # Should have run only 1 iteration (iter 5) assert state.iteration == 5 # total_llm_calls should include the 40 from checkpoint + new calls assert state.total_llm_calls > 40 @pytest.mark.asyncio async def test_full_save_and_resume_roundtrip( self, seed_prompt: Prompt, synthetic_pool: list[SyntheticExample], task_description: str, tmp_path: Path, ) -> None: """End-to-end: run a few iterations, checkpoint, resume, finish.""" evaluator = PromptEvaluator(AsyncMock(), AsyncMock()) bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap.sample_minibatch.return_value = synthetic_pool[:5] old_eval = EvalResult( scores=[0.3, 0.4, 0.3, 0.5, 0.2], feedbacks=["ok"] * 5, trajectories=[ Trajectory(f"in{i}", f"out{i}", s, "ok", "p") for i, s in enumerate([0.3, 0.4, 0.3, 0.5, 0.2]) ], ) new_eval = EvalResult( scores=[0.8, 0.9, 0.7, 0.8, 0.9], feedbacks=["good"] * 5, trajectories=[], ) evaluator.evaluate = AsyncMock( side_effect=[old_eval, old_eval, new_eval, old_eval, new_eval] ) proposer = AsyncMock() proposer.propose.return_value = Prompt(text="improved prompt") ckpt = JsonCheckpointPersistence(checkpoint_dir=tmp_path / "ckpts") loop = EvolutionLoop( evaluator=evaluator, proposer=proposer, bootstrap=bootstrap, max_iterations=2, minibatch_size=5, checkpoint_port=ckpt, checkpoint_interval=1, ) state = await loop.run(seed_prompt, synthetic_pool, task_description) assert state.iteration == 2 assert ckpt.latest_exists() # Capture the checkpoint state *before* resume (state is mutated in-place) loaded = ckpt.load() assert loaded is not None saved_llm_calls = loaded.total_llm_calls saved_iteration = loaded.iteration # Set up evaluator for resumed run (just 1 more iteration) evaluator.evaluate = AsyncMock(side_effect=[old_eval, new_eval]) proposer.propose.return_value = Prompt(text="even better prompt") loop2 = EvolutionLoop( evaluator=evaluator, proposer=proposer, bootstrap=bootstrap, max_iterations=3, minibatch_size=5, checkpoint_port=ckpt, checkpoint_interval=1, ) resumed = await loop2.run( seed_prompt, synthetic_pool, task_description, initial_state=loaded, ) assert resumed.iteration == 3 assert resumed.total_llm_calls > saved_llm_calls assert resumed.iteration > saved_iteration