feat: async/parallel execution with configurable concurrency
Parallelize LLM calls across minibatches to reduce wall-clock time. All domain ports (LLMPort, JudgePort, ProposerPort) are now async. Adapter implementations wrap synchronous DSPy calls with asyncio.to_thread. Judge calls run in parallel within a batch using asyncio.gather + semaphore. Evaluator parallelizes minibatch execution with configurable concurrency. Evolution loop and use case are fully async. Proposer stays sequential. Added --max-concurrency CLI flag and max_concurrency YAML config field. Added async_retry_with_backoff for async error handling. All 139 unit tests pass. Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -16,6 +16,7 @@ dependencies = [
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=8.3",
|
"pytest>=8.3",
|
||||||
|
"pytest-asyncio>=0.24",
|
||||||
"pytest-cov>=6.0",
|
"pytest-cov>=6.0",
|
||||||
"ruff>=0.9",
|
"ruff>=0.9",
|
||||||
"mypy>=1.14",
|
"mypy>=1.14",
|
||||||
@@ -37,6 +38,9 @@ target-version = "py312"
|
|||||||
python_version = "3.12"
|
python_version = "3.12"
|
||||||
strict = true
|
strict = true
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = ["dspy", "dspy.*"]
|
module = ["dspy", "dspy.*"]
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ class OptimizationConfig:
|
|||||||
# --- Reproducibility ---
|
# --- Reproducibility ---
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
|
|
||||||
|
# --- Concurrency ---
|
||||||
|
max_concurrency: int = 5
|
||||||
|
|
||||||
# --- Error handling ---
|
# --- Error handling ---
|
||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
retry_delay_base: float = 1.0
|
retry_delay_base: float = 1.0
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ Combines candidate prompt execution + LLM-as-Judge evaluation.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from prometheus.domain.entities import (
|
from prometheus.domain.entities import (
|
||||||
@@ -26,15 +27,22 @@ class PromptEvaluator:
|
|||||||
Replaces GEPA's EvaluatorFn. Instead of comparing to ground truth,
|
Replaces GEPA's EvaluatorFn. Instead of comparing to ground truth,
|
||||||
uses an LLM-as-Judge.
|
uses an LLM-as-Judge.
|
||||||
|
|
||||||
|
Execution and judge calls run in parallel (bounded by *max_concurrency*).
|
||||||
Per-call isolation: a failure on one minibatch item produces a
|
Per-call isolation: a failure on one minibatch item produces a
|
||||||
zero-score trajectory instead of crashing the whole batch.
|
zero-score trajectory instead of crashing the whole batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, executor: LLMPort, judge: JudgePort):
|
def __init__(
|
||||||
|
self,
|
||||||
|
executor: LLMPort,
|
||||||
|
judge: JudgePort,
|
||||||
|
max_concurrency: int = 5,
|
||||||
|
):
|
||||||
self._executor = executor
|
self._executor = executor
|
||||||
self._judge = judge
|
self._judge = judge
|
||||||
|
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
def evaluate(
|
async def evaluate(
|
||||||
self,
|
self,
|
||||||
prompt: Prompt,
|
prompt: Prompt,
|
||||||
minibatch: list[SyntheticExample],
|
minibatch: list[SyntheticExample],
|
||||||
@@ -43,27 +51,20 @@ class PromptEvaluator:
|
|||||||
"""Evaluate the prompt on the minibatch.
|
"""Evaluate the prompt on the minibatch.
|
||||||
|
|
||||||
Steps:
|
Steps:
|
||||||
1. Execute the prompt on each input in the minibatch
|
1. Execute the prompt on each input in the minibatch (parallel)
|
||||||
2. Judge each (input, output) pair
|
2. Judge each (input, output) pair
|
||||||
3. Build trajectories with feedback
|
3. Build trajectories with feedback
|
||||||
"""
|
"""
|
||||||
# Step 1: Execution (per-item isolation)
|
# Step 1: Parallel execution (per-item isolation)
|
||||||
outputs: list[str] = []
|
output_coros = [
|
||||||
for example in minibatch:
|
self._execute_single(prompt, example)
|
||||||
try:
|
for example in minibatch
|
||||||
raw_output = self._executor.execute(prompt, example.input_text)
|
]
|
||||||
outputs.append(raw_output)
|
outputs = await asyncio.gather(*output_coros)
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"Execution failed for input '%s…': %s",
|
|
||||||
example.input_text[:40],
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
outputs.append(f"[execution error: {exc}]")
|
|
||||||
|
|
||||||
# Step 2: Judgement (judge_adapter handles its own per-call isolation)
|
# Step 2: Judgement (judge_adapter handles its own per-call isolation + parallelism)
|
||||||
pairs = [(ex.input_text, out) for ex, out in zip(minibatch, outputs)]
|
pairs = [(ex.input_text, out) for ex, out in zip(minibatch, outputs)]
|
||||||
judge_results = self._judge.judge_batch(task_description, pairs)
|
judge_results = await self._judge.judge_batch(task_description, pairs)
|
||||||
|
|
||||||
# Step 3: Build trajectories
|
# Step 3: Build trajectories
|
||||||
scores: list[float] = []
|
scores: list[float] = []
|
||||||
@@ -88,3 +89,17 @@ class PromptEvaluator:
|
|||||||
feedbacks=feedbacks,
|
feedbacks=feedbacks,
|
||||||
trajectories=trajectories,
|
trajectories=trajectories,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _execute_single(
|
||||||
|
self, prompt: Prompt, example: SyntheticExample
|
||||||
|
) -> str:
|
||||||
|
async with self._semaphore:
|
||||||
|
try:
|
||||||
|
return await self._executor.execute(prompt, example.input_text)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Execution failed for input '%s…': %s",
|
||||||
|
example.input_text[:40],
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return f"[execution error: {exc}]"
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class EvolutionLoop:
|
|||||||
self._circuit_breaker_threshold = circuit_breaker_threshold
|
self._circuit_breaker_threshold = circuit_breaker_threshold
|
||||||
self._error_strategy = error_strategy
|
self._error_strategy = error_strategy
|
||||||
|
|
||||||
def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
seed_prompt: Prompt,
|
seed_prompt: Prompt,
|
||||||
synthetic_pool: list[SyntheticExample],
|
synthetic_pool: list[SyntheticExample],
|
||||||
@@ -76,7 +76,7 @@ class EvolutionLoop:
|
|||||||
initial_batch = self._bootstrap.sample_minibatch(
|
initial_batch = self._bootstrap.sample_minibatch(
|
||||||
synthetic_pool, self._minibatch_size
|
synthetic_pool, self._minibatch_size
|
||||||
)
|
)
|
||||||
initial_eval = self._evaluator.evaluate(
|
initial_eval = await self._evaluator.evaluate(
|
||||||
seed_prompt, initial_batch, task_description
|
seed_prompt, initial_batch, task_description
|
||||||
)
|
)
|
||||||
state.total_llm_calls += 2 * self._minibatch_size # N executions + N judge calls
|
state.total_llm_calls += 2 * self._minibatch_size # N executions + N judge calls
|
||||||
@@ -95,7 +95,7 @@ class EvolutionLoop:
|
|||||||
state.iteration = i
|
state.iteration = i
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._run_iteration(
|
await self._run_iteration(
|
||||||
i, state, best_candidate, synthetic_pool, task_description
|
i, state, best_candidate, synthetic_pool, task_description
|
||||||
)
|
)
|
||||||
# Update best_candidate from state after successful iteration
|
# Update best_candidate from state after successful iteration
|
||||||
@@ -144,7 +144,7 @@ class EvolutionLoop:
|
|||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def _run_iteration(
|
async def _run_iteration(
|
||||||
self,
|
self,
|
||||||
i: int,
|
i: int,
|
||||||
state: OptimizationState,
|
state: OptimizationState,
|
||||||
@@ -159,7 +159,7 @@ class EvolutionLoop:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 2. Evaluate the current candidate
|
# 2. Evaluate the current candidate
|
||||||
current_eval = self._evaluator.evaluate(
|
current_eval = await self._evaluator.evaluate(
|
||||||
best_candidate.prompt, batch, task_description
|
best_candidate.prompt, batch, task_description
|
||||||
)
|
)
|
||||||
state.total_llm_calls += 2 * self._minibatch_size
|
state.total_llm_calls += 2 * self._minibatch_size
|
||||||
@@ -176,8 +176,8 @@ class EvolutionLoop:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 4. Propose a new prompt (reflective mutation)
|
# 4. Propose a new prompt (reflective mutation) — sequential
|
||||||
new_prompt = self._proposer.propose(
|
new_prompt = await self._proposer.propose(
|
||||||
best_candidate.prompt,
|
best_candidate.prompt,
|
||||||
current_eval.trajectories,
|
current_eval.trajectories,
|
||||||
task_description,
|
task_description,
|
||||||
@@ -185,7 +185,7 @@ class EvolutionLoop:
|
|||||||
state.total_llm_calls += 1 # 1 proposition call
|
state.total_llm_calls += 1 # 1 proposition call
|
||||||
|
|
||||||
# 5. Evaluate the new prompt on the same minibatch
|
# 5. Evaluate the new prompt on the same minibatch
|
||||||
new_eval = self._evaluator.evaluate(
|
new_eval = await self._evaluator.evaluate(
|
||||||
new_prompt, batch, task_description
|
new_prompt, batch, task_description
|
||||||
)
|
)
|
||||||
state.total_llm_calls += 2 * self._minibatch_size
|
state.total_llm_calls += 2 * self._minibatch_size
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class OptimizePromptUseCase:
|
|||||||
self._proposer = proposer
|
self._proposer = proposer
|
||||||
self._bootstrap = bootstrap
|
self._bootstrap = bootstrap
|
||||||
|
|
||||||
def execute(self, config: OptimizationConfig) -> OptimizationResult:
|
async def execute(self, config: OptimizationConfig) -> OptimizationResult:
|
||||||
"""Full pipeline:
|
"""Full pipeline:
|
||||||
1. Bootstrap → generate synthetic inputs
|
1. Bootstrap → generate synthetic inputs
|
||||||
2. Evolution → optimization loop
|
2. Evolution → optimization loop
|
||||||
@@ -55,7 +55,7 @@ class OptimizePromptUseCase:
|
|||||||
error_strategy=config.error_strategy,
|
error_strategy=config.error_strategy,
|
||||||
)
|
)
|
||||||
seed_prompt = Prompt(text=config.seed_prompt)
|
seed_prompt = Prompt(text=config.seed_prompt)
|
||||||
state = loop.run(seed_prompt, synthetic_pool, config.task_description)
|
state = await loop.run(seed_prompt, synthetic_pool, config.task_description)
|
||||||
|
|
||||||
# Phase 2: Result
|
# Phase 2: Result
|
||||||
initial_score = (
|
initial_score = (
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Typer interface with -i (input) and -o (output) options.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
@@ -66,12 +67,30 @@ def optimize(
|
|||||||
"--error-strategy",
|
"--error-strategy",
|
||||||
help="How to handle errors: skip | retry | abort.",
|
help="How to handle errors: skip | retry | abort.",
|
||||||
),
|
),
|
||||||
|
max_concurrency: int = typer.Option(
|
||||||
|
5,
|
||||||
|
"--max-concurrency",
|
||||||
|
help="Max parallel LLM calls for minibatch execution and judging.",
|
||||||
|
),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Optimize a prompt without any reference data.
|
"""Optimize a prompt without any reference data.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
prometheus optimize -i config.yaml -o result.yaml
|
prometheus optimize -i config.yaml -o result.yaml
|
||||||
"""
|
"""
|
||||||
|
asyncio.run(
|
||||||
|
_async_optimize(input, output, verbose, max_retries, error_strategy, max_concurrency)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_optimize(
|
||||||
|
input: str,
|
||||||
|
output: str,
|
||||||
|
verbose: bool,
|
||||||
|
max_retries: int,
|
||||||
|
error_strategy: str,
|
||||||
|
max_concurrency: int,
|
||||||
|
) -> None:
|
||||||
# Configure verbose logging
|
# Configure verbose logging
|
||||||
if verbose:
|
if verbose:
|
||||||
logging.basicConfig(level=logging.INFO, format="[PROMETHEUS] %(message)s")
|
logging.basicConfig(level=logging.INFO, format="[PROMETHEUS] %(message)s")
|
||||||
@@ -129,6 +148,7 @@ def optimize(
|
|||||||
retry_delay_base=raw_config.get("retry_delay_base", 1.0),
|
retry_delay_base=raw_config.get("retry_delay_base", 1.0),
|
||||||
circuit_breaker_threshold=raw_config.get("circuit_breaker_threshold", 5),
|
circuit_breaker_threshold=raw_config.get("circuit_breaker_threshold", 5),
|
||||||
error_strategy=raw_config.get("error_strategy", error_strategy),
|
error_strategy=raw_config.get("error_strategy", error_strategy),
|
||||||
|
max_concurrency=raw_config.get("max_concurrency", max_concurrency),
|
||||||
output_path=output,
|
output_path=output,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
@@ -164,6 +184,7 @@ def optimize(
|
|||||||
lm=judge_lm,
|
lm=judge_lm,
|
||||||
max_retries=config.max_retries,
|
max_retries=config.max_retries,
|
||||||
retry_delay_base=config.retry_delay_base,
|
retry_delay_base=config.retry_delay_base,
|
||||||
|
max_concurrency=config.max_concurrency,
|
||||||
)
|
)
|
||||||
proposer_adapter = DSPyProposerAdapter(
|
proposer_adapter = DSPyProposerAdapter(
|
||||||
lm=proposer_lm,
|
lm=proposer_lm,
|
||||||
@@ -171,7 +192,11 @@ def optimize(
|
|||||||
retry_delay_base=config.retry_delay_base,
|
retry_delay_base=config.retry_delay_base,
|
||||||
)
|
)
|
||||||
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
||||||
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter)
|
evaluator = PromptEvaluator(
|
||||||
|
executor=llm_adapter,
|
||||||
|
judge=judge_adapter,
|
||||||
|
max_concurrency=config.max_concurrency,
|
||||||
|
)
|
||||||
use_case = OptimizePromptUseCase(
|
use_case = OptimizePromptUseCase(
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
proposer=proposer_adapter,
|
proposer=proposer_adapter,
|
||||||
@@ -180,7 +205,7 @@ def optimize(
|
|||||||
|
|
||||||
# 4. Execute
|
# 4. Execute
|
||||||
with console.status("[bold green]Evolving prompt..."):
|
with console.status("[bold green]Evolving prompt..."):
|
||||||
result = use_case.execute(config)
|
result = await use_case.execute(config)
|
||||||
|
|
||||||
# 5. Display results
|
# 5. Display results
|
||||||
_display_result(result)
|
_display_result(result)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class LLMPort(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute(self, prompt: Prompt, input_text: str) -> str:
|
async def execute(self, prompt: Prompt, input_text: str) -> str:
|
||||||
"""Execute the prompt on the input, return the raw response."""
|
"""Execute the prompt on the input, return the raw response."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ class JudgePort(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def judge_batch(
|
async def judge_batch(
|
||||||
self,
|
self,
|
||||||
task_description: str,
|
task_description: str,
|
||||||
pairs: list[tuple[str, str]],
|
pairs: list[tuple[str, str]],
|
||||||
@@ -50,7 +50,7 @@ class ProposerPort(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def propose(
|
async def propose(
|
||||||
self,
|
self,
|
||||||
current_prompt: Prompt,
|
current_prompt: Prompt,
|
||||||
trajectories: list[Trajectory],
|
trajectories: list[Trajectory],
|
||||||
|
|||||||
@@ -5,13 +5,15 @@ Implements the JudgePort via the DSPy OutputJudge module.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
import dspy
|
import dspy
|
||||||
|
|
||||||
from prometheus.domain.ports import JudgePort
|
from prometheus.domain.ports import JudgePort
|
||||||
from prometheus.infrastructure.dspy_modules import OutputJudge
|
from prometheus.infrastructure.dspy_modules import OutputJudge
|
||||||
from prometheus.infrastructure.retry import retry_with_backoff
|
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -21,6 +23,8 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
|
|
||||||
Per-call isolation: a failure on one item returns a zero-score sentinel
|
Per-call isolation: a failure on one item returns a zero-score sentinel
|
||||||
instead of crashing the whole batch.
|
instead of crashing the whole batch.
|
||||||
|
|
||||||
|
Judge calls run in parallel (bounded by *max_concurrency*).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -28,40 +32,60 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
lm: dspy.LM,
|
lm: dspy.LM,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay_base: float = 1.0,
|
retry_delay_base: float = 1.0,
|
||||||
|
max_concurrency: int = 5,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._lm = lm
|
self._lm = lm
|
||||||
self._judge = OutputJudge()
|
self._judge = OutputJudge()
|
||||||
self._max_retries = max_retries
|
self._max_retries = max_retries
|
||||||
self._retry_delay_base = retry_delay_base
|
self._retry_delay_base = retry_delay_base
|
||||||
|
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
def judge_batch(
|
async def judge_batch(
|
||||||
self,
|
self,
|
||||||
task_description: str,
|
task_description: str,
|
||||||
pairs: list[tuple[str, str]],
|
pairs: list[tuple[str, str]],
|
||||||
) -> list[tuple[float, str]]:
|
) -> list[tuple[float, str]]:
|
||||||
results: list[tuple[float, str]] = []
|
tasks = [
|
||||||
with dspy.context(lm=self._lm):
|
self._judge_single_safe(task_description, input_text, output_text)
|
||||||
for input_text, output_text in pairs:
|
for input_text, output_text in pairs
|
||||||
results.append(self._judge_single(task_description, input_text, output_text))
|
]
|
||||||
return results
|
return list(await asyncio.gather(*tasks))
|
||||||
|
|
||||||
def _judge_single(
|
async def _judge_single_safe(
|
||||||
self,
|
self,
|
||||||
task_description: str,
|
task_description: str,
|
||||||
input_text: str,
|
input_text: str,
|
||||||
output_text: str,
|
output_text: str,
|
||||||
) -> tuple[float, str]:
|
) -> tuple[float, str]:
|
||||||
|
async with self._semaphore:
|
||||||
try:
|
try:
|
||||||
pred = retry_with_backoff(
|
return await self._judge_single(task_description, input_text, output_text)
|
||||||
lambda: self._judge(
|
|
||||||
task_description=task_description,
|
|
||||||
input_text=input_text,
|
|
||||||
output_text=output_text,
|
|
||||||
),
|
|
||||||
max_retries=self._max_retries,
|
|
||||||
retry_delay_base=self._retry_delay_base,
|
|
||||||
)
|
|
||||||
return (pred.score, pred.feedback)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Judge call failed for input '%s…': %s", input_text[:40], exc)
|
logger.warning("Judge call failed for input '%s…': %s", input_text[:40], exc)
|
||||||
return (0.0, f"[judge error: {exc}]")
|
return (0.0, f"[judge error: {exc}]")
|
||||||
|
|
||||||
|
async def _judge_single(
|
||||||
|
self,
|
||||||
|
task_description: str,
|
||||||
|
input_text: str,
|
||||||
|
output_text: str,
|
||||||
|
) -> tuple[float, str]:
|
||||||
|
async def _call() -> tuple[float, str]:
|
||||||
|
pred = await asyncio.to_thread(
|
||||||
|
self._sync_judge, task_description, input_text, output_text,
|
||||||
|
)
|
||||||
|
return (pred.score, pred.feedback)
|
||||||
|
|
||||||
|
return await async_retry_with_backoff(
|
||||||
|
_call,
|
||||||
|
max_retries=self._max_retries,
|
||||||
|
retry_delay_base=self._retry_delay_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sync_judge(self, task_description: str, input_text: str, output_text: str):
|
||||||
|
with dspy.context(lm=self._lm):
|
||||||
|
return self._judge(
|
||||||
|
task_description=task_description,
|
||||||
|
input_text=input_text,
|
||||||
|
output_text=output_text,
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ Implements the LLMPort via DSPy.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import dspy
|
import dspy
|
||||||
|
|
||||||
from prometheus.domain.entities import Prompt
|
from prometheus.domain.entities import Prompt
|
||||||
from prometheus.domain.ports import LLMPort
|
from prometheus.domain.ports import LLMPort
|
||||||
from prometheus.infrastructure.retry import retry_with_backoff
|
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||||
|
|
||||||
|
|
||||||
class DSPyLLMAdapter(LLMPort):
|
class DSPyLLMAdapter(LLMPort):
|
||||||
@@ -33,17 +35,21 @@ class DSPyLLMAdapter(LLMPort):
|
|||||||
self._max_retries = max_retries
|
self._max_retries = max_retries
|
||||||
self._retry_delay_base = retry_delay_base
|
self._retry_delay_base = retry_delay_base
|
||||||
|
|
||||||
def execute(self, prompt: Prompt, input_text: str) -> str:
|
async def execute(self, prompt: Prompt, input_text: str) -> str:
|
||||||
def _call() -> str:
|
async def _call() -> str:
|
||||||
|
# DSPy is synchronous — run in a thread to avoid blocking the event loop.
|
||||||
|
return await asyncio.to_thread(self._sync_execute, prompt, input_text)
|
||||||
|
|
||||||
|
return await async_retry_with_backoff(
|
||||||
|
_call,
|
||||||
|
max_retries=self._max_retries,
|
||||||
|
retry_delay_base=self._retry_delay_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sync_execute(self, prompt: Prompt, input_text: str) -> str:
|
||||||
with dspy.context(lm=self._lm):
|
with dspy.context(lm=self._lm):
|
||||||
result = self._predictor(
|
result = self._predictor(
|
||||||
instruction=prompt.text,
|
instruction=prompt.text,
|
||||||
input_text=input_text,
|
input_text=input_text,
|
||||||
)
|
)
|
||||||
return str(result.output)
|
return str(result.output)
|
||||||
|
|
||||||
return retry_with_backoff(
|
|
||||||
_call,
|
|
||||||
max_retries=self._max_retries,
|
|
||||||
retry_delay_base=self._retry_delay_base,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ Converts trajectories into readable format for the LLM proposer.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import dspy
|
import dspy
|
||||||
|
|
||||||
from prometheus.domain.entities import Prompt, Trajectory
|
from prometheus.domain.entities import Prompt, Trajectory
|
||||||
from prometheus.domain.ports import ProposerPort
|
from prometheus.domain.ports import ProposerPort
|
||||||
from prometheus.infrastructure.dspy_modules import InstructionProposer
|
from prometheus.infrastructure.dspy_modules import InstructionProposer
|
||||||
from prometheus.infrastructure.retry import retry_with_backoff
|
from prometheus.infrastructure.retry import async_retry_with_backoff
|
||||||
|
|
||||||
|
|
||||||
class DSPyProposerAdapter(ProposerPort):
|
class DSPyProposerAdapter(ProposerPort):
|
||||||
@@ -28,7 +30,7 @@ class DSPyProposerAdapter(ProposerPort):
|
|||||||
self._max_retries = max_retries
|
self._max_retries = max_retries
|
||||||
self._retry_delay_base = retry_delay_base
|
self._retry_delay_base = retry_delay_base
|
||||||
|
|
||||||
def propose(
|
async def propose(
|
||||||
self,
|
self,
|
||||||
current_prompt: Prompt,
|
current_prompt: Prompt,
|
||||||
trajectories: list[Trajectory],
|
trajectories: list[Trajectory],
|
||||||
@@ -36,7 +38,18 @@ class DSPyProposerAdapter(ProposerPort):
|
|||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
failure_examples = self._format_failures(trajectories)
|
failure_examples = self._format_failures(trajectories)
|
||||||
|
|
||||||
def _call() -> Prompt:
|
async def _call() -> Prompt:
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
self._sync_propose, current_prompt, task_description, failure_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await async_retry_with_backoff(
|
||||||
|
_call,
|
||||||
|
max_retries=self._max_retries,
|
||||||
|
retry_delay_base=self._retry_delay_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sync_propose(self, current_prompt: Prompt, task_description: str, failure_examples: str) -> Prompt:
|
||||||
with dspy.context(lm=self._lm):
|
with dspy.context(lm=self._lm):
|
||||||
pred = self._proposer(
|
pred = self._proposer(
|
||||||
current_instruction=current_prompt.text,
|
current_instruction=current_prompt.text,
|
||||||
@@ -45,12 +58,6 @@ class DSPyProposerAdapter(ProposerPort):
|
|||||||
)
|
)
|
||||||
return Prompt(text=pred.new_instruction)
|
return Prompt(text=pred.new_instruction)
|
||||||
|
|
||||||
return retry_with_backoff(
|
|
||||||
_call,
|
|
||||||
max_retries=self._max_retries,
|
|
||||||
retry_delay_base=self._retry_delay_base,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_failures(trajectories: list[Trajectory]) -> str:
|
def _format_failures(trajectories: list[Trajectory]) -> str:
|
||||||
"""Convert trajectories into a structured textual report."""
|
"""Convert trajectories into a structured textual report."""
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Retry with exponential backoff for transient LLM errors."""
|
"""Retry with exponential backoff for transient LLM errors."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, TypeVar
|
from typing import Any, Callable, Coroutine, TypeVar
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,3 +72,31 @@ def retry_with_backoff(
|
|||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
# Should not reach here, but satisfy type-checker.
|
# Should not reach here, but satisfy type-checker.
|
||||||
raise TransientError(str(last_exc)) from last_exc
|
raise TransientError(str(last_exc)) from last_exc
|
||||||
|
|
||||||
|
|
||||||
|
async def async_retry_with_backoff(
|
||||||
|
fn: Callable[..., Coroutine[Any, Any, T]],
|
||||||
|
*args: Any,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay_base: float = 1.0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> T:
|
||||||
|
"""Async version of retry_with_backoff — uses asyncio.sleep instead of time.sleep."""
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
return await fn(*args, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
if not is_transient_error(exc) or attempt == max_retries:
|
||||||
|
raise
|
||||||
|
delay = retry_delay_base * (2 ** attempt)
|
||||||
|
logger.warning(
|
||||||
|
"Transient error (attempt %d/%d): %s — retrying in %.1fs",
|
||||||
|
attempt + 1,
|
||||||
|
max_retries + 1,
|
||||||
|
exc,
|
||||||
|
delay,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
raise TransientError(str(last_exc)) from last_exc
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Shared test fixtures."""
|
"""Shared test fixtures."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -66,17 +66,17 @@ def mock_eval_result() -> EvalResult:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_llm_port() -> MagicMock:
|
def mock_llm_port() -> AsyncMock:
|
||||||
"""Mock LLMPort that returns canned responses."""
|
"""Mock LLMPort that returns canned responses."""
|
||||||
port = MagicMock()
|
port = AsyncMock()
|
||||||
port.execute.return_value = "This is a mock response."
|
port.execute.return_value = "This is a mock response."
|
||||||
return port
|
return port
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_judge_port() -> MagicMock:
|
def mock_judge_port() -> AsyncMock:
|
||||||
"""Mock JudgePort that returns moderate scores."""
|
"""Mock JudgePort that returns moderate scores."""
|
||||||
port = MagicMock()
|
port = AsyncMock()
|
||||||
port.judge_batch.return_value = [
|
port.judge_batch.return_value = [
|
||||||
(0.5, "Moderate quality, needs improvement."),
|
(0.5, "Moderate quality, needs improvement."),
|
||||||
] * 5
|
] * 5
|
||||||
@@ -84,9 +84,9 @@ def mock_judge_port() -> MagicMock:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_proposer_port() -> MagicMock:
|
def mock_proposer_port() -> AsyncMock:
|
||||||
"""Mock ProposerPort that returns a slightly modified prompt."""
|
"""Mock ProposerPort that returns a slightly modified prompt."""
|
||||||
port = MagicMock()
|
port = AsyncMock()
|
||||||
port.propose.return_value = Prompt(
|
port.propose.return_value = Prompt(
|
||||||
text="You are a very helpful assistant. Answer the question precisely."
|
text="You are a very helpful assistant. Answer the question precisely."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,23 +57,25 @@ def synth_lm() -> dspy.LM:
|
|||||||
class TestDSPyLLMAdapterOwnLM:
|
class TestDSPyLLMAdapterOwnLM:
|
||||||
"""Bug #2 fix: DSPyLLMAdapter must use the LM it receives, not the global one."""
|
"""Bug #2 fix: DSPyLLMAdapter must use the LM it receives, not the global one."""
|
||||||
|
|
||||||
def test_uses_provided_lm_not_global(self) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_uses_provided_lm_not_global(self) -> None:
|
||||||
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
|
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
|
||||||
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
|
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
|
||||||
dspy.configure(lm=global_lm)
|
dspy.configure(lm=global_lm)
|
||||||
|
|
||||||
adapter = DSPyLLMAdapter(lm=local_lm)
|
adapter = DSPyLLMAdapter(lm=local_lm)
|
||||||
result = adapter.execute(Prompt(text="test"), "input")
|
result = await adapter.execute(Prompt(text="test"), "input")
|
||||||
|
|
||||||
assert result == "local response"
|
assert result == "local response"
|
||||||
|
|
||||||
def test_does_not_affect_global_lm(self) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_does_not_affect_global_lm(self) -> None:
|
||||||
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
|
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
|
||||||
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
|
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
|
||||||
dspy.configure(lm=global_lm)
|
dspy.configure(lm=global_lm)
|
||||||
|
|
||||||
adapter = DSPyLLMAdapter(lm=local_lm)
|
adapter = DSPyLLMAdapter(lm=local_lm)
|
||||||
adapter.execute(Prompt(text="test"), "input")
|
await adapter.execute(Prompt(text="test"), "input")
|
||||||
|
|
||||||
# Global LM should still be the same
|
# Global LM should still be the same
|
||||||
assert dspy.settings.lm is global_lm
|
assert dspy.settings.lm is global_lm
|
||||||
@@ -82,9 +84,10 @@ class TestDSPyLLMAdapterOwnLM:
|
|||||||
class TestDSPyJudgeAdapterOwnLM:
|
class TestDSPyJudgeAdapterOwnLM:
|
||||||
"""DSPyJudgeAdapter must use its own LM instance."""
|
"""DSPyJudgeAdapter must use its own LM instance."""
|
||||||
|
|
||||||
def test_uses_provided_lm(self, judge_lm: dspy.LM) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_uses_provided_lm(self, judge_lm: dspy.LM) -> None:
|
||||||
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||||
results = adapter.judge_batch(
|
results = await adapter.judge_batch(
|
||||||
task_description="Test task",
|
task_description="Test task",
|
||||||
pairs=[("input 1", "output 1")],
|
pairs=[("input 1", "output 1")],
|
||||||
)
|
)
|
||||||
@@ -93,7 +96,8 @@ class TestDSPyJudgeAdapterOwnLM:
|
|||||||
assert score == 0.8
|
assert score == 0.8
|
||||||
assert feedback == "Good response."
|
assert feedback == "Good response."
|
||||||
|
|
||||||
def test_does_not_use_global_lm(self) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_does_not_use_global_lm(self) -> None:
|
||||||
judge_lm = dspy.utils.DummyLM(
|
judge_lm = dspy.utils.DummyLM(
|
||||||
[{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response"}]
|
[{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response"}]
|
||||||
)
|
)
|
||||||
@@ -101,14 +105,15 @@ class TestDSPyJudgeAdapterOwnLM:
|
|||||||
dspy.configure(lm=global_lm)
|
dspy.configure(lm=global_lm)
|
||||||
|
|
||||||
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||||
results = adapter.judge_batch("task", [("in", "out")])
|
results = await adapter.judge_batch("task", [("in", "out")])
|
||||||
assert results[0][0] == 0.9
|
assert results[0][0] == 0.9
|
||||||
|
|
||||||
|
|
||||||
class TestDSPyProposerAdapterOwnLM:
|
class TestDSPyProposerAdapterOwnLM:
|
||||||
"""DSPyProposerAdapter must use its own LM instance."""
|
"""DSPyProposerAdapter must use its own LM instance."""
|
||||||
|
|
||||||
def test_uses_provided_lm(self, proposer_lm: dspy.LM) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_uses_provided_lm(self, proposer_lm: dspy.LM) -> None:
|
||||||
adapter = DSPyProposerAdapter(lm=proposer_lm)
|
adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||||
trajectories = [
|
trajectories = [
|
||||||
Trajectory(
|
Trajectory(
|
||||||
@@ -119,14 +124,15 @@ class TestDSPyProposerAdapterOwnLM:
|
|||||||
prompt_used="old prompt",
|
prompt_used="old prompt",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
result = adapter.propose(
|
result = await adapter.propose(
|
||||||
current_prompt=Prompt(text="old prompt"),
|
current_prompt=Prompt(text="old prompt"),
|
||||||
trajectories=trajectories,
|
trajectories=trajectories,
|
||||||
task_description="Test task",
|
task_description="Test task",
|
||||||
)
|
)
|
||||||
assert "Improved prompt" in result.text
|
assert "Improved prompt" in result.text
|
||||||
|
|
||||||
def test_does_not_use_global_lm(self) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_does_not_use_global_lm(self) -> None:
|
||||||
proposer_lm = dspy.utils.DummyLM(
|
proposer_lm = dspy.utils.DummyLM(
|
||||||
[{"reasoning": "ok", "new_instruction": "proposer-specific"}]
|
[{"reasoning": "ok", "new_instruction": "proposer-specific"}]
|
||||||
)
|
)
|
||||||
@@ -136,7 +142,7 @@ class TestDSPyProposerAdapterOwnLM:
|
|||||||
dspy.configure(lm=global_lm)
|
dspy.configure(lm=global_lm)
|
||||||
|
|
||||||
adapter = DSPyProposerAdapter(lm=proposer_lm)
|
adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||||
result = adapter.propose(
|
result = await adapter.propose(
|
||||||
current_prompt=Prompt(text="test"),
|
current_prompt=Prompt(text="test"),
|
||||||
trajectories=[],
|
trajectories=[],
|
||||||
task_description="task",
|
task_description="task",
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Unit tests for error handling: retry, circuit breaker, per-call isolation."""
|
"""Unit tests for error handling: retry, circuit breaker, per-call isolation."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -96,7 +96,8 @@ def _make_eval_result(scores, feedbacks=None):
|
|||||||
|
|
||||||
|
|
||||||
class TestCircuitBreaker:
|
class TestCircuitBreaker:
|
||||||
def test_trips_on_consecutive_failures(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_trips_on_consecutive_failures(self):
|
||||||
"""Loop stops when consecutive failures reach the threshold."""
|
"""Loop stops when consecutive failures reach the threshold."""
|
||||||
initial_eval = _make_eval_result([0.3, 0.4])
|
initial_eval = _make_eval_result([0.3, 0.4])
|
||||||
evaluator = MagicMock()
|
evaluator = MagicMock()
|
||||||
@@ -109,8 +110,9 @@ class TestCircuitBreaker:
|
|||||||
return initial_eval # seed eval succeeds
|
return initial_eval # seed eval succeeds
|
||||||
raise RuntimeError("LLM down")
|
raise RuntimeError("LLM down")
|
||||||
|
|
||||||
evaluator.evaluate.side_effect = _evaluate
|
evaluator.evaluate = AsyncMock(side_effect=_evaluate)
|
||||||
proposer = MagicMock()
|
proposer = MagicMock()
|
||||||
|
proposer.propose = AsyncMock()
|
||||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
|
||||||
loop = EvolutionLoop(
|
loop = EvolutionLoop(
|
||||||
@@ -123,7 +125,7 @@ class TestCircuitBreaker:
|
|||||||
error_strategy="skip",
|
error_strategy="skip",
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
with patch.object(loop, "_log"):
|
||||||
state = loop.run(
|
state = await loop.run(
|
||||||
Prompt("test"),
|
Prompt("test"),
|
||||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||||
"task",
|
"task",
|
||||||
@@ -135,7 +137,8 @@ class TestCircuitBreaker:
|
|||||||
assert len(cb_events) == 1
|
assert len(cb_events) == 1
|
||||||
assert state.iteration < 10 # stopped early
|
assert state.iteration < 10 # stopped early
|
||||||
|
|
||||||
def test_abort_raises_on_first_error(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_abort_raises_on_first_error(self):
|
||||||
"""With error_strategy=abort, the first error raises immediately."""
|
"""With error_strategy=abort, the first error raises immediately."""
|
||||||
initial_eval = _make_eval_result([0.3, 0.4])
|
initial_eval = _make_eval_result([0.3, 0.4])
|
||||||
evaluator = MagicMock()
|
evaluator = MagicMock()
|
||||||
@@ -148,8 +151,9 @@ class TestCircuitBreaker:
|
|||||||
return initial_eval
|
return initial_eval
|
||||||
raise RuntimeError("LLM down")
|
raise RuntimeError("LLM down")
|
||||||
|
|
||||||
evaluator.evaluate.side_effect = _evaluate
|
evaluator.evaluate = AsyncMock(side_effect=_evaluate)
|
||||||
proposer = MagicMock()
|
proposer = MagicMock()
|
||||||
|
proposer.propose = AsyncMock()
|
||||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
|
||||||
loop = EvolutionLoop(
|
loop = EvolutionLoop(
|
||||||
@@ -163,13 +167,14 @@ class TestCircuitBreaker:
|
|||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
with patch.object(loop, "_log"):
|
||||||
with pytest.raises(RuntimeError, match="LLM down"):
|
with pytest.raises(RuntimeError, match="LLM down"):
|
||||||
loop.run(
|
await loop.run(
|
||||||
Prompt("test"),
|
Prompt("test"),
|
||||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||||
"task",
|
"task",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_resets_on_success(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_resets_on_success(self):
|
||||||
"""Consecutive failure counter resets after a successful iteration."""
|
"""Consecutive failure counter resets after a successful iteration."""
|
||||||
initial_eval = _make_eval_result([0.3, 0.4])
|
initial_eval = _make_eval_result([0.3, 0.4])
|
||||||
good_eval = _make_eval_result([0.8, 0.9])
|
good_eval = _make_eval_result([0.8, 0.9])
|
||||||
@@ -194,9 +199,9 @@ class TestCircuitBreaker:
|
|||||||
return initial_eval # current eval
|
return initial_eval # current eval
|
||||||
return good_eval # new eval
|
return good_eval # new eval
|
||||||
|
|
||||||
evaluator.evaluate.side_effect = _evaluate
|
evaluator.evaluate = AsyncMock(side_effect=_evaluate)
|
||||||
proposer = MagicMock()
|
proposer = MagicMock()
|
||||||
proposer.propose.return_value = Prompt("better prompt")
|
proposer.propose = AsyncMock(return_value=Prompt("better prompt"))
|
||||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
bootstrap.sample_minibatch.return_value = [
|
bootstrap.sample_minibatch.return_value = [
|
||||||
SyntheticExample(f"in{i}", id=i) for i in range(2)
|
SyntheticExample(f"in{i}", id=i) for i in range(2)
|
||||||
@@ -212,7 +217,7 @@ class TestCircuitBreaker:
|
|||||||
error_strategy="skip",
|
error_strategy="skip",
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
with patch.object(loop, "_log"):
|
||||||
state = loop.run(
|
state = await loop.run(
|
||||||
Prompt("test"),
|
Prompt("test"),
|
||||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||||
"task",
|
"task",
|
||||||
@@ -230,23 +235,24 @@ class TestCircuitBreaker:
|
|||||||
|
|
||||||
|
|
||||||
class TestPerCallIsolation:
|
class TestPerCallIsolation:
|
||||||
def test_evaluator_isolates_execution_failure(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_evaluator_isolates_execution_failure(self):
|
||||||
"""A failing execution produces a sentinel output, not a crash."""
|
"""A failing execution produces a sentinel output, not a crash."""
|
||||||
executor = MagicMock()
|
executor = MagicMock()
|
||||||
executor.execute.side_effect = [
|
executor.execute = AsyncMock(side_effect=[
|
||||||
"good output",
|
"good output",
|
||||||
RuntimeError("API error"),
|
RuntimeError("API error"),
|
||||||
"another good output",
|
"another good output",
|
||||||
]
|
])
|
||||||
judge = MagicMock()
|
judge = MagicMock()
|
||||||
judge.judge_batch.return_value = [
|
judge.judge_batch = AsyncMock(return_value=[
|
||||||
(0.8, "good"),
|
(0.8, "good"),
|
||||||
(0.0, "[judge error]"),
|
(0.0, "[judge error]"),
|
||||||
(0.7, "ok"),
|
(0.7, "ok"),
|
||||||
]
|
])
|
||||||
|
|
||||||
evaluator = PromptEvaluator(executor, judge)
|
evaluator = PromptEvaluator(executor, judge)
|
||||||
result = evaluator.evaluate(
|
result = await evaluator.evaluate(
|
||||||
Prompt("test"),
|
Prompt("test"),
|
||||||
[
|
[
|
||||||
SyntheticExample("in0", id=0),
|
SyntheticExample("in0", id=0),
|
||||||
@@ -261,7 +267,8 @@ class TestPerCallIsolation:
|
|||||||
assert "execution error" in result.trajectories[1].output_text
|
assert "execution error" in result.trajectories[1].output_text
|
||||||
assert result.scores[0] == 0.8 # other items unaffected
|
assert result.scores[0] == 0.8 # other items unaffected
|
||||||
|
|
||||||
def test_judge_adapter_isolates_single_failure(self):
|
@pytest.mark.asyncio
|
||||||
|
async def test_judge_adapter_isolates_single_failure(self):
|
||||||
"""DSPyJudgeAdapter returns sentinel for a failed item, not crash."""
|
"""DSPyJudgeAdapter returns sentinel for a failed item, not crash."""
|
||||||
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||||
|
|
||||||
@@ -269,6 +276,7 @@ class TestPerCallIsolation:
|
|||||||
adapter._lm = MagicMock()
|
adapter._lm = MagicMock()
|
||||||
adapter._max_retries = 1
|
adapter._max_retries = 1
|
||||||
adapter._retry_delay_base = 0
|
adapter._retry_delay_base = 0
|
||||||
|
adapter._semaphore = __import__("asyncio").Semaphore(5)
|
||||||
|
|
||||||
# Mock _judge to fail on first call, succeed on second
|
# Mock _judge to fail on first call, succeed on second
|
||||||
call_count = 0
|
call_count = 0
|
||||||
@@ -289,9 +297,10 @@ class TestPerCallIsolation:
|
|||||||
|
|
||||||
with patch("prometheus.infrastructure.judge_adapter.dspy.context"):
|
with patch("prometheus.infrastructure.judge_adapter.dspy.context"):
|
||||||
with patch(
|
with patch(
|
||||||
"prometheus.infrastructure.retry.time.sleep"
|
"prometheus.infrastructure.retry.asyncio.sleep",
|
||||||
|
new=AsyncMock(),
|
||||||
):
|
):
|
||||||
results = adapter.judge_batch(
|
results = await adapter.judge_batch(
|
||||||
"task", [("input1", "output1"), ("input2", "output2")]
|
"task", [("input1", "output1"), ("input2", "output2")]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Unit tests for PromptEvaluator.evaluate()."""
|
"""Unit tests for PromptEvaluator.evaluate()."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -14,22 +14,23 @@ class TestPromptEvaluatorEvaluate:
|
|||||||
"""Tests for the evaluate() pipeline: execute → judge → trajectories."""
|
"""Tests for the evaluate() pipeline: execute → judge → trajectories."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def executor(self) -> MagicMock:
|
def executor(self) -> AsyncMock:
|
||||||
return MagicMock(spec=LLMPort)
|
return AsyncMock(spec=LLMPort)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def judge(self) -> MagicMock:
|
def judge(self) -> AsyncMock:
|
||||||
return MagicMock(spec=JudgePort)
|
return AsyncMock(spec=JudgePort)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def evaluator(self, executor: MagicMock, judge: MagicMock) -> PromptEvaluator:
|
def evaluator(self, executor: AsyncMock, judge: AsyncMock) -> PromptEvaluator:
|
||||||
return PromptEvaluator(executor=executor, judge=judge)
|
return PromptEvaluator(executor=executor, judge=judge)
|
||||||
|
|
||||||
def test_happy_path_builds_correct_trajectories(
|
@pytest.mark.asyncio
|
||||||
|
async def test_happy_path_builds_correct_trajectories(
|
||||||
self,
|
self,
|
||||||
evaluator: PromptEvaluator,
|
evaluator: PromptEvaluator,
|
||||||
executor: MagicMock,
|
executor: AsyncMock,
|
||||||
judge: MagicMock,
|
judge: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = Prompt(text="Answer the question.")
|
prompt = Prompt(text="Answer the question.")
|
||||||
examples = [
|
examples = [
|
||||||
@@ -42,7 +43,7 @@ class TestPromptEvaluatorEvaluate:
|
|||||||
(0.8, "Mostly correct."),
|
(0.8, "Mostly correct."),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = evaluator.evaluate(prompt, examples, "math and geography")
|
result = await evaluator.evaluate(prompt, examples, "math and geography")
|
||||||
|
|
||||||
assert isinstance(result, EvalResult)
|
assert isinstance(result, EvalResult)
|
||||||
assert result.scores == [0.9, 0.8]
|
assert result.scores == [0.9, 0.8]
|
||||||
@@ -55,14 +56,15 @@ class TestPromptEvaluatorEvaluate:
|
|||||||
assert result.trajectories[0].prompt_used == "Answer the question."
|
assert result.trajectories[0].prompt_used == "Answer the question."
|
||||||
assert result.trajectories[1].prompt_used == "Answer the question."
|
assert result.trajectories[1].prompt_used == "Answer the question."
|
||||||
|
|
||||||
def test_empty_minibatch_returns_empty_result(
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_minibatch_returns_empty_result(
|
||||||
self,
|
self,
|
||||||
evaluator: PromptEvaluator,
|
evaluator: PromptEvaluator,
|
||||||
executor: MagicMock,
|
executor: AsyncMock,
|
||||||
judge: MagicMock,
|
judge: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = Prompt(text="test")
|
prompt = Prompt(text="test")
|
||||||
result = evaluator.evaluate(prompt, [], "task")
|
result = await evaluator.evaluate(prompt, [], "task")
|
||||||
|
|
||||||
assert result.scores == []
|
assert result.scores == []
|
||||||
assert result.feedbacks == []
|
assert result.feedbacks == []
|
||||||
@@ -71,41 +73,44 @@ class TestPromptEvaluatorEvaluate:
|
|||||||
# judge_batch is called with empty pairs list
|
# judge_batch is called with empty pairs list
|
||||||
judge.judge_batch.assert_called_once_with("task", [])
|
judge.judge_batch.assert_called_once_with("task", [])
|
||||||
|
|
||||||
def test_executor_called_with_correct_prompt(
|
@pytest.mark.asyncio
|
||||||
|
async def test_executor_called_with_correct_prompt(
|
||||||
self,
|
self,
|
||||||
evaluator: PromptEvaluator,
|
evaluator: PromptEvaluator,
|
||||||
executor: MagicMock,
|
executor: AsyncMock,
|
||||||
judge: MagicMock,
|
judge: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = Prompt(text="Summarize this.")
|
prompt = Prompt(text="Summarize this.")
|
||||||
examples = [SyntheticExample(input_text="Long text here", id=0)]
|
examples = [SyntheticExample(input_text="Long text here", id=0)]
|
||||||
executor.execute.return_value = "Summary."
|
executor.execute.return_value = "Summary."
|
||||||
judge.judge_batch.return_value = [(0.7, "Good summary.")]
|
judge.judge_batch.return_value = [(0.7, "Good summary.")]
|
||||||
|
|
||||||
evaluator.evaluate(prompt, examples, "summarization")
|
await evaluator.evaluate(prompt, examples, "summarization")
|
||||||
|
|
||||||
executor.execute.assert_called_once_with(prompt, "Long text here")
|
executor.execute.assert_called_once_with(prompt, "Long text here")
|
||||||
|
|
||||||
def test_trajectories_prompt_used_matches_input_prompt(
|
@pytest.mark.asyncio
|
||||||
|
async def test_trajectories_prompt_used_matches_input_prompt(
|
||||||
self,
|
self,
|
||||||
evaluator: PromptEvaluator,
|
evaluator: PromptEvaluator,
|
||||||
executor: MagicMock,
|
executor: AsyncMock,
|
||||||
judge: MagicMock,
|
judge: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = Prompt(text="Translate to French.")
|
prompt = Prompt(text="Translate to French.")
|
||||||
examples = [SyntheticExample(input_text="Hello", id=0)]
|
examples = [SyntheticExample(input_text="Hello", id=0)]
|
||||||
executor.execute.return_value = "Bonjour"
|
executor.execute.return_value = "Bonjour"
|
||||||
judge.judge_batch.return_value = [(1.0, "Perfect.")]
|
judge.judge_batch.return_value = [(1.0, "Perfect.")]
|
||||||
|
|
||||||
result = evaluator.evaluate(prompt, examples, "translation")
|
result = await evaluator.evaluate(prompt, examples, "translation")
|
||||||
|
|
||||||
assert result.trajectories[0].prompt_used == "Translate to French."
|
assert result.trajectories[0].prompt_used == "Translate to French."
|
||||||
|
|
||||||
def test_scores_feedbacks_trajectories_lists_sized_correctly(
|
@pytest.mark.asyncio
|
||||||
|
async def test_scores_feedbacks_trajectories_lists_sized_correctly(
|
||||||
self,
|
self,
|
||||||
evaluator: PromptEvaluator,
|
evaluator: PromptEvaluator,
|
||||||
executor: MagicMock,
|
executor: AsyncMock,
|
||||||
judge: MagicMock,
|
judge: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = Prompt(text="test prompt")
|
prompt = Prompt(text="test prompt")
|
||||||
examples = [SyntheticExample(input_text=f"q{i}", id=i) for i in range(4)]
|
examples = [SyntheticExample(input_text=f"q{i}", id=i) for i in range(4)]
|
||||||
@@ -114,7 +119,7 @@ class TestPromptEvaluatorEvaluate:
|
|||||||
(0.1 * i, f"fb{i}") for i in range(4)
|
(0.1 * i, f"fb{i}") for i in range(4)
|
||||||
]
|
]
|
||||||
|
|
||||||
result = evaluator.evaluate(prompt, examples, "task")
|
result = await evaluator.evaluate(prompt, examples, "task")
|
||||||
|
|
||||||
assert len(result.scores) == 4
|
assert len(result.scores) == 4
|
||||||
assert len(result.feedbacks) == 4
|
assert len(result.feedbacks) == 4
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""Unit tests for the evolution loop — with full mocking."""
|
"""Unit tests for the evolution loop — with full mocking."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||||
from prometheus.application.evaluator import PromptEvaluator
|
from prometheus.application.evaluator import PromptEvaluator
|
||||||
@@ -10,14 +12,15 @@ from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Tra
|
|||||||
|
|
||||||
|
|
||||||
class TestEvolutionLoop:
|
class TestEvolutionLoop:
|
||||||
def test_accepts_improvement(
|
@pytest.mark.asyncio
|
||||||
|
async def test_accepts_improvement(
|
||||||
self,
|
self,
|
||||||
seed_prompt: Prompt,
|
seed_prompt: Prompt,
|
||||||
synthetic_pool: list[SyntheticExample],
|
synthetic_pool: list[SyntheticExample],
|
||||||
task_description: str,
|
task_description: str,
|
||||||
mock_llm_port: MagicMock,
|
mock_llm_port: AsyncMock,
|
||||||
mock_judge_port: MagicMock,
|
mock_judge_port: AsyncMock,
|
||||||
mock_proposer_port: MagicMock,
|
mock_proposer_port: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""When the new prompt improves the score, the best candidate is updated."""
|
"""When the new prompt improves the score, the best candidate is updated."""
|
||||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
@@ -45,7 +48,7 @@ class TestEvolutionLoop:
|
|||||||
feedbacks=["good"] * 5,
|
feedbacks=["good"] * 5,
|
||||||
trajectories=[],
|
trajectories=[],
|
||||||
)
|
)
|
||||||
evaluator.evaluate = MagicMock(side_effect=[initial_eval, old_eval, new_eval])
|
evaluator.evaluate = AsyncMock(side_effect=[initial_eval, old_eval, new_eval])
|
||||||
|
|
||||||
loop = EvolutionLoop(
|
loop = EvolutionLoop(
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
@@ -55,19 +58,20 @@ class TestEvolutionLoop:
|
|||||||
minibatch_size=5,
|
minibatch_size=5,
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
with patch.object(loop, "_log"):
|
||||||
state = loop.run(seed_prompt, synthetic_pool, task_description)
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
assert state.best_candidate is not None
|
assert state.best_candidate is not None
|
||||||
assert state.best_candidate.best_score > 0
|
assert state.best_candidate.best_score > 0
|
||||||
|
|
||||||
def test_rejects_regression(
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_regression(
|
||||||
self,
|
self,
|
||||||
seed_prompt: Prompt,
|
seed_prompt: Prompt,
|
||||||
synthetic_pool: list[SyntheticExample],
|
synthetic_pool: list[SyntheticExample],
|
||||||
task_description: str,
|
task_description: str,
|
||||||
mock_llm_port: MagicMock,
|
mock_llm_port: AsyncMock,
|
||||||
mock_judge_port: MagicMock,
|
mock_judge_port: AsyncMock,
|
||||||
mock_proposer_port: MagicMock,
|
mock_proposer_port: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""When the new prompt degrades the score, the best candidate stays unchanged."""
|
"""When the new prompt degrades the score, the best candidate stays unchanged."""
|
||||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
@@ -95,7 +99,7 @@ class TestEvolutionLoop:
|
|||||||
feedbacks=["bad"] * 5,
|
feedbacks=["bad"] * 5,
|
||||||
trajectories=[],
|
trajectories=[],
|
||||||
)
|
)
|
||||||
evaluator.evaluate = MagicMock(side_effect=[initial_eval, old_eval, new_eval])
|
evaluator.evaluate = AsyncMock(side_effect=[initial_eval, old_eval, new_eval])
|
||||||
|
|
||||||
loop = EvolutionLoop(
|
loop = EvolutionLoop(
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
@@ -105,19 +109,20 @@ class TestEvolutionLoop:
|
|||||||
minibatch_size=5,
|
minibatch_size=5,
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
with patch.object(loop, "_log"):
|
||||||
state = loop.run(seed_prompt, synthetic_pool, task_description)
|
state = await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
assert state.best_candidate is not None
|
assert state.best_candidate is not None
|
||||||
assert state.best_candidate.prompt.text == seed_prompt.text
|
assert state.best_candidate.prompt.text == seed_prompt.text
|
||||||
|
|
||||||
def test_skips_perfect_scores(
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_perfect_scores(
|
||||||
self,
|
self,
|
||||||
seed_prompt: Prompt,
|
seed_prompt: Prompt,
|
||||||
synthetic_pool: list[SyntheticExample],
|
synthetic_pool: list[SyntheticExample],
|
||||||
task_description: str,
|
task_description: str,
|
||||||
mock_llm_port: MagicMock,
|
mock_llm_port: AsyncMock,
|
||||||
mock_judge_port: MagicMock,
|
mock_judge_port: AsyncMock,
|
||||||
mock_proposer_port: MagicMock,
|
mock_proposer_port: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""When all scores are perfect, no proposition is made."""
|
"""When all scores are perfect, no proposition is made."""
|
||||||
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
evaluator = PromptEvaluator(mock_llm_port, mock_judge_port)
|
||||||
@@ -132,7 +137,7 @@ class TestEvolutionLoop:
|
|||||||
for i in range(5)
|
for i in range(5)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
evaluator.evaluate = MagicMock(return_value=perfect_eval)
|
evaluator.evaluate = AsyncMock(return_value=perfect_eval)
|
||||||
|
|
||||||
loop = EvolutionLoop(
|
loop = EvolutionLoop(
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
@@ -142,6 +147,6 @@ class TestEvolutionLoop:
|
|||||||
minibatch_size=5,
|
minibatch_size=5,
|
||||||
)
|
)
|
||||||
with patch.object(loop, "_log"):
|
with patch.object(loop, "_log"):
|
||||||
loop.run(seed_prompt, synthetic_pool, task_description)
|
await loop.run(seed_prompt, synthetic_pool, task_description)
|
||||||
|
|
||||||
mock_proposer_port.propose.assert_not_called()
|
mock_proposer_port.propose.assert_not_called()
|
||||||
|
|||||||
Reference in New Issue
Block a user