"""Unit tests for hold-out validation and early stopping.""" from __future__ import annotations from unittest.mock import AsyncMock, MagicMock import pytest from prometheus.application.bootstrap import SyntheticBootstrap from prometheus.application.evaluator import PromptEvaluator from prometheus.application.evolution import EvolutionLoop from prometheus.domain.entities import ( Candidate, EvalResult, Prompt, SyntheticExample, Trajectory, ) def _make_eval(mean_score: float, n: int = 5) -> EvalResult: """Helper: create an EvalResult with a given mean score.""" scores = [mean_score] * n return EvalResult( scores=scores, feedbacks=["feedback"] * n, trajectories=[ Trajectory(f"input{i}", f"output{i}", mean_score, "feedback", "prompt") for i in range(n) ], ) class TestBootstrapSplit: """Tests for SyntheticBootstrap.split_pool.""" def test_split_produces_correct_sizes(self): pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(20)] train, val = SyntheticBootstrap.split_pool(pool, 0.3) assert len(train) + len(val) == 20 assert len(val) == 6 # 20 * 0.3 = 6 assert len(train) == 14 def test_split_zero_fraction_returns_all_train(self): pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(10)] train, val = SyntheticBootstrap.split_pool(pool, 0.0) assert len(train) == 10 assert len(val) == 0 def test_split_single_element(self): pool = [SyntheticExample(input_text="only", id=0)] train, val = SyntheticBootstrap.split_pool(pool, 0.3) assert len(train) == 1 assert len(val) == 0 def test_split_deterministic_with_seed(self): pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(50)] train1, val1 = SyntheticBootstrap.split_pool(pool, 0.3, rng=MagicMock(wraps=__import__("random").Random(42))) train2, val2 = SyntheticBootstrap.split_pool(pool, 0.3, rng=MagicMock(wraps=__import__("random").Random(42))) assert [ex.id for ex in train1] == [ex.id for ex in train2] assert [ex.id for ex in val1] == [ex.id for ex in val2] def test_split_no_overlap(self): pool = [SyntheticExample(input_text=f"ex{i}", id=i) for i in range(30)] train, val = SyntheticBootstrap.split_pool(pool, 0.3) train_ids = {ex.id for ex in train} val_ids = {ex.id for ex in val} assert train_ids.isdisjoint(val_ids) assert train_ids | val_ids == {ex.id for ex in pool} class TestValidationEvaluation: """Tests for hold-out evaluation during evolution.""" @pytest.mark.asyncio async def test_validation_pool_evaluated_after_each_iteration( self, seed_prompt: Prompt, synthetic_pool: list[SyntheticExample], task_description: str, mock_llm_port: AsyncMock, mock_judge_port: AsyncMock, mock_proposer_port: AsyncMock, ) -> None: """When a validation pool is provided, the best candidate is evaluated on it.""" evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap.sample_minibatch.return_value = synthetic_pool[:5] # Initial eval (train) + validation eval + iteration train eval + new prompt eval + validation eval train_eval = _make_eval(0.5) val_eval = _make_eval(0.6) new_eval = _make_eval(0.7) val_eval_2 = _make_eval(0.65) evaluator.evaluate = AsyncMock( side_effect=[train_eval, val_eval, train_eval, new_eval, val_eval_2] ) validation_pool = synthetic_pool[-6:] loop = EvolutionLoop( evaluator=evaluator, proposer=mock_proposer_port, bootstrap=bootstrap, max_iterations=1, minibatch_size=5, ) state = await loop.run( seed_prompt, synthetic_pool, task_description, validation_pool=validation_pool, ) # Should have validation metrics in state assert state.best_validation_score is not None # History should contain validation_eval entries val_events = [h for h in state.history if h["event"] == "validation_eval"] assert len(val_events) >= 1 @pytest.mark.asyncio async def test_no_validation_without_pool( self, seed_prompt: Prompt, synthetic_pool: list[SyntheticExample], task_description: str, mock_llm_port: AsyncMock, mock_judge_port: AsyncMock, mock_proposer_port: AsyncMock, ) -> None: """Without a validation pool, no validation is performed.""" evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap.sample_minibatch.return_value = synthetic_pool[:5] train_eval = _make_eval(0.5) old_eval = _make_eval(0.5) new_eval = _make_eval(0.7) evaluator.evaluate = AsyncMock(side_effect=[train_eval, old_eval, new_eval]) loop = EvolutionLoop( evaluator=evaluator, proposer=mock_proposer_port, bootstrap=bootstrap, max_iterations=1, minibatch_size=5, ) state = await loop.run(seed_prompt, synthetic_pool, task_description) assert state.best_validation_score is None assert not state.early_stopped val_events = [h for h in state.history if h["event"] == "validation_eval"] assert len(val_events) == 0 class TestEarlyStopping: """Tests for early stopping when validation score degrades.""" @pytest.mark.asyncio async def test_early_stop_triggers_on_patience_exceeded( self, seed_prompt: Prompt, synthetic_pool: list[SyntheticExample], task_description: str, mock_llm_port: AsyncMock, mock_judge_port: AsyncMock, mock_proposer_port: AsyncMock, ) -> None: """Early stopping triggers when validation doesn't improve for K iterations.""" evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap.sample_minibatch.return_value = synthetic_pool[:5] patience = 3 # Build eval sequence: # 1. Initial train eval # 2. Initial validation eval (0.5) # Then for each of 3 iterations: # - train eval (current best) # - train eval (new prompt - accepted) # - validation eval (degrading) evals = [ _make_eval(0.5), # initial train _make_eval(0.5), # initial validation ] for i in range(patience): evals.extend([ _make_eval(0.5 + i * 0.1), # current eval (train) _make_eval(0.6 + i * 0.1), # new eval (train) - accepted _make_eval(0.4), # validation eval (degrading) ]) evaluator.evaluate = AsyncMock(side_effect=evals) validation_pool = synthetic_pool[-5:] loop = EvolutionLoop( evaluator=evaluator, proposer=mock_proposer_port, bootstrap=bootstrap, max_iterations=10, # would go further without early stop minibatch_size=5, early_stop_patience=patience, ) state = await loop.run( seed_prompt, synthetic_pool, task_description, validation_pool=validation_pool, ) assert state.early_stopped is True assert state.iteration == patience assert state.best_validation_score is not None # Should have an early_stop event in history early_stop_events = [h for h in state.history if h["event"] == "early_stop"] assert len(early_stop_events) == 1 @pytest.mark.asyncio async def test_early_stop_does_not_trigger_when_improving( self, seed_prompt: Prompt, synthetic_pool: list[SyntheticExample], task_description: str, mock_llm_port: AsyncMock, mock_judge_port: AsyncMock, mock_proposer_port: AsyncMock, ) -> None: """When validation keeps improving, early stopping does not trigger.""" evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap.sample_minibatch.return_value = synthetic_pool[:5] evals = [ _make_eval(0.3), # initial train _make_eval(0.3), # initial validation ] # 3 iterations, each with improving validation for i in range(3): evals.extend([ _make_eval(0.3 + i * 0.1), # current train eval _make_eval(0.4 + i * 0.1), # new train eval (accepted) _make_eval(0.3 + (i + 1) * 0.1), # validation eval (improving) ]) evaluator.evaluate = AsyncMock(side_effect=evals) validation_pool = synthetic_pool[-5:] loop = EvolutionLoop( evaluator=evaluator, proposer=mock_proposer_port, bootstrap=bootstrap, max_iterations=3, minibatch_size=5, early_stop_patience=5, ) state = await loop.run( seed_prompt, synthetic_pool, task_description, validation_pool=validation_pool, ) assert state.early_stopped is False assert state.iteration == 3 assert state.best_validation_score is not None @pytest.mark.asyncio async def test_validation_patience_resets_on_improvement( self, seed_prompt: Prompt, synthetic_pool: list[SyntheticExample], task_description: str, mock_llm_port: AsyncMock, mock_judge_port: AsyncMock, mock_proposer_port: AsyncMock, ) -> None: """Patience counter resets when validation improves after degrading.""" evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap.sample_minibatch.return_value = synthetic_pool[:5] evals = [ _make_eval(0.3), # initial train _make_eval(0.3), # initial validation # iter 1: degrade _make_eval(0.3), # current train _make_eval(0.5), # new train (accepted) _make_eval(0.2), # validation degrade (patience=1) # iter 2: degrade _make_eval(0.5), # current train _make_eval(0.6), # new train (accepted) _make_eval(0.2), # validation degrade (patience=2) # iter 3: improve! (resets patience) _make_eval(0.6), # current train _make_eval(0.7), # new train (accepted) _make_eval(0.4), # validation improve (patience=0) # iter 4: degrade again _make_eval(0.7), # current train _make_eval(0.8), # new train (accepted) _make_eval(0.2), # validation degrade (patience=1) ] evaluator.evaluate = AsyncMock(side_effect=evals) validation_pool = synthetic_pool[-5:] loop = EvolutionLoop( evaluator=evaluator, proposer=mock_proposer_port, bootstrap=bootstrap, max_iterations=4, minibatch_size=5, early_stop_patience=3, ) state = await loop.run( seed_prompt, synthetic_pool, task_description, validation_pool=validation_pool, ) assert state.early_stopped is False assert state.iteration == 4