feat: error handling, retry with backoff, and circuit breaker
Add robust error handling to the evolution loop and LLM adapters: - Retry utility with exponential backoff for transient errors (429, 5xx, timeouts) - Per-call error isolation in evaluator and judge adapter - Circuit breaker in EvolutionLoop (trips after N consecutive failures) - CLI flags: --max-retries, --error-strategy (skip|retry|abort) - Config fields: max_retries, retry_delay_base, circuit_breaker_threshold, error_strategy - 16 new unit tests covering all error handling paths Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -38,6 +38,12 @@ class OptimizationConfig:
|
||||
# --- Reproducibility ---
|
||||
seed: int = 42
|
||||
|
||||
# --- Error handling ---
|
||||
max_retries: int = 3
|
||||
retry_delay_base: float = 1.0
|
||||
circuit_breaker_threshold: int = 5
|
||||
error_strategy: str = "retry" # skip | retry | abort
|
||||
|
||||
# --- Output ---
|
||||
output_path: str = "output.yaml"
|
||||
verbose: bool = False
|
||||
|
||||
@@ -6,6 +6,8 @@ Combines candidate prompt execution + LLM-as-Judge evaluation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from prometheus.domain.entities import (
|
||||
EvalResult,
|
||||
Prompt,
|
||||
@@ -14,6 +16,8 @@ from prometheus.domain.entities import (
|
||||
)
|
||||
from prometheus.domain.ports import JudgePort, LLMPort
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptEvaluator:
|
||||
"""Evaluates a prompt on a minibatch of synthetic inputs.
|
||||
@@ -21,6 +25,9 @@ class PromptEvaluator:
|
||||
Pipeline: execute → judge → build trajectories.
|
||||
Replaces GEPA's EvaluatorFn. Instead of comparing to ground truth,
|
||||
uses an LLM-as-Judge.
|
||||
|
||||
Per-call isolation: a failure on one minibatch item produces a
|
||||
zero-score trajectory instead of crashing the whole batch.
|
||||
"""
|
||||
|
||||
def __init__(self, executor: LLMPort, judge: JudgePort):
|
||||
@@ -40,13 +47,21 @@ class PromptEvaluator:
|
||||
2. Judge each (input, output) pair
|
||||
3. Build trajectories with feedback
|
||||
"""
|
||||
# Step 1: Execution
|
||||
# Step 1: Execution (per-item isolation)
|
||||
outputs: list[str] = []
|
||||
for example in minibatch:
|
||||
raw_output = self._executor.execute(prompt, example.input_text)
|
||||
outputs.append(raw_output)
|
||||
try:
|
||||
raw_output = self._executor.execute(prompt, example.input_text)
|
||||
outputs.append(raw_output)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Execution failed for input '%s…': %s",
|
||||
example.input_text[:40],
|
||||
exc,
|
||||
)
|
||||
outputs.append(f"[execution error: {exc}]")
|
||||
|
||||
# Step 2: Judgement
|
||||
# Step 2: Judgement (judge_adapter handles its own per-call isolation)
|
||||
pairs = [(ex.input_text, out) for ex, out in zip(minibatch, outputs)]
|
||||
judge_results = self._judge.judge_batch(task_description, pairs)
|
||||
|
||||
|
||||
@@ -22,6 +22,10 @@ from prometheus.domain.scoring import should_accept
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitBreakerOpen(Exception):
|
||||
"""Raised when the circuit breaker trips due to too many consecutive failures."""
|
||||
|
||||
|
||||
class EvolutionLoop:
|
||||
"""Main evolution loop.
|
||||
|
||||
@@ -29,6 +33,11 @@ class EvolutionLoop:
|
||||
- Keeps only the best candidate (no full population).
|
||||
- Simplifies vs GEPA (no Pareto, no merge).
|
||||
- Population support deferred to v2.
|
||||
|
||||
Error handling:
|
||||
- Transient errors are retried by adapters.
|
||||
- Circuit breaker trips after N consecutive iteration failures.
|
||||
- error_strategy controls what happens on non-transient errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -40,6 +49,8 @@ class EvolutionLoop:
|
||||
minibatch_size: int = 5,
|
||||
perfect_score: float = 1.0,
|
||||
verbose: bool = False,
|
||||
circuit_breaker_threshold: int = 5,
|
||||
error_strategy: str = "retry",
|
||||
):
|
||||
self._evaluator = evaluator
|
||||
self._proposer = proposer
|
||||
@@ -48,6 +59,8 @@ class EvolutionLoop:
|
||||
self._minibatch_size = minibatch_size
|
||||
self._perfect_score = perfect_score
|
||||
self._verbose = verbose
|
||||
self._circuit_breaker_threshold = circuit_breaker_threshold
|
||||
self._error_strategy = error_strategy
|
||||
|
||||
def run(
|
||||
self,
|
||||
@@ -57,6 +70,7 @@ class EvolutionLoop:
|
||||
) -> OptimizationState:
|
||||
"""Execute the complete evolution loop."""
|
||||
state = OptimizationState()
|
||||
consecutive_failures = 0
|
||||
|
||||
# Evaluate the seed
|
||||
initial_batch = self._bootstrap.sample_minibatch(
|
||||
@@ -81,94 +95,139 @@ class EvolutionLoop:
|
||||
state.iteration = i
|
||||
|
||||
try:
|
||||
# 1. Sample a fresh minibatch
|
||||
batch = self._bootstrap.sample_minibatch(
|
||||
synthetic_pool, self._minibatch_size
|
||||
self._run_iteration(
|
||||
i, state, best_candidate, synthetic_pool, task_description
|
||||
)
|
||||
|
||||
# 2. Evaluate the current candidate
|
||||
current_eval = self._evaluator.evaluate(
|
||||
best_candidate.prompt, batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * self._minibatch_size
|
||||
|
||||
# 3. Skip if perfect
|
||||
if all(s >= self._perfect_score for s in current_eval.scores):
|
||||
self._log(f"Iter {i}: All scores perfect, skipping.")
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "skip_perfect",
|
||||
"current_score": current_eval.total_score,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 4. Propose a new prompt (reflective mutation)
|
||||
new_prompt = self._proposer.propose(
|
||||
best_candidate.prompt,
|
||||
current_eval.trajectories,
|
||||
task_description,
|
||||
)
|
||||
state.total_llm_calls += 1 # 1 proposition call
|
||||
|
||||
# 5. Evaluate the new prompt on the same minibatch
|
||||
new_eval = self._evaluator.evaluate(
|
||||
new_prompt, batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * self._minibatch_size
|
||||
|
||||
# 6. Accept or reject
|
||||
if should_accept(current_eval, new_eval):
|
||||
best_candidate = Candidate(
|
||||
prompt=new_prompt,
|
||||
best_score=new_eval.total_score,
|
||||
generation=i,
|
||||
parent_id=id(best_candidate),
|
||||
)
|
||||
state.best_candidate = best_candidate
|
||||
state.candidates.append(best_candidate)
|
||||
self._log(
|
||||
f"Iter {i}: ACCEPTED "
|
||||
f"({current_eval.total_score:.2f} -> {new_eval.total_score:.2f})"
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "accepted",
|
||||
"old_score": current_eval.total_score,
|
||||
"new_score": new_eval.total_score,
|
||||
"improvement": new_eval.total_score
|
||||
- current_eval.total_score,
|
||||
}
|
||||
)
|
||||
else:
|
||||
self._log(
|
||||
f"Iter {i}: REJECTED "
|
||||
f"({new_eval.total_score:.2f} <= {current_eval.total_score:.2f})"
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "rejected",
|
||||
"old_score": current_eval.total_score,
|
||||
"new_score": new_eval.total_score,
|
||||
}
|
||||
)
|
||||
# Update best_candidate from state after successful iteration
|
||||
best_candidate = state.best_candidate # type: ignore[assignment]
|
||||
consecutive_failures = 0
|
||||
|
||||
except Exception as exc:
|
||||
self._log(f"Iter {i}: ERROR — {exc}. Skipping iteration.")
|
||||
consecutive_failures += 1
|
||||
self._log(
|
||||
f"Iter {i}: ERROR ({consecutive_failures} consecutive) — {exc}"
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "error",
|
||||
"error": str(exc),
|
||||
"consecutive_failures": consecutive_failures,
|
||||
}
|
||||
)
|
||||
|
||||
# Check circuit breaker
|
||||
if consecutive_failures >= self._circuit_breaker_threshold:
|
||||
self._log(
|
||||
f"Circuit breaker tripped after {consecutive_failures} "
|
||||
f"consecutive failures."
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "circuit_breaker",
|
||||
"consecutive_failures": consecutive_failures,
|
||||
}
|
||||
)
|
||||
if self._error_strategy == "abort":
|
||||
raise CircuitBreakerOpen(
|
||||
f"Circuit breaker tripped after "
|
||||
f"{consecutive_failures} consecutive failures"
|
||||
) from exc
|
||||
# skip / retry strategies: stop the loop gracefully
|
||||
break
|
||||
|
||||
if self._error_strategy == "abort":
|
||||
raise
|
||||
# skip / retry: continue to next iteration
|
||||
continue
|
||||
|
||||
return state
|
||||
|
||||
def _run_iteration(
|
||||
self,
|
||||
i: int,
|
||||
state: OptimizationState,
|
||||
best_candidate: Candidate,
|
||||
synthetic_pool: list[SyntheticExample],
|
||||
task_description: str,
|
||||
) -> None:
|
||||
"""Execute a single iteration. Mutates *state* in-place."""
|
||||
# 1. Sample a fresh minibatch
|
||||
batch = self._bootstrap.sample_minibatch(
|
||||
synthetic_pool, self._minibatch_size
|
||||
)
|
||||
|
||||
# 2. Evaluate the current candidate
|
||||
current_eval = self._evaluator.evaluate(
|
||||
best_candidate.prompt, batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * self._minibatch_size
|
||||
|
||||
# 3. Skip if perfect
|
||||
if all(s >= self._perfect_score for s in current_eval.scores):
|
||||
self._log(f"Iter {i}: All scores perfect, skipping.")
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "skip_perfect",
|
||||
"current_score": current_eval.total_score,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# 4. Propose a new prompt (reflective mutation)
|
||||
new_prompt = self._proposer.propose(
|
||||
best_candidate.prompt,
|
||||
current_eval.trajectories,
|
||||
task_description,
|
||||
)
|
||||
state.total_llm_calls += 1 # 1 proposition call
|
||||
|
||||
# 5. Evaluate the new prompt on the same minibatch
|
||||
new_eval = self._evaluator.evaluate(
|
||||
new_prompt, batch, task_description
|
||||
)
|
||||
state.total_llm_calls += 2 * self._minibatch_size
|
||||
|
||||
# 6. Accept or reject
|
||||
if should_accept(current_eval, new_eval):
|
||||
new_candidate = Candidate(
|
||||
prompt=new_prompt,
|
||||
best_score=new_eval.total_score,
|
||||
generation=i,
|
||||
parent_id=id(best_candidate),
|
||||
)
|
||||
state.best_candidate = new_candidate
|
||||
state.candidates.append(new_candidate)
|
||||
self._log(
|
||||
f"Iter {i}: ACCEPTED "
|
||||
f"({current_eval.total_score:.2f} -> {new_eval.total_score:.2f})"
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "accepted",
|
||||
"old_score": current_eval.total_score,
|
||||
"new_score": new_eval.total_score,
|
||||
"improvement": new_eval.total_score
|
||||
- current_eval.total_score,
|
||||
}
|
||||
)
|
||||
else:
|
||||
self._log(
|
||||
f"Iter {i}: REJECTED "
|
||||
f"({new_eval.total_score:.2f} <= {current_eval.total_score:.2f})"
|
||||
)
|
||||
state.history.append(
|
||||
{
|
||||
"iteration": i,
|
||||
"event": "rejected",
|
||||
"old_score": current_eval.total_score,
|
||||
"new_score": new_eval.total_score,
|
||||
}
|
||||
)
|
||||
|
||||
def _log(self, msg: str) -> None:
|
||||
if self._verbose:
|
||||
logger.info("[PROMETHEUS] %s", msg)
|
||||
|
||||
@@ -51,6 +51,8 @@ class OptimizePromptUseCase:
|
||||
minibatch_size=config.minibatch_size,
|
||||
perfect_score=config.perfect_score,
|
||||
verbose=config.verbose,
|
||||
circuit_breaker_threshold=config.circuit_breaker_threshold,
|
||||
error_strategy=config.error_strategy,
|
||||
)
|
||||
seed_prompt = Prompt(text=config.seed_prompt)
|
||||
state = loop.run(seed_prompt, synthetic_pool, config.task_description)
|
||||
|
||||
@@ -56,6 +56,16 @@ def optimize(
|
||||
"--verbose",
|
||||
help="Print detailed progress.",
|
||||
),
|
||||
max_retries: int = typer.Option(
|
||||
3,
|
||||
"--max-retries",
|
||||
help="Max retry attempts for transient LLM errors (429, timeout, 5xx).",
|
||||
),
|
||||
error_strategy: str = typer.Option(
|
||||
"retry",
|
||||
"--error-strategy",
|
||||
help="How to handle errors: skip | retry | abort.",
|
||||
),
|
||||
) -> None:
|
||||
"""Optimize a prompt without any reference data.
|
||||
|
||||
@@ -115,6 +125,10 @@ def optimize(
|
||||
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
|
||||
minibatch_size=raw_config.get("minibatch_size", 5),
|
||||
seed=raw_config.get("seed", 42),
|
||||
max_retries=raw_config.get("max_retries", max_retries),
|
||||
retry_delay_base=raw_config.get("retry_delay_base", 1.0),
|
||||
circuit_breaker_threshold=raw_config.get("circuit_breaker_threshold", 5),
|
||||
error_strategy=raw_config.get("error_strategy", error_strategy),
|
||||
output_path=output,
|
||||
verbose=verbose,
|
||||
)
|
||||
@@ -139,11 +153,23 @@ def optimize(
|
||||
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env),
|
||||
)
|
||||
|
||||
# 3. Build adapters (Dependency Injection — each gets its own LM)
|
||||
# 3. Build adapters (Dependency Injection — each gets its own LM + retry config)
|
||||
synth_adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||
llm_adapter = DSPyLLMAdapter(lm=task_lm)
|
||||
judge_adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||
proposer_adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||
llm_adapter = DSPyLLMAdapter(
|
||||
lm=task_lm,
|
||||
max_retries=config.max_retries,
|
||||
retry_delay_base=config.retry_delay_base,
|
||||
)
|
||||
judge_adapter = DSPyJudgeAdapter(
|
||||
lm=judge_lm,
|
||||
max_retries=config.max_retries,
|
||||
retry_delay_base=config.retry_delay_base,
|
||||
)
|
||||
proposer_adapter = DSPyProposerAdapter(
|
||||
lm=proposer_lm,
|
||||
max_retries=config.max_retries,
|
||||
retry_delay_base=config.retry_delay_base,
|
||||
)
|
||||
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
||||
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter)
|
||||
use_case = OptimizePromptUseCase(
|
||||
|
||||
@@ -5,21 +5,34 @@ Implements the JudgePort via the DSPy OutputJudge module.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import dspy
|
||||
|
||||
from prometheus.domain.ports import JudgePort
|
||||
from prometheus.infrastructure.dspy_modules import OutputJudge
|
||||
from prometheus.infrastructure.retry import retry_with_backoff
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DSPyJudgeAdapter(JudgePort):
|
||||
"""Evaluates a batch of (input, output) pairs by calling the Judge for each.
|
||||
|
||||
Sequential for MVP. Future: parallelize via dspy.Parallel.
|
||||
Per-call isolation: a failure on one item returns a zero-score sentinel
|
||||
instead of crashing the whole batch.
|
||||
"""
|
||||
|
||||
def __init__(self, lm: dspy.LM) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
lm: dspy.LM,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
) -> None:
|
||||
self._lm = lm
|
||||
self._judge = OutputJudge()
|
||||
self._max_retries = max_retries
|
||||
self._retry_delay_base = retry_delay_base
|
||||
|
||||
def judge_batch(
|
||||
self,
|
||||
@@ -29,10 +42,26 @@ class DSPyJudgeAdapter(JudgePort):
|
||||
results: list[tuple[float, str]] = []
|
||||
with dspy.context(lm=self._lm):
|
||||
for input_text, output_text in pairs:
|
||||
pred = self._judge(
|
||||
results.append(self._judge_single(task_description, input_text, output_text))
|
||||
return results
|
||||
|
||||
def _judge_single(
|
||||
self,
|
||||
task_description: str,
|
||||
input_text: str,
|
||||
output_text: str,
|
||||
) -> tuple[float, str]:
|
||||
try:
|
||||
pred = retry_with_backoff(
|
||||
lambda: self._judge(
|
||||
task_description=task_description,
|
||||
input_text=input_text,
|
||||
output_text=output_text,
|
||||
)
|
||||
results.append((pred.score, pred.feedback))
|
||||
return results
|
||||
),
|
||||
max_retries=self._max_retries,
|
||||
retry_delay_base=self._retry_delay_base,
|
||||
)
|
||||
return (pred.score, pred.feedback)
|
||||
except Exception as exc:
|
||||
logger.warning("Judge call failed for input '%s…': %s", input_text[:40], exc)
|
||||
return (0.0, f"[judge error: {exc}]")
|
||||
|
||||
@@ -9,6 +9,7 @@ import dspy
|
||||
|
||||
from prometheus.domain.entities import Prompt
|
||||
from prometheus.domain.ports import LLMPort
|
||||
from prometheus.infrastructure.retry import retry_with_backoff
|
||||
|
||||
|
||||
class DSPyLLMAdapter(LLMPort):
|
||||
@@ -21,14 +22,28 @@ class DSPyLLMAdapter(LLMPort):
|
||||
input_text: str = dspy.InputField(desc="The input to process.")
|
||||
output: str = dspy.OutputField(desc="The response following the instruction.")
|
||||
|
||||
def __init__(self, lm: dspy.LM) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
lm: dspy.LM,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
) -> None:
|
||||
self._lm = lm
|
||||
self._predictor = dspy.Predict(self._ExecuteSignature)
|
||||
self._max_retries = max_retries
|
||||
self._retry_delay_base = retry_delay_base
|
||||
|
||||
def execute(self, prompt: Prompt, input_text: str) -> str:
|
||||
with dspy.context(lm=self._lm):
|
||||
result = self._predictor(
|
||||
instruction=prompt.text,
|
||||
input_text=input_text,
|
||||
)
|
||||
return str(result.output)
|
||||
def _call() -> str:
|
||||
with dspy.context(lm=self._lm):
|
||||
result = self._predictor(
|
||||
instruction=prompt.text,
|
||||
input_text=input_text,
|
||||
)
|
||||
return str(result.output)
|
||||
|
||||
return retry_with_backoff(
|
||||
_call,
|
||||
max_retries=self._max_retries,
|
||||
retry_delay_base=self._retry_delay_base,
|
||||
)
|
||||
|
||||
@@ -11,14 +11,22 @@ import dspy
|
||||
from prometheus.domain.entities import Prompt, Trajectory
|
||||
from prometheus.domain.ports import ProposerPort
|
||||
from prometheus.infrastructure.dspy_modules import InstructionProposer
|
||||
from prometheus.infrastructure.retry import retry_with_backoff
|
||||
|
||||
|
||||
class DSPyProposerAdapter(ProposerPort):
|
||||
"""Uses evaluation trajectories to build a failure report and propose a new prompt."""
|
||||
|
||||
def __init__(self, lm: dspy.LM) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
lm: dspy.LM,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
) -> None:
|
||||
self._lm = lm
|
||||
self._proposer = InstructionProposer()
|
||||
self._max_retries = max_retries
|
||||
self._retry_delay_base = retry_delay_base
|
||||
|
||||
def propose(
|
||||
self,
|
||||
@@ -27,13 +35,21 @@ class DSPyProposerAdapter(ProposerPort):
|
||||
task_description: str,
|
||||
) -> Prompt:
|
||||
failure_examples = self._format_failures(trajectories)
|
||||
with dspy.context(lm=self._lm):
|
||||
pred = self._proposer(
|
||||
current_instruction=current_prompt.text,
|
||||
task_description=task_description,
|
||||
failure_examples=failure_examples,
|
||||
)
|
||||
return Prompt(text=pred.new_instruction)
|
||||
|
||||
def _call() -> Prompt:
|
||||
with dspy.context(lm=self._lm):
|
||||
pred = self._proposer(
|
||||
current_instruction=current_prompt.text,
|
||||
task_description=task_description,
|
||||
failure_examples=failure_examples,
|
||||
)
|
||||
return Prompt(text=pred.new_instruction)
|
||||
|
||||
return retry_with_backoff(
|
||||
_call,
|
||||
max_retries=self._max_retries,
|
||||
retry_delay_base=self._retry_delay_base,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_failures(trajectories: list[Trajectory]) -> str:
|
||||
|
||||
73
src/prometheus/infrastructure/retry.py
Normal file
73
src/prometheus/infrastructure/retry.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Retry with exponential backoff for transient LLM errors."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Status codes / keywords that indicate a transient error worth retrying.
|
||||
_TRANSIENT_PATTERNS = (
|
||||
"429",
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"connection error",
|
||||
"connection refused",
|
||||
"overloaded",
|
||||
)
|
||||
|
||||
|
||||
def is_transient_error(exc: Exception) -> bool:
|
||||
"""Return True if the exception looks like a transient LLM/API error."""
|
||||
msg = str(exc).lower()
|
||||
if any(p in msg for p in _TRANSIENT_PATTERNS):
|
||||
return True
|
||||
if isinstance(exc, (ConnectionError, TimeoutError, OSError)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TransientError(RuntimeError):
|
||||
"""Raised when all retry attempts are exhausted for a transient error."""
|
||||
|
||||
|
||||
def retry_with_backoff(
|
||||
fn: Callable[..., T],
|
||||
*args: Any,
|
||||
max_retries: int = 3,
|
||||
retry_delay_base: float = 1.0,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
"""Call *fn* with exponential-backoff retry on transient errors.
|
||||
|
||||
Delay per attempt: ``retry_delay_base * 2 ** attempt`` seconds.
|
||||
"""
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if not is_transient_error(exc) or attempt == max_retries:
|
||||
raise
|
||||
delay = retry_delay_base * (2 ** attempt)
|
||||
logger.warning(
|
||||
"Transient error (attempt %d/%d): %s — retrying in %.1fs",
|
||||
attempt + 1,
|
||||
max_retries + 1,
|
||||
exc,
|
||||
delay,
|
||||
)
|
||||
time.sleep(delay)
|
||||
# Should not reach here, but satisfy type-checker.
|
||||
raise TransientError(str(last_exc)) from last_exc
|
||||
302
tests/unit/test_error_handling.py
Normal file
302
tests/unit/test_error_handling.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""Unit tests for error handling: retry, circuit breaker, per-call isolation."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||
from prometheus.application.evaluator import PromptEvaluator
|
||||
from prometheus.application.evolution import CircuitBreakerOpen, EvolutionLoop
|
||||
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory
|
||||
from prometheus.infrastructure.retry import is_transient_error, retry_with_backoff
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Retry utility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsTransientError:
|
||||
def test_rate_limit_429(self):
|
||||
assert is_transient_error(RuntimeError("HTTP 429: rate limit exceeded"))
|
||||
|
||||
def test_server_error_503(self):
|
||||
assert is_transient_error(RuntimeError("503 Service Unavailable"))
|
||||
|
||||
def test_timeout(self):
|
||||
assert is_transient_error(TimeoutError("request timed out"))
|
||||
|
||||
def test_connection_error(self):
|
||||
assert is_transient_error(ConnectionError("connection refused"))
|
||||
|
||||
def test_non_transient(self):
|
||||
assert not is_transient_error(ValueError("bad input"))
|
||||
|
||||
def test_os_error(self):
|
||||
assert is_transient_error(OSError("network unreachable"))
|
||||
|
||||
|
||||
class TestRetryWithBackoff:
|
||||
def test_succeeds_on_first_try(self):
|
||||
fn = MagicMock(return_value="ok")
|
||||
result = retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
|
||||
assert result == "ok"
|
||||
assert fn.call_count == 1
|
||||
|
||||
def test_retries_on_transient_then_succeeds(self):
|
||||
fn = MagicMock(
|
||||
side_effect=[
|
||||
RuntimeError("429 rate limit"),
|
||||
RuntimeError("429 rate limit"),
|
||||
"ok",
|
||||
]
|
||||
)
|
||||
with patch("prometheus.infrastructure.retry.time.sleep"):
|
||||
result = retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
|
||||
assert result == "ok"
|
||||
assert fn.call_count == 3
|
||||
|
||||
def test_raises_after_max_retries(self):
|
||||
fn = MagicMock(side_effect=RuntimeError("503 overloaded"))
|
||||
with patch("prometheus.infrastructure.retry.time.sleep"):
|
||||
with pytest.raises(RuntimeError, match="503"):
|
||||
retry_with_backoff(fn, max_retries=2, retry_delay_base=0)
|
||||
assert fn.call_count == 3 # 1 initial + 2 retries
|
||||
|
||||
def test_non_transient_not_retried(self):
|
||||
fn = MagicMock(side_effect=ValueError("bad"))
|
||||
with pytest.raises(ValueError, match="bad"):
|
||||
retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
|
||||
assert fn.call_count == 1
|
||||
|
||||
def test_exponential_backoff_delays(self):
|
||||
fn = MagicMock(side_effect=[RuntimeError("timeout"), "ok"])
|
||||
with patch("prometheus.infrastructure.retry.time.sleep") as mock_sleep:
|
||||
retry_with_backoff(fn, max_retries=3, retry_delay_base=2.0)
|
||||
mock_sleep.assert_called_once_with(2.0) # 2.0 * 2^0 = 2.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Circuit breaker (EvolutionLoop)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_eval_result(scores, feedbacks=None):
|
||||
"""Helper to create EvalResult with matching trajectories."""
|
||||
feedbacks = feedbacks or ["ok"] * len(scores)
|
||||
return EvalResult(
|
||||
scores=scores,
|
||||
feedbacks=feedbacks,
|
||||
trajectories=[
|
||||
Trajectory(f"input{i}", f"output{i}", s, f, "prompt")
|
||||
for i, (s, f) in enumerate(zip(scores, feedbacks))
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
def test_trips_on_consecutive_failures(self):
|
||||
"""Loop stops when consecutive failures reach the threshold."""
|
||||
initial_eval = _make_eval_result([0.3, 0.4])
|
||||
evaluator = MagicMock()
|
||||
call_count = 0
|
||||
|
||||
def _evaluate(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return initial_eval # seed eval succeeds
|
||||
raise RuntimeError("LLM down")
|
||||
|
||||
evaluator.evaluate.side_effect = _evaluate
|
||||
proposer = MagicMock()
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=10,
|
||||
minibatch_size=2,
|
||||
circuit_breaker_threshold=3,
|
||||
error_strategy="skip",
|
||||
)
|
||||
with patch.object(loop, "_log"):
|
||||
state = loop.run(
|
||||
Prompt("test"),
|
||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||
"task",
|
||||
)
|
||||
|
||||
error_events = [h for h in state.history if h.get("event") == "error"]
|
||||
cb_events = [h for h in state.history if h.get("event") == "circuit_breaker"]
|
||||
assert len(error_events) == 3
|
||||
assert len(cb_events) == 1
|
||||
assert state.iteration < 10 # stopped early
|
||||
|
||||
def test_abort_raises_on_first_error(self):
|
||||
"""With error_strategy=abort, the first error raises immediately."""
|
||||
initial_eval = _make_eval_result([0.3, 0.4])
|
||||
evaluator = MagicMock()
|
||||
call_count = 0
|
||||
|
||||
def _evaluate(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return initial_eval
|
||||
raise RuntimeError("LLM down")
|
||||
|
||||
evaluator.evaluate.side_effect = _evaluate
|
||||
proposer = MagicMock()
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=10,
|
||||
minibatch_size=2,
|
||||
circuit_breaker_threshold=3,
|
||||
error_strategy="abort",
|
||||
)
|
||||
with patch.object(loop, "_log"):
|
||||
with pytest.raises(RuntimeError, match="LLM down"):
|
||||
loop.run(
|
||||
Prompt("test"),
|
||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||
"task",
|
||||
)
|
||||
|
||||
def test_resets_on_success(self):
|
||||
"""Consecutive failure counter resets after a successful iteration."""
|
||||
initial_eval = _make_eval_result([0.3, 0.4])
|
||||
good_eval = _make_eval_result([0.8, 0.9])
|
||||
|
||||
evaluator = MagicMock()
|
||||
call_count = 0
|
||||
|
||||
def _evaluate(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# call 1: seed eval → succeed
|
||||
# call 2: iter 1 current eval → fail
|
||||
# call 3: iter 2 current eval → fail
|
||||
# call 4: iter 3 current eval → succeed (returns initial_eval)
|
||||
# call 5: iter 3 new eval → succeed (returns good_eval, accepted)
|
||||
# call 6+: iter 4+ current eval → succeed
|
||||
if call_count == 1:
|
||||
return initial_eval
|
||||
if call_count in (2, 3):
|
||||
raise RuntimeError("timeout")
|
||||
if call_count % 2 == 0:
|
||||
return initial_eval # current eval
|
||||
return good_eval # new eval
|
||||
|
||||
evaluator.evaluate.side_effect = _evaluate
|
||||
proposer = MagicMock()
|
||||
proposer.propose.return_value = Prompt("better prompt")
|
||||
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||
bootstrap.sample_minibatch.return_value = [
|
||||
SyntheticExample(f"in{i}", id=i) for i in range(2)
|
||||
]
|
||||
|
||||
loop = EvolutionLoop(
|
||||
evaluator=evaluator,
|
||||
proposer=proposer,
|
||||
bootstrap=bootstrap,
|
||||
max_iterations=6,
|
||||
minibatch_size=2,
|
||||
circuit_breaker_threshold=3,
|
||||
error_strategy="skip",
|
||||
)
|
||||
with patch.object(loop, "_log"):
|
||||
state = loop.run(
|
||||
Prompt("test"),
|
||||
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||
"task",
|
||||
)
|
||||
|
||||
# Should NOT have tripped — 2 fails, then success reset the counter
|
||||
cb_events = [h for h in state.history if h.get("event") == "circuit_breaker"]
|
||||
assert len(cb_events) == 0
|
||||
assert state.iteration == 6 # ran all iterations
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-call isolation (Evaluator)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPerCallIsolation:
|
||||
def test_evaluator_isolates_execution_failure(self):
|
||||
"""A failing execution produces a sentinel output, not a crash."""
|
||||
executor = MagicMock()
|
||||
executor.execute.side_effect = [
|
||||
"good output",
|
||||
RuntimeError("API error"),
|
||||
"another good output",
|
||||
]
|
||||
judge = MagicMock()
|
||||
judge.judge_batch.return_value = [
|
||||
(0.8, "good"),
|
||||
(0.0, "[judge error]"),
|
||||
(0.7, "ok"),
|
||||
]
|
||||
|
||||
evaluator = PromptEvaluator(executor, judge)
|
||||
result = evaluator.evaluate(
|
||||
Prompt("test"),
|
||||
[
|
||||
SyntheticExample("in0", id=0),
|
||||
SyntheticExample("in1", id=1),
|
||||
SyntheticExample("in2", id=2),
|
||||
],
|
||||
"task",
|
||||
)
|
||||
|
||||
assert len(result.scores) == 3
|
||||
assert result.scores[1] == 0.0 # failed item got zero score
|
||||
assert "execution error" in result.trajectories[1].output_text
|
||||
assert result.scores[0] == 0.8 # other items unaffected
|
||||
|
||||
def test_judge_adapter_isolates_single_failure(self):
|
||||
"""DSPyJudgeAdapter returns sentinel for a failed item, not crash."""
|
||||
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||
|
||||
adapter = DSPyJudgeAdapter.__new__(DSPyJudgeAdapter)
|
||||
adapter._lm = MagicMock()
|
||||
adapter._max_retries = 1
|
||||
adapter._retry_delay_base = 0
|
||||
|
||||
# Mock _judge to fail on first call, succeed on second
|
||||
call_count = 0
|
||||
|
||||
class FakePred:
|
||||
def __init__(self):
|
||||
self.score = 0.9
|
||||
self.feedback = "good"
|
||||
|
||||
def fake_judge(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("judge failure")
|
||||
return FakePred()
|
||||
|
||||
adapter._judge = fake_judge
|
||||
|
||||
with patch("prometheus.infrastructure.judge_adapter.dspy.context"):
|
||||
with patch(
|
||||
"prometheus.infrastructure.retry.time.sleep"
|
||||
):
|
||||
results = adapter.judge_batch(
|
||||
"task", [("input1", "output1"), ("input2", "output2")]
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
# First item failed even after retry → sentinel
|
||||
assert results[0] == (0.0, "[judge error: judge failure]")
|
||||
# Second item succeeded
|
||||
assert results[1] == (0.9, "good")
|
||||
Reference in New Issue
Block a user