feat: error handling, retry with backoff, and circuit breaker
Add robust error handling to the evolution loop and LLM adapters: - Retry utility with exponential backoff for transient errors (429, 5xx, timeouts) - Per-call error isolation in evaluator and judge adapter - Circuit breaker in EvolutionLoop (trips after N consecutive failures) - CLI flags: --max-retries, --error-strategy (skip|retry|abort) - Config fields: max_retries, retry_delay_base, circuit_breaker_threshold, error_strategy - 16 new unit tests covering all error handling paths Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -38,6 +38,12 @@ class OptimizationConfig:
|
|||||||
# --- Reproducibility ---
|
# --- Reproducibility ---
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
|
|
||||||
|
# --- Error handling ---
|
||||||
|
max_retries: int = 3
|
||||||
|
retry_delay_base: float = 1.0
|
||||||
|
circuit_breaker_threshold: int = 5
|
||||||
|
error_strategy: str = "retry" # skip | retry | abort
|
||||||
|
|
||||||
# --- Output ---
|
# --- Output ---
|
||||||
output_path: str = "output.yaml"
|
output_path: str = "output.yaml"
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ Combines candidate prompt execution + LLM-as-Judge evaluation.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from prometheus.domain.entities import (
|
from prometheus.domain.entities import (
|
||||||
EvalResult,
|
EvalResult,
|
||||||
Prompt,
|
Prompt,
|
||||||
@@ -14,6 +16,8 @@ from prometheus.domain.entities import (
|
|||||||
)
|
)
|
||||||
from prometheus.domain.ports import JudgePort, LLMPort
|
from prometheus.domain.ports import JudgePort, LLMPort
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PromptEvaluator:
|
class PromptEvaluator:
|
||||||
"""Evaluates a prompt on a minibatch of synthetic inputs.
|
"""Evaluates a prompt on a minibatch of synthetic inputs.
|
||||||
@@ -21,6 +25,9 @@ class PromptEvaluator:
|
|||||||
Pipeline: execute → judge → build trajectories.
|
Pipeline: execute → judge → build trajectories.
|
||||||
Replaces GEPA's EvaluatorFn. Instead of comparing to ground truth,
|
Replaces GEPA's EvaluatorFn. Instead of comparing to ground truth,
|
||||||
uses an LLM-as-Judge.
|
uses an LLM-as-Judge.
|
||||||
|
|
||||||
|
Per-call isolation: a failure on one minibatch item produces a
|
||||||
|
zero-score trajectory instead of crashing the whole batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, executor: LLMPort, judge: JudgePort):
|
def __init__(self, executor: LLMPort, judge: JudgePort):
|
||||||
@@ -40,13 +47,21 @@ class PromptEvaluator:
|
|||||||
2. Judge each (input, output) pair
|
2. Judge each (input, output) pair
|
||||||
3. Build trajectories with feedback
|
3. Build trajectories with feedback
|
||||||
"""
|
"""
|
||||||
# Step 1: Execution
|
# Step 1: Execution (per-item isolation)
|
||||||
outputs: list[str] = []
|
outputs: list[str] = []
|
||||||
for example in minibatch:
|
for example in minibatch:
|
||||||
|
try:
|
||||||
raw_output = self._executor.execute(prompt, example.input_text)
|
raw_output = self._executor.execute(prompt, example.input_text)
|
||||||
outputs.append(raw_output)
|
outputs.append(raw_output)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Execution failed for input '%s…': %s",
|
||||||
|
example.input_text[:40],
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
outputs.append(f"[execution error: {exc}]")
|
||||||
|
|
||||||
# Step 2: Judgement
|
# Step 2: Judgement (judge_adapter handles its own per-call isolation)
|
||||||
pairs = [(ex.input_text, out) for ex, out in zip(minibatch, outputs)]
|
pairs = [(ex.input_text, out) for ex, out in zip(minibatch, outputs)]
|
||||||
judge_results = self._judge.judge_batch(task_description, pairs)
|
judge_results = self._judge.judge_batch(task_description, pairs)
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ from prometheus.domain.scoring import should_accept
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreakerOpen(Exception):
|
||||||
|
"""Raised when the circuit breaker trips due to too many consecutive failures."""
|
||||||
|
|
||||||
|
|
||||||
class EvolutionLoop:
|
class EvolutionLoop:
|
||||||
"""Main evolution loop.
|
"""Main evolution loop.
|
||||||
|
|
||||||
@@ -29,6 +33,11 @@ class EvolutionLoop:
|
|||||||
- Keeps only the best candidate (no full population).
|
- Keeps only the best candidate (no full population).
|
||||||
- Simplifies vs GEPA (no Pareto, no merge).
|
- Simplifies vs GEPA (no Pareto, no merge).
|
||||||
- Population support deferred to v2.
|
- Population support deferred to v2.
|
||||||
|
|
||||||
|
Error handling:
|
||||||
|
- Transient errors are retried by adapters.
|
||||||
|
- Circuit breaker trips after N consecutive iteration failures.
|
||||||
|
- error_strategy controls what happens on non-transient errors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -40,6 +49,8 @@ class EvolutionLoop:
|
|||||||
minibatch_size: int = 5,
|
minibatch_size: int = 5,
|
||||||
perfect_score: float = 1.0,
|
perfect_score: float = 1.0,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
circuit_breaker_threshold: int = 5,
|
||||||
|
error_strategy: str = "retry",
|
||||||
):
|
):
|
||||||
self._evaluator = evaluator
|
self._evaluator = evaluator
|
||||||
self._proposer = proposer
|
self._proposer = proposer
|
||||||
@@ -48,6 +59,8 @@ class EvolutionLoop:
|
|||||||
self._minibatch_size = minibatch_size
|
self._minibatch_size = minibatch_size
|
||||||
self._perfect_score = perfect_score
|
self._perfect_score = perfect_score
|
||||||
self._verbose = verbose
|
self._verbose = verbose
|
||||||
|
self._circuit_breaker_threshold = circuit_breaker_threshold
|
||||||
|
self._error_strategy = error_strategy
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
@@ -57,6 +70,7 @@ class EvolutionLoop:
|
|||||||
) -> OptimizationState:
|
) -> OptimizationState:
|
||||||
"""Execute the complete evolution loop."""
|
"""Execute the complete evolution loop."""
|
||||||
state = OptimizationState()
|
state = OptimizationState()
|
||||||
|
consecutive_failures = 0
|
||||||
|
|
||||||
# Evaluate the seed
|
# Evaluate the seed
|
||||||
initial_batch = self._bootstrap.sample_minibatch(
|
initial_batch = self._bootstrap.sample_minibatch(
|
||||||
@@ -81,6 +95,64 @@ class EvolutionLoop:
|
|||||||
state.iteration = i
|
state.iteration = i
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
self._run_iteration(
|
||||||
|
i, state, best_candidate, synthetic_pool, task_description
|
||||||
|
)
|
||||||
|
# Update best_candidate from state after successful iteration
|
||||||
|
best_candidate = state.best_candidate # type: ignore[assignment]
|
||||||
|
consecutive_failures = 0
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
consecutive_failures += 1
|
||||||
|
self._log(
|
||||||
|
f"Iter {i}: ERROR ({consecutive_failures} consecutive) — {exc}"
|
||||||
|
)
|
||||||
|
state.history.append(
|
||||||
|
{
|
||||||
|
"iteration": i,
|
||||||
|
"event": "error",
|
||||||
|
"error": str(exc),
|
||||||
|
"consecutive_failures": consecutive_failures,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check circuit breaker
|
||||||
|
if consecutive_failures >= self._circuit_breaker_threshold:
|
||||||
|
self._log(
|
||||||
|
f"Circuit breaker tripped after {consecutive_failures} "
|
||||||
|
f"consecutive failures."
|
||||||
|
)
|
||||||
|
state.history.append(
|
||||||
|
{
|
||||||
|
"iteration": i,
|
||||||
|
"event": "circuit_breaker",
|
||||||
|
"consecutive_failures": consecutive_failures,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if self._error_strategy == "abort":
|
||||||
|
raise CircuitBreakerOpen(
|
||||||
|
f"Circuit breaker tripped after "
|
||||||
|
f"{consecutive_failures} consecutive failures"
|
||||||
|
) from exc
|
||||||
|
# skip / retry strategies: stop the loop gracefully
|
||||||
|
break
|
||||||
|
|
||||||
|
if self._error_strategy == "abort":
|
||||||
|
raise
|
||||||
|
# skip / retry: continue to next iteration
|
||||||
|
continue
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def _run_iteration(
|
||||||
|
self,
|
||||||
|
i: int,
|
||||||
|
state: OptimizationState,
|
||||||
|
best_candidate: Candidate,
|
||||||
|
synthetic_pool: list[SyntheticExample],
|
||||||
|
task_description: str,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a single iteration. Mutates *state* in-place."""
|
||||||
# 1. Sample a fresh minibatch
|
# 1. Sample a fresh minibatch
|
||||||
batch = self._bootstrap.sample_minibatch(
|
batch = self._bootstrap.sample_minibatch(
|
||||||
synthetic_pool, self._minibatch_size
|
synthetic_pool, self._minibatch_size
|
||||||
@@ -102,7 +174,7 @@ class EvolutionLoop:
|
|||||||
"current_score": current_eval.total_score,
|
"current_score": current_eval.total_score,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
continue
|
return
|
||||||
|
|
||||||
# 4. Propose a new prompt (reflective mutation)
|
# 4. Propose a new prompt (reflective mutation)
|
||||||
new_prompt = self._proposer.propose(
|
new_prompt = self._proposer.propose(
|
||||||
@@ -120,14 +192,14 @@ class EvolutionLoop:
|
|||||||
|
|
||||||
# 6. Accept or reject
|
# 6. Accept or reject
|
||||||
if should_accept(current_eval, new_eval):
|
if should_accept(current_eval, new_eval):
|
||||||
best_candidate = Candidate(
|
new_candidate = Candidate(
|
||||||
prompt=new_prompt,
|
prompt=new_prompt,
|
||||||
best_score=new_eval.total_score,
|
best_score=new_eval.total_score,
|
||||||
generation=i,
|
generation=i,
|
||||||
parent_id=id(best_candidate),
|
parent_id=id(best_candidate),
|
||||||
)
|
)
|
||||||
state.best_candidate = best_candidate
|
state.best_candidate = new_candidate
|
||||||
state.candidates.append(best_candidate)
|
state.candidates.append(new_candidate)
|
||||||
self._log(
|
self._log(
|
||||||
f"Iter {i}: ACCEPTED "
|
f"Iter {i}: ACCEPTED "
|
||||||
f"({current_eval.total_score:.2f} -> {new_eval.total_score:.2f})"
|
f"({current_eval.total_score:.2f} -> {new_eval.total_score:.2f})"
|
||||||
@@ -156,19 +228,6 @@ class EvolutionLoop:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
self._log(f"Iter {i}: ERROR — {exc}. Skipping iteration.")
|
|
||||||
state.history.append(
|
|
||||||
{
|
|
||||||
"iteration": i,
|
|
||||||
"event": "error",
|
|
||||||
"error": str(exc),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
||||||
def _log(self, msg: str) -> None:
|
def _log(self, msg: str) -> None:
|
||||||
if self._verbose:
|
if self._verbose:
|
||||||
logger.info("[PROMETHEUS] %s", msg)
|
logger.info("[PROMETHEUS] %s", msg)
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ class OptimizePromptUseCase:
|
|||||||
minibatch_size=config.minibatch_size,
|
minibatch_size=config.minibatch_size,
|
||||||
perfect_score=config.perfect_score,
|
perfect_score=config.perfect_score,
|
||||||
verbose=config.verbose,
|
verbose=config.verbose,
|
||||||
|
circuit_breaker_threshold=config.circuit_breaker_threshold,
|
||||||
|
error_strategy=config.error_strategy,
|
||||||
)
|
)
|
||||||
seed_prompt = Prompt(text=config.seed_prompt)
|
seed_prompt = Prompt(text=config.seed_prompt)
|
||||||
state = loop.run(seed_prompt, synthetic_pool, config.task_description)
|
state = loop.run(seed_prompt, synthetic_pool, config.task_description)
|
||||||
|
|||||||
@@ -56,6 +56,16 @@ def optimize(
|
|||||||
"--verbose",
|
"--verbose",
|
||||||
help="Print detailed progress.",
|
help="Print detailed progress.",
|
||||||
),
|
),
|
||||||
|
max_retries: int = typer.Option(
|
||||||
|
3,
|
||||||
|
"--max-retries",
|
||||||
|
help="Max retry attempts for transient LLM errors (429, timeout, 5xx).",
|
||||||
|
),
|
||||||
|
error_strategy: str = typer.Option(
|
||||||
|
"retry",
|
||||||
|
"--error-strategy",
|
||||||
|
help="How to handle errors: skip | retry | abort.",
|
||||||
|
),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Optimize a prompt without any reference data.
|
"""Optimize a prompt without any reference data.
|
||||||
|
|
||||||
@@ -115,6 +125,10 @@ def optimize(
|
|||||||
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
|
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
|
||||||
minibatch_size=raw_config.get("minibatch_size", 5),
|
minibatch_size=raw_config.get("minibatch_size", 5),
|
||||||
seed=raw_config.get("seed", 42),
|
seed=raw_config.get("seed", 42),
|
||||||
|
max_retries=raw_config.get("max_retries", max_retries),
|
||||||
|
retry_delay_base=raw_config.get("retry_delay_base", 1.0),
|
||||||
|
circuit_breaker_threshold=raw_config.get("circuit_breaker_threshold", 5),
|
||||||
|
error_strategy=raw_config.get("error_strategy", error_strategy),
|
||||||
output_path=output,
|
output_path=output,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
@@ -139,11 +153,23 @@ def optimize(
|
|||||||
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env),
|
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Build adapters (Dependency Injection — each gets its own LM)
|
# 3. Build adapters (Dependency Injection — each gets its own LM + retry config)
|
||||||
synth_adapter = DSPySyntheticAdapter(lm=synth_lm)
|
synth_adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||||
llm_adapter = DSPyLLMAdapter(lm=task_lm)
|
llm_adapter = DSPyLLMAdapter(
|
||||||
judge_adapter = DSPyJudgeAdapter(lm=judge_lm)
|
lm=task_lm,
|
||||||
proposer_adapter = DSPyProposerAdapter(lm=proposer_lm)
|
max_retries=config.max_retries,
|
||||||
|
retry_delay_base=config.retry_delay_base,
|
||||||
|
)
|
||||||
|
judge_adapter = DSPyJudgeAdapter(
|
||||||
|
lm=judge_lm,
|
||||||
|
max_retries=config.max_retries,
|
||||||
|
retry_delay_base=config.retry_delay_base,
|
||||||
|
)
|
||||||
|
proposer_adapter = DSPyProposerAdapter(
|
||||||
|
lm=proposer_lm,
|
||||||
|
max_retries=config.max_retries,
|
||||||
|
retry_delay_base=config.retry_delay_base,
|
||||||
|
)
|
||||||
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
||||||
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter)
|
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter)
|
||||||
use_case = OptimizePromptUseCase(
|
use_case = OptimizePromptUseCase(
|
||||||
|
|||||||
@@ -5,21 +5,34 @@ Implements the JudgePort via the DSPy OutputJudge module.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import dspy
|
import dspy
|
||||||
|
|
||||||
from prometheus.domain.ports import JudgePort
|
from prometheus.domain.ports import JudgePort
|
||||||
from prometheus.infrastructure.dspy_modules import OutputJudge
|
from prometheus.infrastructure.dspy_modules import OutputJudge
|
||||||
|
from prometheus.infrastructure.retry import retry_with_backoff
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DSPyJudgeAdapter(JudgePort):
|
class DSPyJudgeAdapter(JudgePort):
|
||||||
"""Evaluates a batch of (input, output) pairs by calling the Judge for each.
|
"""Evaluates a batch of (input, output) pairs by calling the Judge for each.
|
||||||
|
|
||||||
Sequential for MVP. Future: parallelize via dspy.Parallel.
|
Per-call isolation: a failure on one item returns a zero-score sentinel
|
||||||
|
instead of crashing the whole batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lm: dspy.LM) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
lm: dspy.LM,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay_base: float = 1.0,
|
||||||
|
) -> None:
|
||||||
self._lm = lm
|
self._lm = lm
|
||||||
self._judge = OutputJudge()
|
self._judge = OutputJudge()
|
||||||
|
self._max_retries = max_retries
|
||||||
|
self._retry_delay_base = retry_delay_base
|
||||||
|
|
||||||
def judge_batch(
|
def judge_batch(
|
||||||
self,
|
self,
|
||||||
@@ -29,10 +42,26 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
results: list[tuple[float, str]] = []
|
results: list[tuple[float, str]] = []
|
||||||
with dspy.context(lm=self._lm):
|
with dspy.context(lm=self._lm):
|
||||||
for input_text, output_text in pairs:
|
for input_text, output_text in pairs:
|
||||||
pred = self._judge(
|
results.append(self._judge_single(task_description, input_text, output_text))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _judge_single(
|
||||||
|
self,
|
||||||
|
task_description: str,
|
||||||
|
input_text: str,
|
||||||
|
output_text: str,
|
||||||
|
) -> tuple[float, str]:
|
||||||
|
try:
|
||||||
|
pred = retry_with_backoff(
|
||||||
|
lambda: self._judge(
|
||||||
task_description=task_description,
|
task_description=task_description,
|
||||||
input_text=input_text,
|
input_text=input_text,
|
||||||
output_text=output_text,
|
output_text=output_text,
|
||||||
|
),
|
||||||
|
max_retries=self._max_retries,
|
||||||
|
retry_delay_base=self._retry_delay_base,
|
||||||
)
|
)
|
||||||
results.append((pred.score, pred.feedback))
|
return (pred.score, pred.feedback)
|
||||||
return results
|
except Exception as exc:
|
||||||
|
logger.warning("Judge call failed for input '%s…': %s", input_text[:40], exc)
|
||||||
|
return (0.0, f"[judge error: {exc}]")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import dspy
|
|||||||
|
|
||||||
from prometheus.domain.entities import Prompt
|
from prometheus.domain.entities import Prompt
|
||||||
from prometheus.domain.ports import LLMPort
|
from prometheus.domain.ports import LLMPort
|
||||||
|
from prometheus.infrastructure.retry import retry_with_backoff
|
||||||
|
|
||||||
|
|
||||||
class DSPyLLMAdapter(LLMPort):
|
class DSPyLLMAdapter(LLMPort):
|
||||||
@@ -21,14 +22,28 @@ class DSPyLLMAdapter(LLMPort):
|
|||||||
input_text: str = dspy.InputField(desc="The input to process.")
|
input_text: str = dspy.InputField(desc="The input to process.")
|
||||||
output: str = dspy.OutputField(desc="The response following the instruction.")
|
output: str = dspy.OutputField(desc="The response following the instruction.")
|
||||||
|
|
||||||
def __init__(self, lm: dspy.LM) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
lm: dspy.LM,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay_base: float = 1.0,
|
||||||
|
) -> None:
|
||||||
self._lm = lm
|
self._lm = lm
|
||||||
self._predictor = dspy.Predict(self._ExecuteSignature)
|
self._predictor = dspy.Predict(self._ExecuteSignature)
|
||||||
|
self._max_retries = max_retries
|
||||||
|
self._retry_delay_base = retry_delay_base
|
||||||
|
|
||||||
def execute(self, prompt: Prompt, input_text: str) -> str:
|
def execute(self, prompt: Prompt, input_text: str) -> str:
|
||||||
|
def _call() -> str:
|
||||||
with dspy.context(lm=self._lm):
|
with dspy.context(lm=self._lm):
|
||||||
result = self._predictor(
|
result = self._predictor(
|
||||||
instruction=prompt.text,
|
instruction=prompt.text,
|
||||||
input_text=input_text,
|
input_text=input_text,
|
||||||
)
|
)
|
||||||
return str(result.output)
|
return str(result.output)
|
||||||
|
|
||||||
|
return retry_with_backoff(
|
||||||
|
_call,
|
||||||
|
max_retries=self._max_retries,
|
||||||
|
retry_delay_base=self._retry_delay_base,
|
||||||
|
)
|
||||||
|
|||||||
@@ -11,14 +11,22 @@ import dspy
|
|||||||
from prometheus.domain.entities import Prompt, Trajectory
|
from prometheus.domain.entities import Prompt, Trajectory
|
||||||
from prometheus.domain.ports import ProposerPort
|
from prometheus.domain.ports import ProposerPort
|
||||||
from prometheus.infrastructure.dspy_modules import InstructionProposer
|
from prometheus.infrastructure.dspy_modules import InstructionProposer
|
||||||
|
from prometheus.infrastructure.retry import retry_with_backoff
|
||||||
|
|
||||||
|
|
||||||
class DSPyProposerAdapter(ProposerPort):
|
class DSPyProposerAdapter(ProposerPort):
|
||||||
"""Uses evaluation trajectories to build a failure report and propose a new prompt."""
|
"""Uses evaluation trajectories to build a failure report and propose a new prompt."""
|
||||||
|
|
||||||
def __init__(self, lm: dspy.LM) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
lm: dspy.LM,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay_base: float = 1.0,
|
||||||
|
) -> None:
|
||||||
self._lm = lm
|
self._lm = lm
|
||||||
self._proposer = InstructionProposer()
|
self._proposer = InstructionProposer()
|
||||||
|
self._max_retries = max_retries
|
||||||
|
self._retry_delay_base = retry_delay_base
|
||||||
|
|
||||||
def propose(
|
def propose(
|
||||||
self,
|
self,
|
||||||
@@ -27,6 +35,8 @@ class DSPyProposerAdapter(ProposerPort):
|
|||||||
task_description: str,
|
task_description: str,
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
failure_examples = self._format_failures(trajectories)
|
failure_examples = self._format_failures(trajectories)
|
||||||
|
|
||||||
|
def _call() -> Prompt:
|
||||||
with dspy.context(lm=self._lm):
|
with dspy.context(lm=self._lm):
|
||||||
pred = self._proposer(
|
pred = self._proposer(
|
||||||
current_instruction=current_prompt.text,
|
current_instruction=current_prompt.text,
|
||||||
@@ -35,6 +45,12 @@ class DSPyProposerAdapter(ProposerPort):
|
|||||||
)
|
)
|
||||||
return Prompt(text=pred.new_instruction)
|
return Prompt(text=pred.new_instruction)
|
||||||
|
|
||||||
|
return retry_with_backoff(
|
||||||
|
_call,
|
||||||
|
max_retries=self._max_retries,
|
||||||
|
retry_delay_base=self._retry_delay_base,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_failures(trajectories: list[Trajectory]) -> str:
|
def _format_failures(trajectories: list[Trajectory]) -> str:
|
||||||
"""Convert trajectories into a structured textual report."""
|
"""Convert trajectories into a structured textual report."""
|
||||||
|
|||||||
73
src/prometheus/infrastructure/retry.py
Normal file
73
src/prometheus/infrastructure/retry.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Retry with exponential backoff for transient LLM errors."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, TypeVar
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
# Status codes / keywords that indicate a transient error worth retrying.
|
||||||
|
_TRANSIENT_PATTERNS = (
|
||||||
|
"429",
|
||||||
|
"rate limit",
|
||||||
|
"rate_limit",
|
||||||
|
"too many requests",
|
||||||
|
"500",
|
||||||
|
"502",
|
||||||
|
"503",
|
||||||
|
"504",
|
||||||
|
"timeout",
|
||||||
|
"timed out",
|
||||||
|
"connection error",
|
||||||
|
"connection refused",
|
||||||
|
"overloaded",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_transient_error(exc: Exception) -> bool:
|
||||||
|
"""Return True if the exception looks like a transient LLM/API error."""
|
||||||
|
msg = str(exc).lower()
|
||||||
|
if any(p in msg for p in _TRANSIENT_PATTERNS):
|
||||||
|
return True
|
||||||
|
if isinstance(exc, (ConnectionError, TimeoutError, OSError)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TransientError(RuntimeError):
|
||||||
|
"""Raised when all retry attempts are exhausted for a transient error."""
|
||||||
|
|
||||||
|
|
||||||
|
def retry_with_backoff(
|
||||||
|
fn: Callable[..., T],
|
||||||
|
*args: Any,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay_base: float = 1.0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> T:
|
||||||
|
"""Call *fn* with exponential-backoff retry on transient errors.
|
||||||
|
|
||||||
|
Delay per attempt: ``retry_delay_base * 2 ** attempt`` seconds.
|
||||||
|
"""
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
if not is_transient_error(exc) or attempt == max_retries:
|
||||||
|
raise
|
||||||
|
delay = retry_delay_base * (2 ** attempt)
|
||||||
|
logger.warning(
|
||||||
|
"Transient error (attempt %d/%d): %s — retrying in %.1fs",
|
||||||
|
attempt + 1,
|
||||||
|
max_retries + 1,
|
||||||
|
exc,
|
||||||
|
delay,
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
# Should not reach here, but satisfy type-checker.
|
||||||
|
raise TransientError(str(last_exc)) from last_exc
|
||||||
302
tests/unit/test_error_handling.py
Normal file
302
tests/unit/test_error_handling.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
"""Unit tests for error handling: retry, circuit breaker, per-call isolation."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.application.bootstrap import SyntheticBootstrap
|
||||||
|
from prometheus.application.evaluator import PromptEvaluator
|
||||||
|
from prometheus.application.evolution import CircuitBreakerOpen, EvolutionLoop
|
||||||
|
from prometheus.domain.entities import EvalResult, Prompt, SyntheticExample, Trajectory
|
||||||
|
from prometheus.infrastructure.retry import is_transient_error, retry_with_backoff
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Retry utility
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsTransientError:
|
||||||
|
def test_rate_limit_429(self):
|
||||||
|
assert is_transient_error(RuntimeError("HTTP 429: rate limit exceeded"))
|
||||||
|
|
||||||
|
def test_server_error_503(self):
|
||||||
|
assert is_transient_error(RuntimeError("503 Service Unavailable"))
|
||||||
|
|
||||||
|
def test_timeout(self):
|
||||||
|
assert is_transient_error(TimeoutError("request timed out"))
|
||||||
|
|
||||||
|
def test_connection_error(self):
|
||||||
|
assert is_transient_error(ConnectionError("connection refused"))
|
||||||
|
|
||||||
|
def test_non_transient(self):
|
||||||
|
assert not is_transient_error(ValueError("bad input"))
|
||||||
|
|
||||||
|
def test_os_error(self):
|
||||||
|
assert is_transient_error(OSError("network unreachable"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetryWithBackoff:
|
||||||
|
def test_succeeds_on_first_try(self):
|
||||||
|
fn = MagicMock(return_value="ok")
|
||||||
|
result = retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
|
||||||
|
assert result == "ok"
|
||||||
|
assert fn.call_count == 1
|
||||||
|
|
||||||
|
def test_retries_on_transient_then_succeeds(self):
|
||||||
|
fn = MagicMock(
|
||||||
|
side_effect=[
|
||||||
|
RuntimeError("429 rate limit"),
|
||||||
|
RuntimeError("429 rate limit"),
|
||||||
|
"ok",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
with patch("prometheus.infrastructure.retry.time.sleep"):
|
||||||
|
result = retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
|
||||||
|
assert result == "ok"
|
||||||
|
assert fn.call_count == 3
|
||||||
|
|
||||||
|
def test_raises_after_max_retries(self):
|
||||||
|
fn = MagicMock(side_effect=RuntimeError("503 overloaded"))
|
||||||
|
with patch("prometheus.infrastructure.retry.time.sleep"):
|
||||||
|
with pytest.raises(RuntimeError, match="503"):
|
||||||
|
retry_with_backoff(fn, max_retries=2, retry_delay_base=0)
|
||||||
|
assert fn.call_count == 3 # 1 initial + 2 retries
|
||||||
|
|
||||||
|
def test_non_transient_not_retried(self):
|
||||||
|
fn = MagicMock(side_effect=ValueError("bad"))
|
||||||
|
with pytest.raises(ValueError, match="bad"):
|
||||||
|
retry_with_backoff(fn, max_retries=3, retry_delay_base=0)
|
||||||
|
assert fn.call_count == 1
|
||||||
|
|
||||||
|
def test_exponential_backoff_delays(self):
|
||||||
|
fn = MagicMock(side_effect=[RuntimeError("timeout"), "ok"])
|
||||||
|
with patch("prometheus.infrastructure.retry.time.sleep") as mock_sleep:
|
||||||
|
retry_with_backoff(fn, max_retries=3, retry_delay_base=2.0)
|
||||||
|
mock_sleep.assert_called_once_with(2.0) # 2.0 * 2^0 = 2.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Circuit breaker (EvolutionLoop)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_eval_result(scores, feedbacks=None):
|
||||||
|
"""Helper to create EvalResult with matching trajectories."""
|
||||||
|
feedbacks = feedbacks or ["ok"] * len(scores)
|
||||||
|
return EvalResult(
|
||||||
|
scores=scores,
|
||||||
|
feedbacks=feedbacks,
|
||||||
|
trajectories=[
|
||||||
|
Trajectory(f"input{i}", f"output{i}", s, f, "prompt")
|
||||||
|
for i, (s, f) in enumerate(zip(scores, feedbacks))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCircuitBreaker:
|
||||||
|
def test_trips_on_consecutive_failures(self):
|
||||||
|
"""Loop stops when consecutive failures reach the threshold."""
|
||||||
|
initial_eval = _make_eval_result([0.3, 0.4])
|
||||||
|
evaluator = MagicMock()
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def _evaluate(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return initial_eval # seed eval succeeds
|
||||||
|
raise RuntimeError("LLM down")
|
||||||
|
|
||||||
|
evaluator.evaluate.side_effect = _evaluate
|
||||||
|
proposer = MagicMock()
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=10,
|
||||||
|
minibatch_size=2,
|
||||||
|
circuit_breaker_threshold=3,
|
||||||
|
error_strategy="skip",
|
||||||
|
)
|
||||||
|
with patch.object(loop, "_log"):
|
||||||
|
state = loop.run(
|
||||||
|
Prompt("test"),
|
||||||
|
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||||
|
"task",
|
||||||
|
)
|
||||||
|
|
||||||
|
error_events = [h for h in state.history if h.get("event") == "error"]
|
||||||
|
cb_events = [h for h in state.history if h.get("event") == "circuit_breaker"]
|
||||||
|
assert len(error_events) == 3
|
||||||
|
assert len(cb_events) == 1
|
||||||
|
assert state.iteration < 10 # stopped early
|
||||||
|
|
||||||
|
def test_abort_raises_on_first_error(self):
|
||||||
|
"""With error_strategy=abort, the first error raises immediately."""
|
||||||
|
initial_eval = _make_eval_result([0.3, 0.4])
|
||||||
|
evaluator = MagicMock()
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def _evaluate(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return initial_eval
|
||||||
|
raise RuntimeError("LLM down")
|
||||||
|
|
||||||
|
evaluator.evaluate.side_effect = _evaluate
|
||||||
|
proposer = MagicMock()
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=10,
|
||||||
|
minibatch_size=2,
|
||||||
|
circuit_breaker_threshold=3,
|
||||||
|
error_strategy="abort",
|
||||||
|
)
|
||||||
|
with patch.object(loop, "_log"):
|
||||||
|
with pytest.raises(RuntimeError, match="LLM down"):
|
||||||
|
loop.run(
|
||||||
|
Prompt("test"),
|
||||||
|
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||||
|
"task",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_resets_on_success(self):
|
||||||
|
"""Consecutive failure counter resets after a successful iteration."""
|
||||||
|
initial_eval = _make_eval_result([0.3, 0.4])
|
||||||
|
good_eval = _make_eval_result([0.8, 0.9])
|
||||||
|
|
||||||
|
evaluator = MagicMock()
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def _evaluate(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
# call 1: seed eval → succeed
|
||||||
|
# call 2: iter 1 current eval → fail
|
||||||
|
# call 3: iter 2 current eval → fail
|
||||||
|
# call 4: iter 3 current eval → succeed (returns initial_eval)
|
||||||
|
# call 5: iter 3 new eval → succeed (returns good_eval, accepted)
|
||||||
|
# call 6+: iter 4+ current eval → succeed
|
||||||
|
if call_count == 1:
|
||||||
|
return initial_eval
|
||||||
|
if call_count in (2, 3):
|
||||||
|
raise RuntimeError("timeout")
|
||||||
|
if call_count % 2 == 0:
|
||||||
|
return initial_eval # current eval
|
||||||
|
return good_eval # new eval
|
||||||
|
|
||||||
|
evaluator.evaluate.side_effect = _evaluate
|
||||||
|
proposer = MagicMock()
|
||||||
|
proposer.propose.return_value = Prompt("better prompt")
|
||||||
|
bootstrap = MagicMock(spec=SyntheticBootstrap)
|
||||||
|
bootstrap.sample_minibatch.return_value = [
|
||||||
|
SyntheticExample(f"in{i}", id=i) for i in range(2)
|
||||||
|
]
|
||||||
|
|
||||||
|
loop = EvolutionLoop(
|
||||||
|
evaluator=evaluator,
|
||||||
|
proposer=proposer,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
max_iterations=6,
|
||||||
|
minibatch_size=2,
|
||||||
|
circuit_breaker_threshold=3,
|
||||||
|
error_strategy="skip",
|
||||||
|
)
|
||||||
|
with patch.object(loop, "_log"):
|
||||||
|
state = loop.run(
|
||||||
|
Prompt("test"),
|
||||||
|
[SyntheticExample("in", id=0), SyntheticExample("in2", id=1)],
|
||||||
|
"task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should NOT have tripped — 2 fails, then success reset the counter
|
||||||
|
cb_events = [h for h in state.history if h.get("event") == "circuit_breaker"]
|
||||||
|
assert len(cb_events) == 0
|
||||||
|
assert state.iteration == 6 # ran all iterations
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-call isolation (Evaluator)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerCallIsolation:
|
||||||
|
def test_evaluator_isolates_execution_failure(self):
|
||||||
|
"""A failing execution produces a sentinel output, not a crash."""
|
||||||
|
executor = MagicMock()
|
||||||
|
executor.execute.side_effect = [
|
||||||
|
"good output",
|
||||||
|
RuntimeError("API error"),
|
||||||
|
"another good output",
|
||||||
|
]
|
||||||
|
judge = MagicMock()
|
||||||
|
judge.judge_batch.return_value = [
|
||||||
|
(0.8, "good"),
|
||||||
|
(0.0, "[judge error]"),
|
||||||
|
(0.7, "ok"),
|
||||||
|
]
|
||||||
|
|
||||||
|
evaluator = PromptEvaluator(executor, judge)
|
||||||
|
result = evaluator.evaluate(
|
||||||
|
Prompt("test"),
|
||||||
|
[
|
||||||
|
SyntheticExample("in0", id=0),
|
||||||
|
SyntheticExample("in1", id=1),
|
||||||
|
SyntheticExample("in2", id=2),
|
||||||
|
],
|
||||||
|
"task",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.scores) == 3
|
||||||
|
assert result.scores[1] == 0.0 # failed item got zero score
|
||||||
|
assert "execution error" in result.trajectories[1].output_text
|
||||||
|
assert result.scores[0] == 0.8 # other items unaffected
|
||||||
|
|
||||||
|
def test_judge_adapter_isolates_single_failure(self):
|
||||||
|
"""DSPyJudgeAdapter returns sentinel for a failed item, not crash."""
|
||||||
|
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||||
|
|
||||||
|
adapter = DSPyJudgeAdapter.__new__(DSPyJudgeAdapter)
|
||||||
|
adapter._lm = MagicMock()
|
||||||
|
adapter._max_retries = 1
|
||||||
|
adapter._retry_delay_base = 0
|
||||||
|
|
||||||
|
# Mock _judge to fail on first call, succeed on second
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
class FakePred:
|
||||||
|
def __init__(self):
|
||||||
|
self.score = 0.9
|
||||||
|
self.feedback = "good"
|
||||||
|
|
||||||
|
def fake_judge(**kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
raise RuntimeError("judge failure")
|
||||||
|
return FakePred()
|
||||||
|
|
||||||
|
adapter._judge = fake_judge
|
||||||
|
|
||||||
|
with patch("prometheus.infrastructure.judge_adapter.dspy.context"):
|
||||||
|
with patch(
|
||||||
|
"prometheus.infrastructure.retry.time.sleep"
|
||||||
|
):
|
||||||
|
results = adapter.judge_batch(
|
||||||
|
"task", [("input1", "output1"), ("input2", "output2")]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
# First item failed even after retry → sentinel
|
||||||
|
assert results[0] == (0.0, "[judge error: judge failure]")
|
||||||
|
# Second item succeeded
|
||||||
|
assert results[1] == (0.9, "good")
|
||||||
Reference in New Issue
Block a user