feat: async/parallel execution with configurable concurrency

Parallelize LLM calls across minibatches to reduce wall-clock time.
All domain ports (LLMPort, JudgePort, ProposerPort) are now async.
Adapter implementations wrap synchronous DSPy calls with asyncio.to_thread.
Judge calls run in parallel within a batch using asyncio.gather + semaphore.
Evaluator parallelizes minibatch execution with configurable concurrency.
Evolution loop and use case are fully async. Proposer stays sequential.
Added --max-concurrency CLI flag and max_concurrency YAML config field.
Added async_retry_with_backoff for async error handling.
All 139 unit tests pass.

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
FullStackDev
2026-03-29 13:15:34 +00:00
parent e2d111ce5b
commit c92ca4a2b8
16 changed files with 297 additions and 159 deletions

View File

@@ -16,6 +16,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"pytest>=8.3", "pytest>=8.3",
"pytest-asyncio>=0.24",
"pytest-cov>=6.0", "pytest-cov>=6.0",
"ruff>=0.9", "ruff>=0.9",
"mypy>=1.14", "mypy>=1.14",
@@ -37,6 +38,9 @@ target-version = "py312"
python_version = "3.12" python_version = "3.12"
strict = true strict = true
[tool.pytest.ini_options]
asyncio_mode = "auto"
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = ["dspy", "dspy.*"] module = ["dspy", "dspy.*"]
ignore_missing_imports = true ignore_missing_imports = true

View File

@@ -38,6 +38,9 @@ class OptimizationConfig:
# --- Reproducibility --- # --- Reproducibility ---
seed: int = 42 seed: int = 42
# --- Concurrency ---
max_concurrency: int = 5
# --- Error handling --- # --- Error handling ---
max_retries: int = 3 max_retries: int = 3
retry_delay_base: float = 1.0 retry_delay_base: float = 1.0

View File

@@ -6,6 +6,7 @@ Combines candidate prompt execution + LLM-as-Judge evaluation.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from prometheus.domain.entities import ( from prometheus.domain.entities import (
@@ -26,15 +27,22 @@ class PromptEvaluator:
Replaces GEPA's EvaluatorFn. Instead of comparing to ground truth, Replaces GEPA's EvaluatorFn. Instead of comparing to ground truth,
uses an LLM-as-Judge. uses an LLM-as-Judge.
Execution and judge calls run in parallel (bounded by *max_concurrency*).
Per-call isolation: a failure on one minibatch item produces a Per-call isolation: a failure on one minibatch item produces a
zero-score trajectory instead of crashing the whole batch. zero-score trajectory instead of crashing the whole batch.
""" """
def __init__(self, executor: LLMPort, judge: JudgePort): def __init__(
self,
executor: LLMPort,
judge: JudgePort,
max_concurrency: int = 5,
):
self._executor = executor self._executor = executor
self._judge = judge self._judge = judge
self._semaphore = asyncio.Semaphore(max_concurrency)
def evaluate( async def evaluate(
self, self,
prompt: Prompt, prompt: Prompt,
minibatch: list[SyntheticExample], minibatch: list[SyntheticExample],
@@ -43,27 +51,20 @@ class PromptEvaluator:
"""Evaluate the prompt on the minibatch. """Evaluate the prompt on the minibatch.
Steps: Steps:
1. Execute the prompt on each input in the minibatch 1. Execute the prompt on each input in the minibatch (parallel)
2. Judge each (input, output) pair 2. Judge each (input, output) pair
3. Build trajectories with feedback 3. Build trajectories with feedback
""" """
# Step 1: Execution (per-item isolation) # Step 1: Parallel execution (per-item isolation)
outputs: list[str] = [] output_coros = [
for example in minibatch: self._execute_single(prompt, example)
try: for example in minibatch
raw_output = self._executor.execute(prompt, example.input_text) ]
outputs.append(raw_output) outputs = await asyncio.gather(*output_coros)
except Exception as exc:
logger.warning(
"Execution failed for input '%s': %s",
example.input_text[:40],
exc,
)
outputs.append(f"[execution error: {exc}]")
# Step 2: Judgement (judge_adapter handles its own per-call isolation) # Step 2: Judgement (judge_adapter handles its own per-call isolation + parallelism)
pairs = [(ex.input_text, out) for ex, out in zip(minibatch, outputs)] pairs = [(ex.input_text, out) for ex, out in zip(minibatch, outputs)]
judge_results = self._judge.judge_batch(task_description, pairs) judge_results = await self._judge.judge_batch(task_description, pairs)
# Step 3: Build trajectories # Step 3: Build trajectories
scores: list[float] = [] scores: list[float] = []
@@ -88,3 +89,17 @@ class PromptEvaluator:
feedbacks=feedbacks, feedbacks=feedbacks,
trajectories=trajectories, trajectories=trajectories,
) )
async def _execute_single(
self, prompt: Prompt, example: SyntheticExample
) -> str:
async with self._semaphore:
try:
return await self._executor.execute(prompt, example.input_text)
except Exception as exc:
logger.warning(
"Execution failed for input '%s': %s",
example.input_text[:40],
exc,
)
return f"[execution error: {exc}]"

View File

@@ -62,7 +62,7 @@ class EvolutionLoop:
self._circuit_breaker_threshold = circuit_breaker_threshold self._circuit_breaker_threshold = circuit_breaker_threshold
self._error_strategy = error_strategy self._error_strategy = error_strategy
def run( async def run(
self, self,
seed_prompt: Prompt, seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample], synthetic_pool: list[SyntheticExample],
@@ -76,7 +76,7 @@ class EvolutionLoop:
initial_batch = self._bootstrap.sample_minibatch( initial_batch = self._bootstrap.sample_minibatch(
synthetic_pool, self._minibatch_size synthetic_pool, self._minibatch_size
) )
initial_eval = self._evaluator.evaluate( initial_eval = await self._evaluator.evaluate(
seed_prompt, initial_batch, task_description seed_prompt, initial_batch, task_description
) )
state.total_llm_calls += 2 * self._minibatch_size # N executions + N judge calls state.total_llm_calls += 2 * self._minibatch_size # N executions + N judge calls
@@ -95,7 +95,7 @@ class EvolutionLoop:
state.iteration = i state.iteration = i
try: try:
self._run_iteration( await self._run_iteration(
i, state, best_candidate, synthetic_pool, task_description i, state, best_candidate, synthetic_pool, task_description
) )
# Update best_candidate from state after successful iteration # Update best_candidate from state after successful iteration
@@ -144,7 +144,7 @@ class EvolutionLoop:
return state return state
def _run_iteration( async def _run_iteration(
self, self,
i: int, i: int,
state: OptimizationState, state: OptimizationState,
@@ -159,7 +159,7 @@ class EvolutionLoop:
) )
# 2. Evaluate the current candidate # 2. Evaluate the current candidate
current_eval = self._evaluator.evaluate( current_eval = await self._evaluator.evaluate(
best_candidate.prompt, batch, task_description best_candidate.prompt, batch, task_description
) )
state.total_llm_calls += 2 * self._minibatch_size state.total_llm_calls += 2 * self._minibatch_size
@@ -176,8 +176,8 @@ class EvolutionLoop:
) )
return return
# 4. Propose a new prompt (reflective mutation) # 4. Propose a new prompt (reflective mutation) — sequential
new_prompt = self._proposer.propose( new_prompt = await self._proposer.propose(
best_candidate.prompt, best_candidate.prompt,
current_eval.trajectories, current_eval.trajectories,
task_description, task_description,
@@ -185,7 +185,7 @@ class EvolutionLoop:
state.total_llm_calls += 1 # 1 proposition call state.total_llm_calls += 1 # 1 proposition call
# 5. Evaluate the new prompt on the same minibatch # 5. Evaluate the new prompt on the same minibatch
new_eval = self._evaluator.evaluate( new_eval = await self._evaluator.evaluate(
new_prompt, batch, task_description new_prompt, batch, task_description
) )
state.total_llm_calls += 2 * self._minibatch_size state.total_llm_calls += 2 * self._minibatch_size

View File

@@ -30,7 +30,7 @@ class OptimizePromptUseCase:
self._proposer = proposer self._proposer = proposer
self._bootstrap = bootstrap self._bootstrap = bootstrap
def execute(self, config: OptimizationConfig) -> OptimizationResult: async def execute(self, config: OptimizationConfig) -> OptimizationResult:
"""Full pipeline: """Full pipeline:
1. Bootstrap → generate synthetic inputs 1. Bootstrap → generate synthetic inputs
2. Evolution → optimization loop 2. Evolution → optimization loop
@@ -55,7 +55,7 @@ class OptimizePromptUseCase:
error_strategy=config.error_strategy, error_strategy=config.error_strategy,
) )
seed_prompt = Prompt(text=config.seed_prompt) seed_prompt = Prompt(text=config.seed_prompt)
state = loop.run(seed_prompt, synthetic_pool, config.task_description) state = await loop.run(seed_prompt, synthetic_pool, config.task_description)
# Phase 2: Result # Phase 2: Result
initial_score = ( initial_score = (

View File

@@ -5,6 +5,7 @@ Typer interface with -i (input) and -o (output) options.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import os import os
from dataclasses import asdict from dataclasses import asdict
@@ -66,12 +67,30 @@ def optimize(
"--error-strategy", "--error-strategy",
help="How to handle errors: skip | retry | abort.", help="How to handle errors: skip | retry | abort.",
), ),
max_concurrency: int = typer.Option(
5,
"--max-concurrency",
help="Max parallel LLM calls for minibatch execution and judging.",
),
) -> None: ) -> None:
"""Optimize a prompt without any reference data. """Optimize a prompt without any reference data.
Usage: Usage:
prometheus optimize -i config.yaml -o result.yaml prometheus optimize -i config.yaml -o result.yaml
""" """
asyncio.run(
_async_optimize(input, output, verbose, max_retries, error_strategy, max_concurrency)
)
async def _async_optimize(
input: str,
output: str,
verbose: bool,
max_retries: int,
error_strategy: str,
max_concurrency: int,
) -> None:
# Configure verbose logging # Configure verbose logging
if verbose: if verbose:
logging.basicConfig(level=logging.INFO, format="[PROMETHEUS] %(message)s") logging.basicConfig(level=logging.INFO, format="[PROMETHEUS] %(message)s")
@@ -129,6 +148,7 @@ def optimize(
retry_delay_base=raw_config.get("retry_delay_base", 1.0), retry_delay_base=raw_config.get("retry_delay_base", 1.0),
circuit_breaker_threshold=raw_config.get("circuit_breaker_threshold", 5), circuit_breaker_threshold=raw_config.get("circuit_breaker_threshold", 5),
error_strategy=raw_config.get("error_strategy", error_strategy), error_strategy=raw_config.get("error_strategy", error_strategy),
max_concurrency=raw_config.get("max_concurrency", max_concurrency),
output_path=output, output_path=output,
verbose=verbose, verbose=verbose,
) )
@@ -164,6 +184,7 @@ def optimize(
lm=judge_lm, lm=judge_lm,
max_retries=config.max_retries, max_retries=config.max_retries,
retry_delay_base=config.retry_delay_base, retry_delay_base=config.retry_delay_base,
max_concurrency=config.max_concurrency,
) )
proposer_adapter = DSPyProposerAdapter( proposer_adapter = DSPyProposerAdapter(
lm=proposer_lm, lm=proposer_lm,
@@ -171,7 +192,11 @@ def optimize(
retry_delay_base=config.retry_delay_base, retry_delay_base=config.retry_delay_base,
) )
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed) bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter) evaluator = PromptEvaluator(
executor=llm_adapter,
judge=judge_adapter,
max_concurrency=config.max_concurrency,
)
use_case = OptimizePromptUseCase( use_case = OptimizePromptUseCase(
evaluator=evaluator, evaluator=evaluator,
proposer=proposer_adapter, proposer=proposer_adapter,
@@ -180,7 +205,7 @@ def optimize(
# 4. Execute # 4. Execute
with console.status("[bold green]Evolving prompt..."): with console.status("[bold green]Evolving prompt..."):
result = use_case.execute(config) result = await use_case.execute(config)
# 5. Display results # 5. Display results
_display_result(result) _display_result(result)

View File

@@ -18,7 +18,7 @@ class LLMPort(ABC):
""" """
@abstractmethod @abstractmethod
def execute(self, prompt: Prompt, input_text: str) -> str: async def execute(self, prompt: Prompt, input_text: str) -> str:
"""Execute the prompt on the input, return the raw response.""" """Execute the prompt on the input, return the raw response."""
... ...
@@ -31,7 +31,7 @@ class JudgePort(ABC):
""" """
@abstractmethod @abstractmethod
def judge_batch( async def judge_batch(
self, self,
task_description: str, task_description: str,
pairs: list[tuple[str, str]], pairs: list[tuple[str, str]],
@@ -50,7 +50,7 @@ class ProposerPort(ABC):
""" """
@abstractmethod @abstractmethod
def propose( async def propose(
self, self,
current_prompt: Prompt, current_prompt: Prompt,
trajectories: list[Trajectory], trajectories: list[Trajectory],

View File

@@ -5,13 +5,15 @@ Implements the JudgePort via the DSPy OutputJudge module.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from typing import Self
import dspy import dspy
from prometheus.domain.ports import JudgePort from prometheus.domain.ports import JudgePort
from prometheus.infrastructure.dspy_modules import OutputJudge from prometheus.infrastructure.dspy_modules import OutputJudge
from prometheus.infrastructure.retry import retry_with_backoff from prometheus.infrastructure.retry import async_retry_with_backoff
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -21,6 +23,8 @@ class DSPyJudgeAdapter(JudgePort):
Per-call isolation: a failure on one item returns a zero-score sentinel Per-call isolation: a failure on one item returns a zero-score sentinel
instead of crashing the whole batch. instead of crashing the whole batch.
Judge calls run in parallel (bounded by *max_concurrency*).
""" """
def __init__( def __init__(
@@ -28,40 +32,60 @@ class DSPyJudgeAdapter(JudgePort):
lm: dspy.LM, lm: dspy.LM,
max_retries: int = 3, max_retries: int = 3,
retry_delay_base: float = 1.0, retry_delay_base: float = 1.0,
max_concurrency: int = 5,
) -> None: ) -> None:
self._lm = lm self._lm = lm
self._judge = OutputJudge() self._judge = OutputJudge()
self._max_retries = max_retries self._max_retries = max_retries
self._retry_delay_base = retry_delay_base self._retry_delay_base = retry_delay_base
self._semaphore = asyncio.Semaphore(max_concurrency)
def judge_batch( async def judge_batch(
self, self,
task_description: str, task_description: str,
pairs: list[tuple[str, str]], pairs: list[tuple[str, str]],
) -> list[tuple[float, str]]: ) -> list[tuple[float, str]]:
results: list[tuple[float, str]] = [] tasks = [
with dspy.context(lm=self._lm): self._judge_single_safe(task_description, input_text, output_text)
for input_text, output_text in pairs: for input_text, output_text in pairs
results.append(self._judge_single(task_description, input_text, output_text)) ]
return results return list(await asyncio.gather(*tasks))
def _judge_single( async def _judge_single_safe(
self, self,
task_description: str, task_description: str,
input_text: str, input_text: str,
output_text: str, output_text: str,
) -> tuple[float, str]: ) -> tuple[float, str]:
async with self._semaphore:
try: try:
pred = retry_with_backoff( return await self._judge_single(task_description, input_text, output_text)
lambda: self._judge(
task_description=task_description,
input_text=input_text,
output_text=output_text,
),
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)
return (pred.score, pred.feedback)
except Exception as exc: except Exception as exc:
logger.warning("Judge call failed for input '%s': %s", input_text[:40], exc) logger.warning("Judge call failed for input '%s': %s", input_text[:40], exc)
return (0.0, f"[judge error: {exc}]") return (0.0, f"[judge error: {exc}]")
async def _judge_single(
self,
task_description: str,
input_text: str,
output_text: str,
) -> tuple[float, str]:
async def _call() -> tuple[float, str]:
pred = await asyncio.to_thread(
self._sync_judge, task_description, input_text, output_text,
)
return (pred.score, pred.feedback)
return await async_retry_with_backoff(
_call,
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)
def _sync_judge(self, task_description: str, input_text: str, output_text: str):
with dspy.context(lm=self._lm):
return self._judge(
task_description=task_description,
input_text=input_text,
output_text=output_text,
)

View File

@@ -5,11 +5,13 @@ Implements the LLMPort via DSPy.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import dspy import dspy
from prometheus.domain.entities import Prompt from prometheus.domain.entities import Prompt
from prometheus.domain.ports import LLMPort from prometheus.domain.ports import LLMPort
from prometheus.infrastructure.retry import retry_with_backoff from prometheus.infrastructure.retry import async_retry_with_backoff
class DSPyLLMAdapter(LLMPort): class DSPyLLMAdapter(LLMPort):
@@ -33,17 +35,21 @@ class DSPyLLMAdapter(LLMPort):
self._max_retries = max_retries self._max_retries = max_retries
self._retry_delay_base = retry_delay_base self._retry_delay_base = retry_delay_base
def execute(self, prompt: Prompt, input_text: str) -> str: async def execute(self, prompt: Prompt, input_text: str) -> str:
def _call() -> str: async def _call() -> str:
# DSPy is synchronous — run in a thread to avoid blocking the event loop.
return await asyncio.to_thread(self._sync_execute, prompt, input_text)
return await async_retry_with_backoff(
_call,
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)
def _sync_execute(self, prompt: Prompt, input_text: str) -> str:
with dspy.context(lm=self._lm): with dspy.context(lm=self._lm):
result = self._predictor( result = self._predictor(
instruction=prompt.text, instruction=prompt.text,
input_text=input_text, input_text=input_text,
) )
return str(result.output) return str(result.output)
return retry_with_backoff(
_call,
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)

View File

@@ -6,12 +6,14 @@ Converts trajectories into readable format for the LLM proposer.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import dspy import dspy
from prometheus.domain.entities import Prompt, Trajectory from prometheus.domain.entities import Prompt, Trajectory
from prometheus.domain.ports import ProposerPort from prometheus.domain.ports import ProposerPort
from prometheus.infrastructure.dspy_modules import InstructionProposer from prometheus.infrastructure.dspy_modules import InstructionProposer
from prometheus.infrastructure.retry import retry_with_backoff from prometheus.infrastructure.retry import async_retry_with_backoff
class DSPyProposerAdapter(ProposerPort): class DSPyProposerAdapter(ProposerPort):
@@ -28,7 +30,7 @@ class DSPyProposerAdapter(ProposerPort):
self._max_retries = max_retries self._max_retries = max_retries
self._retry_delay_base = retry_delay_base self._retry_delay_base = retry_delay_base
def propose( async def propose(
self, self,
current_prompt: Prompt, current_prompt: Prompt,
trajectories: list[Trajectory], trajectories: list[Trajectory],
@@ -36,7 +38,18 @@ class DSPyProposerAdapter(ProposerPort):
) -> Prompt: ) -> Prompt:
failure_examples = self._format_failures(trajectories) failure_examples = self._format_failures(trajectories)
def _call() -> Prompt: async def _call() -> Prompt:
return await asyncio.to_thread(
self._sync_propose, current_prompt, task_description, failure_examples,
)
return await async_retry_with_backoff(
_call,
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)
def _sync_propose(self, current_prompt: Prompt, task_description: str, failure_examples: str) -> Prompt:
with dspy.context(lm=self._lm): with dspy.context(lm=self._lm):
pred = self._proposer( pred = self._proposer(
current_instruction=current_prompt.text, current_instruction=current_prompt.text,
@@ -45,12 +58,6 @@ class DSPyProposerAdapter(ProposerPort):
) )
return Prompt(text=pred.new_instruction) return Prompt(text=pred.new_instruction)
return retry_with_backoff(
_call,
max_retries=self._max_retries,
retry_delay_base=self._retry_delay_base,
)
@staticmethod @staticmethod
def _format_failures(trajectories: list[Trajectory]) -> str: def _format_failures(trajectories: list[Trajectory]) -> str:
"""Convert trajectories into a structured textual report.""" """Convert trajectories into a structured textual report."""

View File

@@ -1,9 +1,10 @@
"""Retry with exponential backoff for transient LLM errors.""" """Retry with exponential backoff for transient LLM errors."""
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import time import time
from typing import Any, Callable, TypeVar from typing import Any, Callable, Coroutine, TypeVar
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -71,3 +72,31 @@ def retry_with_backoff(
time.sleep(delay) time.sleep(delay)
# Should not reach here, but satisfy type-checker. # Should not reach here, but satisfy type-checker.
raise TransientError(str(last_exc)) from last_exc raise TransientError(str(last_exc)) from last_exc
async def async_retry_with_backoff(
fn: Callable[..., Coroutine[Any, Any, T]],
*args: Any,
max_retries: int = 3,
retry_delay_base: float = 1.0,
**kwargs: Any,
) -> T:
"""Async version of retry_with_backoff — uses asyncio.sleep instead of time.sleep."""
last_exc: Exception | None = None
for attempt in range(max_retries + 1):
try:
return await fn(*args, **kwargs)
except Exception as exc:
last_exc = exc
if not is_transient_error(exc) or attempt == max_retries:
raise
delay = retry_delay_base * (2 ** attempt)
logger.warning(
"Transient error (attempt %d/%d): %s — retrying in %.1fs",
attempt + 1,
max_retries + 1,
exc,
delay,
)
await asyncio.sleep(delay)
raise TransientError(str(last_exc)) from last_exc

View File

@@ -1,7 +1,7 @@
"""Shared test fixtures.""" """Shared test fixtures."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@@ -66,17 +66,17 @@ def mock_eval_result() -> EvalResult:
@pytest.fixture @pytest.fixture
def mock_llm_port() -> MagicMock: def mock_llm_port() -> AsyncMock:
"""Mock LLMPort that returns canned responses.""" """Mock LLMPort that returns canned responses."""
port = MagicMock() port = AsyncMock()
port.execute.return_value = "This is a mock response." port.execute.return_value = "This is a mock response."
return port return port
@pytest.fixture @pytest.fixture
def mock_judge_port() -> MagicMock: def mock_judge_port() -> AsyncMock:
"""Mock JudgePort that returns moderate scores.""" """Mock JudgePort that returns moderate scores."""
port = MagicMock() port = AsyncMock()
port.judge_batch.return_value = [ port.judge_batch.return_value = [
(0.5, "Moderate quality, needs improvement."), (0.5, "Moderate quality, needs improvement."),
] * 5 ] * 5
@@ -84,9 +84,9 @@ def mock_judge_port() -> MagicMock:
@pytest.fixture @pytest.fixture
def mock_proposer_port() -> MagicMock: def mock_proposer_port() -> AsyncMock:
"""Mock ProposerPort that returns a slightly modified prompt.""" """Mock ProposerPort that returns a slightly modified prompt."""
port = MagicMock() port = AsyncMock()
port.propose.return_value = Prompt( port.propose.return_value = Prompt(
text="You are a very helpful assistant. Answer the question precisely." text="You are a very helpful assistant. Answer the question precisely."
) )

View File

@@ -57,23 +57,25 @@ def synth_lm() -> dspy.LM:
class TestDSPyLLMAdapterOwnLM: class TestDSPyLLMAdapterOwnLM:
"""Bug #2 fix: DSPyLLMAdapter must use the LM it receives, not the global one.""" """Bug #2 fix: DSPyLLMAdapter must use the LM it receives, not the global one."""
def test_uses_provided_lm_not_global(self) -> None: @pytest.mark.asyncio
async def test_uses_provided_lm_not_global(self) -> None:
local_lm = dspy.utils.DummyLM([{"output": "local response"}]) local_lm = dspy.utils.DummyLM([{"output": "local response"}])
global_lm = dspy.utils.DummyLM([{"output": "global response"}]) global_lm = dspy.utils.DummyLM([{"output": "global response"}])
dspy.configure(lm=global_lm) dspy.configure(lm=global_lm)
adapter = DSPyLLMAdapter(lm=local_lm) adapter = DSPyLLMAdapter(lm=local_lm)
result = adapter.execute(Prompt(text="test"), "input") result = await adapter.execute(Prompt(text="test"), "input")
assert result == "local response" assert result == "local response"
def test_does_not_affect_global_lm(self) -> None: @pytest.mark.asyncio
async def test_does_not_affect_global_lm(self) -> None:
local_lm = dspy.utils.DummyLM([{"output": "local response"}]) local_lm = dspy.utils.DummyLM([{"output": "local response"}])
global_lm = dspy.utils.DummyLM([{"output": "global response"}]) global_lm = dspy.utils.DummyLM([{"output": "global response"}])
dspy.configure(lm=global_lm) dspy.configure(lm=global_lm)
adapter = DSPyLLMAdapter(lm=local_lm) adapter = DSPyLLMAdapter(lm=local_lm)
adapter.execute(Prompt(text="test"), "input") await adapter.execute(Prompt(text="test"), "input")
# Global LM should still be the same # Global LM should still be the same
assert dspy.settings.lm is global_lm assert dspy.settings.lm is global_lm
@@ -82,9 +84,10 @@ class TestDSPyLLMAdapterOwnLM:
class TestDSPyJudgeAdapterOwnLM: class TestDSPyJudgeAdapterOwnLM:
"""DSPyJudgeAdapter must use its own LM instance.""" """DSPyJudgeAdapter must use its own LM instance."""
def test_uses_provided_lm(self, judge_lm: dspy.LM) -> None: @pytest.mark.asyncio
async def test_uses_provided_lm(self, judge_lm: dspy.LM) -> None:
adapter = DSPyJudgeAdapter(lm=judge_lm) adapter = DSPyJudgeAdapter(lm=judge_lm)
results = adapter.judge_batch( results = await adapter.judge_batch(
task_description="Test task", task_description="Test task",
pairs=[("input 1", "output 1")], pairs=[("input 1", "output 1")],
) )
@@ -93,7 +96,8 @@ class TestDSPyJudgeAdapterOwnLM:
assert score == 0.8 assert score == 0.8
assert feedback == "Good response." assert feedback == "Good response."
def test_does_not_use_global_lm(self) -> None: @pytest.mark.asyncio
async def test_does_not_use_global_lm(self) -> None:
judge_lm = dspy.utils.DummyLM( judge_lm = dspy.utils.DummyLM(
[{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response"}] [{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response"}]
) )
@@ -101,14 +105,15 @@ class TestDSPyJudgeAdapterOwnLM:
dspy.configure(lm=global_lm) dspy.configure(lm=global_lm)
adapter = DSPyJudgeAdapter(lm=judge_lm) adapter = DSPyJudgeAdapter(lm=judge_lm)
results = adapter.judge_batch("task", [("in", "out")]) results = await adapter.judge_batch("task", [("in", "out")])
assert results[0][0] == 0.9 assert results[0][0] == 0.9
class TestDSPyProposerAdapterOwnLM: class TestDSPyProposerAdapterOwnLM:
"""DSPyProposerAdapter must use its own LM instance.""" """DSPyProposerAdapter must use its own LM instance."""
def test_uses_provided_lm(self, proposer_lm: dspy.LM) -> None: @pytest.mark.asyncio
async def test_uses_provided_lm(self, proposer_lm: dspy.LM) -> None:
adapter = DSPyProposerAdapter(lm=proposer_lm) adapter = DSPyProposerAdapter(lm=proposer_lm)
trajectories = [ trajectories = [
Trajectory( Trajectory(
@@ -119,14 +124,15 @@ class TestDSPyProposerAdapterOwnLM:
prompt_used="old prompt", prompt_used="old prompt",
) )
] ]
result = adapter.propose( result = await adapter.propose(
current_prompt=Prompt(text="old prompt"), current_prompt=Prompt(text="old prompt"),
trajectories=trajectories, trajectories=trajectories,
task_description="Test task", task_description="Test task",
) )
assert "Improved prompt" in result.text assert "Improved prompt" in result.text
def test_does_not_use_global_lm(self) -> None: @pytest.mark.asyncio
async def test_does_not_use_global_lm(self) -> None:
proposer_lm = dspy.utils.DummyLM( proposer_lm = dspy.utils.DummyLM(
[{"reasoning": "ok", "new_instruction": "proposer-specific"}] [{"reasoning": "ok", "new_instruction": "proposer-specific"}]
) )
@@ -136,7 +142,7 @@ class TestDSPyProposerAdapterOwnLM:
dspy.configure(lm=global_lm) dspy.configure(lm=global_lm)
adapter = DSPyProposerAdapter(lm=proposer_lm) adapter = DSPyProposerAdapter(lm=proposer_lm)
result = adapter.propose( result = await adapter.propose(
current_prompt=Prompt(text="test"), current_prompt=Prompt(text="test"),
trajectories=[], trajectories=[],
task_description="task", task_description="task",

View File

@@ -1,7 +1,7 @@
"""Unit tests for error handling: retry, circuit breaker, per-call isolation.""" """Unit tests for error handling: retry, circuit breaker, per-call isolation."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@@ -96,7 +96,8 @@ def _make_eval_result(scores, feedbacks=None):
class TestCircuitBreaker: class TestCircuitBreaker:
def test_trips_on_consecutive_failures(self): @pytest.mark.asyncio
async def test_trips_on_consecutive_failures(self):
"""Loop stops when consecutive failures reach the threshold.""" """Loop stops when consecutive failures reach the threshold."""
initial_eval = _make_eval_result([0.3, 0.4]) initial_eval = _make_eval_result([0.3, 0.4])
evaluator = MagicMock() evaluator = MagicMock()
@@ -109,8 +110,9 @@ class TestCircuitBreaker:
return initial_eval # seed eval succeeds return initial_eval # seed eval succeeds
raise RuntimeError("LLM down") raise RuntimeError("LLM down")
evaluator.evaluate.side_effect = _evaluate evaluator.evaluate = AsyncMock(side_effect=_evaluate)
proposer = MagicMock() proposer = MagicMock()
proposer.propose = AsyncMock()
bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap = MagicMock(spec=SyntheticBootstrap)
loop = EvolutionLoop( loop = EvolutionLoop(
@@ -123,7 +125,7 @@ class TestCircuitBreaker:
error_strategy="skip", error_strategy="skip",
) )
with patch.object(loop, "_log"): with patch.object(loop, "_log"):
state = loop.run( state = await loop.run(
Prompt("test"), Prompt("test"),
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)], [SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
"task", "task",
@@ -135,7 +137,8 @@ class TestCircuitBreaker:
assert len(cb_events) == 1 assert len(cb_events) == 1
assert state.iteration < 10 # stopped early assert state.iteration < 10 # stopped early
def test_abort_raises_on_first_error(self): @pytest.mark.asyncio
async def test_abort_raises_on_first_error(self):
"""With error_strategy=abort, the first error raises immediately.""" """With error_strategy=abort, the first error raises immediately."""
initial_eval = _make_eval_result([0.3, 0.4]) initial_eval = _make_eval_result([0.3, 0.4])
evaluator = MagicMock() evaluator = MagicMock()
@@ -148,8 +151,9 @@ class TestCircuitBreaker:
return initial_eval return initial_eval
raise RuntimeError("LLM down") raise RuntimeError("LLM down")
evaluator.evaluate.side_effect = _evaluate evaluator.evaluate = AsyncMock(side_effect=_evaluate)
proposer = MagicMock() proposer = MagicMock()
proposer.propose = AsyncMock()
bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap = MagicMock(spec=SyntheticBootstrap)
loop = EvolutionLoop( loop = EvolutionLoop(
@@ -163,13 +167,14 @@ class TestCircuitBreaker:
) )
with patch.object(loop, "_log"): with patch.object(loop, "_log"):
with pytest.raises(RuntimeError, match="LLM down"): with pytest.raises(RuntimeError, match="LLM down"):
loop.run( await loop.run(
Prompt("test"), Prompt("test"),
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)], [SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
"task", "task",
) )
def test_resets_on_success(self): @pytest.mark.asyncio
async def test_resets_on_success(self):
"""Consecutive failure counter resets after a successful iteration.""" """Consecutive failure counter resets after a successful iteration."""
initial_eval = _make_eval_result([0.3, 0.4]) initial_eval = _make_eval_result([0.3, 0.4])
good_eval = _make_eval_result([0.8, 0.9]) good_eval = _make_eval_result([0.8, 0.9])
@@ -194,9 +199,9 @@ class TestCircuitBreaker:
return initial_eval # current eval return initial_eval # current eval
return good_eval # new eval return good_eval # new eval
evaluator.evaluate.side_effect = _evaluate evaluator.evaluate = AsyncMock(side_effect=_evaluate)
proposer = MagicMock() proposer = MagicMock()
proposer.propose.return_value = Prompt("better prompt") proposer.propose = AsyncMock(return_value=Prompt("better prompt"))
bootstrap = MagicMock(spec=SyntheticBootstrap) bootstrap = MagicMock(spec=SyntheticBootstrap)
bootstrap.sample_minibatch.return_value = [ bootstrap.sample_minibatch.return_value = [
SyntheticExample(f"in{i}", id=i) for i in range(2) SyntheticExample(f"in{i}", id=i) for i in range(2)
@@ -212,7 +217,7 @@ class TestCircuitBreaker:
error_strategy="skip", error_strategy="skip",
) )
with patch.object(loop, "_log"): with patch.object(loop, "_log"):
state = loop.run( state = await loop.run(
Prompt("test"), Prompt("test"),
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)], [SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
"task", "task",
@@ -230,23 +235,24 @@ class TestCircuitBreaker:
class TestPerCallIsolation: class TestPerCallIsolation:
def test_evaluator_isolates_execution_failure(self): @pytest.mark.asyncio
async def test_evaluator_isolates_execution_failure(self):
"""A failing execution produces a sentinel output, not a crash.""" """A failing execution produces a sentinel output, not a crash."""
executor = MagicMock() executor = MagicMock()
executor.execute.side_effect = [ executor.execute = AsyncMock(side_effect=[
"good output", "good output",
RuntimeError("API error"), RuntimeError("API error"),
"another good output", "another good output",
] ])
judge = MagicMock() judge = MagicMock()
judge.judge_batch.return_value = [ judge.judge_batch = AsyncMock(return_value=[
(0.8, "good"), (0.8, "good"),
(0.0, "[judge error]"), (0.0, "[judge error]"),
(0.7, "ok"), (0.7, "ok"),
] ])
evaluator = PromptEvaluator(executor, judge) evaluator = PromptEvaluator(executor, judge)
result = evaluator.evaluate( result = await evaluator.evaluate(
Prompt("test"), Prompt("test"),
[ [
SyntheticExample("in0", id=0), SyntheticExample("in0", id=0),
@@ -261,7 +267,8 @@ class TestPerCallIsolation:
assert "execution error" in result.trajectories[1].output_text assert "execution error" in result.trajectories[1].output_text
assert result.scores[0] == 0.8 # other items unaffected assert result.scores[0] == 0.8 # other items unaffected
def test_judge_adapter_isolates_single_failure(self): @pytest.mark.asyncio
async def test_judge_adapter_isolates_single_failure(self):
"""DSPyJudgeAdapter returns sentinel for a failed item, not crash.""" """DSPyJudgeAdapter returns sentinel for a failed item, not crash."""
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
@@ -269,6 +276,7 @@ class TestPerCallIsolation:
adapter._lm = MagicMock() adapter._lm = MagicMock()
adapter._max_retries = 1 adapter._max_retries = 1
adapter._retry_delay_base = 0 adapter._retry_delay_base = 0
adapter._semaphore = __import__("asyncio").Semaphore(5)
# Mock _judge to fail on first call, succeed on second # Mock _judge to fail on first call, succeed on second
call_count = 0 call_count = 0
@@ -289,9 +297,10 @@ class TestPerCallIsolation:
with patch("prometheus.infrastructure.judge_adapter.dspy.context"): with patch("prometheus.infrastructure.judge_adapter.dspy.context"):
with patch( with patch(
"prometheus.infrastructure.retry.time.sleep" "prometheus.infrastructure.retry.asyncio.sleep",
new=AsyncMock(),
): ):
results = adapter.judge_batch( results = await adapter.judge_batch(
"task", [("input1", "output1"), ("input2", "output2")] "task", [("input1", "output1"), ("input2", "output2")]
) )

View File

@@ -1,7 +1,7 @@
"""Unit tests for PromptEvaluator.evaluate().""" """Unit tests for PromptEvaluator.evaluate()."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@@ -14,22 +14,23 @@ class TestPromptEvaluatorEvaluate:
"""Tests for the evaluate() pipeline: execute → judge → trajectories.""" """Tests for the evaluate() pipeline: execute → judge → trajectories."""
@pytest.fixture @pytest.fixture
def executor(self) -> MagicMock: def executor(self) -> AsyncMock:
return MagicMock(spec=LLMPort) return AsyncMock(spec=LLMPort)
@pytest.fixture @pytest.fixture
def judge(self) -> MagicMock: def judge(self) -> AsyncMock:
return MagicMock(spec=JudgePort) return AsyncMock(spec=JudgePort)
@pytest.fixture @pytest.fixture
def evaluator(self, executor: MagicMock, judge: MagicMock) -> PromptEvaluator: def evaluator(self, executor: AsyncMock, judge: AsyncMock) -> PromptEvaluator:
return PromptEvaluator(executor=executor, judge=judge) return PromptEvaluator(executor=executor, judge=judge)
def test_happy_path_builds_correct_trajectories( @pytest.mark.asyncio
async def test_happy_path_builds_correct_trajectories(
self, self,
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
executor: MagicMock, executor: AsyncMock,
judge: MagicMock, judge: AsyncMock,
) -> None: ) -> None:
prompt = Prompt(text="Answer the question.") prompt = Prompt(text="Answer the question.")
examples = [ examples = [
@@ -42,7 +43,7 @@ class TestPromptEvaluatorEvaluate:
(0.8, "Mostly correct."), (0.8, "Mostly correct."),
] ]
result = evaluator.evaluate(prompt, examples, "math and geography") result = await evaluator.evaluate(prompt, examples, "math and geography")
assert isinstance(result, EvalResult) assert isinstance(result, EvalResult)
assert result.scores == [0.9, 0.8] assert result.scores == [0.9, 0.8]
@@ -55,14 +56,15 @@ class TestPromptEvaluatorEvaluate:
assert result.trajectories[0].prompt_used == "Answer the question." assert result.trajectories[0].prompt_used == "Answer the question."
assert result.trajectories[1].prompt_used == "Answer the question." assert result.trajectories[1].prompt_used == "Answer the question."
def test_empty_minibatch_returns_empty_result( @pytest.mark.asyncio
async def test_empty_minibatch_returns_empty_result(
self, self,
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
executor: MagicMock, executor: AsyncMock,
judge: MagicMock, judge: AsyncMock,
) -> None: ) -> None:
prompt = Prompt(text="test") prompt = Prompt(text="test")
result = evaluator.evaluate(prompt, [], "task") result = await evaluator.evaluate(prompt, [], "task")
assert result.scores == [] assert result.scores == []
assert result.feedbacks == [] assert result.feedbacks == []
@@ -71,41 +73,44 @@ class TestPromptEvaluatorEvaluate:
# judge_batch is called with empty pairs list # judge_batch is called with empty pairs list
judge.judge_batch.assert_called_once_with("task", []) judge.judge_batch.assert_called_once_with("task", [])
def test_executor_called_with_correct_prompt( @pytest.mark.asyncio
async def test_executor_called_with_correct_prompt(
self, self,
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
executor: MagicMock, executor: AsyncMock,
judge: MagicMock, judge: AsyncMock,
) -> None: ) -> None:
prompt = Prompt(text="Summarize this.") prompt = Prompt(text="Summarize this.")
examples = [SyntheticExample(input_text="Long text here", id=0)] examples = [SyntheticExample(input_text="Long text here", id=0)]
executor.execute.return_value = "Summary." executor.execute.return_value = "Summary."
judge.judge_batch.return_value = [(0.7, "Good summary.")] judge.judge_batch.return_value = [(0.7, "Good summary.")]
evaluator.evaluate(prompt, examples, "summarization") await evaluator.evaluate(prompt, examples, "summarization")
executor.execute.assert_called_once_with(prompt, "Long text here") executor.execute.assert_called_once_with(prompt, "Long text here")
def test_trajectories_prompt_used_matches_input_prompt( @pytest.mark.asyncio
async def test_trajectories_prompt_used_matches_input_prompt(
self, self,
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
executor: MagicMock, executor: AsyncMock,
judge: MagicMock, judge: AsyncMock,
) -> None: ) -> None:
prompt = Prompt(text="Translate to French.") prompt = Prompt(text="Translate to French.")
examples = [SyntheticExample(input_text="Hello", id=0)] examples = [SyntheticExample(input_text="Hello", id=0)]
executor.execute.return_value = "Bonjour" executor.execute.return_value = "Bonjour"
judge.judge_batch.return_value = [(1.0, "Perfect.")] judge.judge_batch.return_value = [(1.0, "Perfect.")]
result = evaluator.evaluate(prompt, examples, "translation") result = await evaluator.evaluate(prompt, examples, "translation")
assert result.trajectories[0].prompt_used == "Translate to French." assert result.trajectories[0].prompt_used == "Translate to French."
def test_scores_feedbacks_trajectories_lists_sized_correctly( @pytest.mark.asyncio
async def test_scores_feedbacks_trajectories_lists_sized_correctly(
self, self,
evaluator: PromptEvaluator, evaluator: PromptEvaluator,
executor: MagicMock, executor: AsyncMock,
judge: MagicMock, judge: AsyncMock,
) -> None: ) -> None:
prompt = Prompt(text="test prompt") prompt = Prompt(text="test prompt")
examples = [SyntheticExample(input_text=f"q{i}", id=i) for i in range(4)] examples = [SyntheticExample(input_text=f"q{i}", id=i) for i in range(4)]
@@ -114,7 +119,7 @@ class TestPromptEvaluatorEvaluate:
(0.1 * i, f"fb{i}") for i in range(4) (0.1 * i, f"fb{i}") for i in range(4)
] ]
result = evaluator.evaluate(prompt, examples, "task") result = await evaluator.evaluate(prompt, examples, "task")
assert len(result.scores) == 4 assert len(result.scores) == 4
assert len(result.feedbacks) == 4 assert len(result.feedbacks) == 4

View File

@@ -1,7 +1,9 @@
"""Unit tests for the evolution loop — with full mocking.""" """Unit tests for the evolution loop — with full mocking."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prometheus.application.bootstrap import SyntheticBootstrap from prometheus.application.bootstrap import SyntheticBootstrap
from prometheus.application.evaluator import PromptEvaluator from prometheus.application.evaluator import PromptEvaluator
@@ -10,14 +12,15 @@ from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Tra
class TestEvolutionLoop: class TestEvolutionLoop:
def test_accepts_improvement( @pytest.mark.asyncio
async def test_accepts_improvement(
self, self,
seed_prompt: Prompt, seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample], synthetic_pool: list[SyntheticExample],
task_description: str, task_description: str,
mock_llm_port: MagicMock, mock_llm_port: AsyncMock,
mock_judge_port: MagicMock, mock_judge_port: AsyncMock,
mock_proposer_port: MagicMock, mock_proposer_port: AsyncMock,
) -> None: ) -> None:
"""When the new prompt improves the score, the best candidate is updated.""" """When the new prompt improves the score, the best candidate is updated."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
@@ -45,7 +48,7 @@ class TestEvolutionLoop:
feedbacks=["good"] * 5, feedbacks=["good"] * 5,
trajectories=[], trajectories=[],
) )
evaluator.evaluate = MagicMock(side_effect=[initial_eval, old_eval, new_eval]) evaluator.evaluate = AsyncMock(side_effect=[initial_eval, old_eval, new_eval])
loop = EvolutionLoop( loop = EvolutionLoop(
evaluator=evaluator, evaluator=evaluator,
@@ -55,19 +58,20 @@ class TestEvolutionLoop:
minibatch_size=5, minibatch_size=5,
) )
with patch.object(loop, "_log"): with patch.object(loop, "_log"):
state = loop.run(seed_prompt, synthetic_pool, task_description) state = await loop.run(seed_prompt, synthetic_pool, task_description)
assert state.best_candidate is not None assert state.best_candidate is not None
assert state.best_candidate.best_score > 0 assert state.best_candidate.best_score > 0
def test_rejects_regression( @pytest.mark.asyncio
async def test_rejects_regression(
self, self,
seed_prompt: Prompt, seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample], synthetic_pool: list[SyntheticExample],
task_description: str, task_description: str,
mock_llm_port: MagicMock, mock_llm_port: AsyncMock,
mock_judge_port: MagicMock, mock_judge_port: AsyncMock,
mock_proposer_port: MagicMock, mock_proposer_port: AsyncMock,
) -> None: ) -> None:
"""When the new prompt degrades the score, the best candidate stays unchanged.""" """When the new prompt degrades the score, the best candidate stays unchanged."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
@@ -95,7 +99,7 @@ class TestEvolutionLoop:
feedbacks=["bad"] * 5, feedbacks=["bad"] * 5,
trajectories=[], trajectories=[],
) )
evaluator.evaluate = MagicMock(side_effect=[initial_eval, old_eval, new_eval]) evaluator.evaluate = AsyncMock(side_effect=[initial_eval, old_eval, new_eval])
loop = EvolutionLoop( loop = EvolutionLoop(
evaluator=evaluator, evaluator=evaluator,
@@ -105,19 +109,20 @@ class TestEvolutionLoop:
minibatch_size=5, minibatch_size=5,
) )
with patch.object(loop, "_log"): with patch.object(loop, "_log"):
state = loop.run(seed_prompt, synthetic_pool, task_description) state = await loop.run(seed_prompt, synthetic_pool, task_description)
assert state.best_candidate is not None assert state.best_candidate is not None
assert state.best_candidate.prompt.text == seed_prompt.text assert state.best_candidate.prompt.text == seed_prompt.text
def test_skips_perfect_scores( @pytest.mark.asyncio
async def test_skips_perfect_scores(
self, self,
seed_prompt: Prompt, seed_prompt: Prompt,
synthetic_pool: list[SyntheticExample], synthetic_pool: list[SyntheticExample],
task_description: str, task_description: str,
mock_llm_port: MagicMock, mock_llm_port: AsyncMock,
mock_judge_port: MagicMock, mock_judge_port: AsyncMock,
mock_proposer_port: MagicMock, mock_proposer_port: AsyncMock,
) -> None: ) -> None:
"""When all scores are perfect, no proposition is made.""" """When all scores are perfect, no proposition is made."""
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port) evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
@@ -132,7 +137,7 @@ class TestEvolutionLoop:
for i in range(5) for i in range(5)
], ],
) )
evaluator.evaluate = MagicMock(return_value=perfect_eval) evaluator.evaluate = AsyncMock(return_value=perfect_eval)
loop = EvolutionLoop( loop = EvolutionLoop(
evaluator=evaluator, evaluator=evaluator,
@@ -142,6 +147,6 @@ class TestEvolutionLoop:
minibatch_size=5, minibatch_size=5,
) )
with patch.object(loop, "_log"): with patch.object(loop, "_log"):
loop.run(seed_prompt, synthetic_pool, task_description) await loop.run(seed_prompt, synthetic_pool, task_description)
mock_proposer_port.propose.assert_not_called() mock_proposer_port.propose.assert_not_called()