fix: multi-model routing — each adapter uses own dspy.LM instance

- DSPyLLMAdapter now accepts dspy.LM instead of model string, uses dspy.context(lm=...)
- DSPyJudgeAdapter, DSPyProposerAdapter, DSPySyntheticAdapter each accept and use own LM
- OptimizationConfig gains per-model api_base/api_key_env override fields
- cli/app.py creates separate dspy.LM per adapter with per-model overrides
- New unit tests verify each adapter isolates its LM from global config

Fixes Bug #1 (multi-model config not wired) and Bug #2 (DSPyLLMAdapter ignores model param).

Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
FullStackDev
2026-03-29 12:31:48 +00:00
parent 837a44970f
commit f516ca4be6
8 changed files with 306 additions and 41 deletions

View File

@@ -19,6 +19,16 @@ class OptimizationConfig:
proposer_model: str = "openai/gpt-4o"
synth_model: str = "openai/gpt-4o"
# --- Per-model API overrides (optional, fall back to global api_base/api_key_env) ---
task_api_base: str | None = None
task_api_key_env: str | None = None
judge_api_base: str | None = None
judge_api_key_env: str | None = None
proposer_api_base: str | None = None
proposer_api_key_env: str | None = None
synth_api_base: str | None = None
synth_api_key_env: str | None = None
# --- Evolution parameters ---
max_iterations: int = 30
n_synthetic_inputs: int = 20

View File

@@ -76,6 +76,26 @@ def optimize(
# 1. Load config
persistence = YamlPersistence()
raw_config = persistence.read_config(input)
def _model_lm_kwargs(
model_api_base: str | None,
model_api_key_env: str | None,
global_api_base: str | None,
global_api_key_env: str | None,
) -> dict:
"""Build kwargs for dspy.LM, using per-model overrides with global fallback."""
kwargs: dict = {}
api_base = model_api_base or global_api_base
api_key_env = model_api_key_env or global_api_key_env
if api_base:
kwargs["api_base"] = api_base
if api_key_env:
kwargs["api_key"] = os.environ.get(api_key_env, "")
return kwargs
global_api_base = raw_config.get("api_base")
global_api_key_env = raw_config.get("api_key_env")
config = OptimizationConfig(
seed_prompt=raw_config["seed_prompt"],
task_description=raw_config["task_description"],
@@ -83,6 +103,14 @@ def optimize(
judge_model=raw_config.get("judge_model", "openai/gpt-4o"),
proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"),
synth_model=raw_config.get("synth_model", "openai/gpt-4o"),
task_api_base=raw_config.get("task_api_base"),
task_api_key_env=raw_config.get("task_api_key_env"),
judge_api_base=raw_config.get("judge_api_base"),
judge_api_key_env=raw_config.get("judge_api_key_env"),
proposer_api_base=raw_config.get("proposer_api_base"),
proposer_api_key_env=raw_config.get("proposer_api_key_env"),
synth_api_base=raw_config.get("synth_api_base"),
synth_api_key_env=raw_config.get("synth_api_key_env"),
max_iterations=raw_config.get("max_iterations", 30),
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
minibatch_size=raw_config.get("minibatch_size", 5),
@@ -93,22 +121,29 @@ def optimize(
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
# 2. Configure DSPy with optional api_base/api_key from config
lm_kwargs: dict = {}
api_base = raw_config.get("api_base")
api_key_env = raw_config.get("api_key_env")
if api_base:
lm_kwargs["api_base"] = api_base
if api_key_env:
lm_kwargs["api_key"] = os.environ.get(api_key_env, "")
task_lm = dspy.LM(config.task_model, **lm_kwargs)
dspy.configure(lm=task_lm)
# 2. Create per-model DSPy LM instances
task_lm = dspy.LM(
config.task_model,
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env, global_api_base, global_api_key_env),
)
judge_lm = dspy.LM(
config.judge_model,
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env, global_api_base, global_api_key_env),
)
proposer_lm = dspy.LM(
config.proposer_model,
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env, global_api_base, global_api_key_env),
)
synth_lm = dspy.LM(
config.synth_model,
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env),
)
# 3. Build adapters (Dependency Injection)
synth_adapter = DSPySyntheticAdapter()
llm_adapter = DSPyLLMAdapter(model=config.task_model)
judge_adapter = DSPyJudgeAdapter()
proposer_adapter = DSPyProposerAdapter()
# 3. Build adapters (Dependency Injection — each gets its own LM)
synth_adapter = DSPySyntheticAdapter(lm=synth_lm)
llm_adapter = DSPyLLMAdapter(lm=task_lm)
judge_adapter = DSPyJudgeAdapter(lm=judge_lm)
proposer_adapter = DSPyProposerAdapter(lm=proposer_lm)
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter)
use_case = OptimizePromptUseCase(

View File

@@ -5,6 +5,8 @@ Implements the JudgePort via the DSPy OutputJudge module.
"""
from __future__ import annotations
import dspy
from prometheus.domain.ports import JudgePort
from prometheus.infrastructure.dspy_modules import OutputJudge
@@ -15,7 +17,8 @@ class DSPyJudgeAdapter(JudgePort):
Sequential for MVP. Future: parallelize via dspy.Parallel.
"""
def __init__(self) -> None:
def __init__(self, lm: dspy.LM) -> None:
self._lm = lm
self._judge = OutputJudge()
def judge_batch(
@@ -24,11 +27,12 @@ class DSPyJudgeAdapter(JudgePort):
pairs: list[tuple[str, str]],
) -> list[tuple[float, str]]:
results: list[tuple[float, str]] = []
for input_text, output_text in pairs:
pred = self._judge(
task_description=task_description,
input_text=input_text,
output_text=output_text,
)
results.append((pred.score, pred.feedback))
with dspy.context(lm=self._lm):
for input_text, output_text in pairs:
pred = self._judge(
task_description=task_description,
input_text=input_text,
output_text=output_text,
)
results.append((pred.score, pred.feedback))
return results

View File

@@ -21,12 +21,14 @@ class DSPyLLMAdapter(LLMPort):
input_text: str = dspy.InputField(desc="The input to process.")
output: str = dspy.OutputField(desc="The response following the instruction.")
def __init__(self, model: str) -> None:
def __init__(self, lm: dspy.LM) -> None:
self._lm = lm
self._predictor = dspy.Predict(self._ExecuteSignature)
def execute(self, prompt: Prompt, input_text: str) -> str:
result = self._predictor(
instruction=prompt.text,
input_text=input_text,
)
with dspy.context(lm=self._lm):
result = self._predictor(
instruction=prompt.text,
input_text=input_text,
)
return str(result.output)

View File

@@ -6,6 +6,8 @@ Converts trajectories into readable format for the LLM proposer.
"""
from __future__ import annotations
import dspy
from prometheus.domain.entities import Prompt, Trajectory
from prometheus.domain.ports import ProposerPort
from prometheus.infrastructure.dspy_modules import InstructionProposer
@@ -14,7 +16,8 @@ from prometheus.infrastructure.dspy_modules import InstructionProposer
class DSPyProposerAdapter(ProposerPort):
"""Uses evaluation trajectories to build a failure report and propose a new prompt."""
def __init__(self) -> None:
def __init__(self, lm: dspy.LM) -> None:
self._lm = lm
self._proposer = InstructionProposer()
def propose(
@@ -24,11 +27,12 @@ class DSPyProposerAdapter(ProposerPort):
task_description: str,
) -> Prompt:
failure_examples = self._format_failures(trajectories)
pred = self._proposer(
current_instruction=current_prompt.text,
task_description=task_description,
failure_examples=failure_examples,
)
with dspy.context(lm=self._lm):
pred = self._proposer(
current_instruction=current_prompt.text,
task_description=task_description,
failure_examples=failure_examples,
)
return Prompt(text=pred.new_instruction)
@staticmethod

View File

@@ -5,6 +5,8 @@ Implements the SyntheticGeneratorPort via DSPy.
"""
from __future__ import annotations
import dspy
from prometheus.domain.entities import SyntheticExample
from prometheus.domain.ports import SyntheticGeneratorPort
from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator
@@ -13,7 +15,8 @@ from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator
class DSPySyntheticAdapter(SyntheticGeneratorPort):
"""Generates synthetic inputs in a single batch call via DSPy."""
def __init__(self) -> None:
def __init__(self, lm: dspy.LM) -> None:
self._lm = lm
self._generator = SyntheticInputGenerator()
def generate_inputs(
@@ -21,10 +24,11 @@ class DSPySyntheticAdapter(SyntheticGeneratorPort):
task_description: str,
n_examples: int,
) -> list[SyntheticExample]:
pred = self._generator(
task_description=task_description,
n_examples=n_examples,
)
with dspy.context(lm=self._lm):
pred = self._generator(
task_description=task_description,
n_examples=n_examples,
)
return [
SyntheticExample(
input_text=text,