From f516ca4be651231182f2191557e1cb21ddcafd6e Mon Sep 17 00:00:00 2001 From: FullStackDev Date: Sun, 29 Mar 2026 12:31:48 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20multi-model=20routing=20=E2=80=94=20each?= =?UTF-8?q?=20adapter=20uses=20own=20dspy.LM=20instance?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - DSPyLLMAdapter now accepts dspy.LM instead of model string, uses dspy.context(lm=...) - DSPyJudgeAdapter, DSPyProposerAdapter, DSPySyntheticAdapter each accept and use own LM - OptimizationConfig gains per-model api_base/api_key_env override fields - cli/app.py creates separate dspy.LM per adapter with per-model overrides - New unit tests verify each adapter isolates its LM from global config Fixes Bug #1 (multi-model config not wired) and Bug #2 (DSPyLLMAdapter ignores model param). Co-Authored-By: Paperclip --- src/prometheus/application/dto.py | 10 + src/prometheus/cli/app.py | 65 ++++-- .../infrastructure/judge_adapter.py | 20 +- src/prometheus/infrastructure/llm_adapter.py | 12 +- .../infrastructure/proposer_adapter.py | 16 +- .../infrastructure/synth_adapter.py | 14 +- tests/integration/test_dspy_adapters.py | 3 +- tests/unit/test_adapter_config.py | 207 ++++++++++++++++++ 8 files changed, 306 insertions(+), 41 deletions(-) create mode 100644 tests/unit/test_adapter_config.py diff --git a/src/prometheus/application/dto.py b/src/prometheus/application/dto.py index 3752f4f..774f444 100644 --- a/src/prometheus/application/dto.py +++ b/src/prometheus/application/dto.py @@ -19,6 +19,16 @@ class OptimizationConfig: proposer_model: str = "openai/gpt-4o" synth_model: str = "openai/gpt-4o" + # --- Per-model API overrides (optional, fall back to global api_base/api_key_env) --- + task_api_base: str | None = None + task_api_key_env: str | None = None + judge_api_base: str | None = None + judge_api_key_env: str | None = None + proposer_api_base: str | None = None + proposer_api_key_env: str | None = None + synth_api_base: str | None = None + synth_api_key_env: str | None = None + # --- Evolution parameters --- max_iterations: int = 30 n_synthetic_inputs: int = 20 diff --git a/src/prometheus/cli/app.py b/src/prometheus/cli/app.py index 4b08ce9..dcf5e8c 100644 --- a/src/prometheus/cli/app.py +++ b/src/prometheus/cli/app.py @@ -76,6 +76,26 @@ def optimize( # 1. Load config persistence = YamlPersistence() raw_config = persistence.read_config(input) + + def _model_lm_kwargs( + model_api_base: str | None, + model_api_key_env: str | None, + global_api_base: str | None, + global_api_key_env: str | None, + ) -> dict: + """Build kwargs for dspy.LM, using per-model overrides with global fallback.""" + kwargs: dict = {} + api_base = model_api_base or global_api_base + api_key_env = model_api_key_env or global_api_key_env + if api_base: + kwargs["api_base"] = api_base + if api_key_env: + kwargs["api_key"] = os.environ.get(api_key_env, "") + return kwargs + + global_api_base = raw_config.get("api_base") + global_api_key_env = raw_config.get("api_key_env") + config = OptimizationConfig( seed_prompt=raw_config["seed_prompt"], task_description=raw_config["task_description"], @@ -83,6 +103,14 @@ def optimize( judge_model=raw_config.get("judge_model", "openai/gpt-4o"), proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"), synth_model=raw_config.get("synth_model", "openai/gpt-4o"), + task_api_base=raw_config.get("task_api_base"), + task_api_key_env=raw_config.get("task_api_key_env"), + judge_api_base=raw_config.get("judge_api_base"), + judge_api_key_env=raw_config.get("judge_api_key_env"), + proposer_api_base=raw_config.get("proposer_api_base"), + proposer_api_key_env=raw_config.get("proposer_api_key_env"), + synth_api_base=raw_config.get("synth_api_base"), + synth_api_key_env=raw_config.get("synth_api_key_env"), max_iterations=raw_config.get("max_iterations", 30), n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20), minibatch_size=raw_config.get("minibatch_size", 5), @@ -93,22 +121,29 @@ def optimize( console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]") console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]") - # 2. Configure DSPy with optional api_base/api_key from config - lm_kwargs: dict = {} - api_base = raw_config.get("api_base") - api_key_env = raw_config.get("api_key_env") - if api_base: - lm_kwargs["api_base"] = api_base - if api_key_env: - lm_kwargs["api_key"] = os.environ.get(api_key_env, "") - task_lm = dspy.LM(config.task_model, **lm_kwargs) - dspy.configure(lm=task_lm) + # 2. Create per-model DSPy LM instances + task_lm = dspy.LM( + config.task_model, + **_model_lm_kwargs(config.task_api_base, config.task_api_key_env, global_api_base, global_api_key_env), + ) + judge_lm = dspy.LM( + config.judge_model, + **_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env, global_api_base, global_api_key_env), + ) + proposer_lm = dspy.LM( + config.proposer_model, + **_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env, global_api_base, global_api_key_env), + ) + synth_lm = dspy.LM( + config.synth_model, + **_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env), + ) - # 3. Build adapters (Dependency Injection) - synth_adapter = DSPySyntheticAdapter() - llm_adapter = DSPyLLMAdapter(model=config.task_model) - judge_adapter = DSPyJudgeAdapter() - proposer_adapter = DSPyProposerAdapter() + # 3. Build adapters (Dependency Injection — each gets its own LM) + synth_adapter = DSPySyntheticAdapter(lm=synth_lm) + llm_adapter = DSPyLLMAdapter(lm=task_lm) + judge_adapter = DSPyJudgeAdapter(lm=judge_lm) + proposer_adapter = DSPyProposerAdapter(lm=proposer_lm) bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed) evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter) use_case = OptimizePromptUseCase( diff --git a/src/prometheus/infrastructure/judge_adapter.py b/src/prometheus/infrastructure/judge_adapter.py index 32b02dc..c90dcb3 100644 --- a/src/prometheus/infrastructure/judge_adapter.py +++ b/src/prometheus/infrastructure/judge_adapter.py @@ -5,6 +5,8 @@ Implements the JudgePort via the DSPy OutputJudge module. """ from __future__ import annotations +import dspy + from prometheus.domain.ports import JudgePort from prometheus.infrastructure.dspy_modules import OutputJudge @@ -15,7 +17,8 @@ class DSPyJudgeAdapter(JudgePort): Sequential for MVP. Future: parallelize via dspy.Parallel. """ - def __init__(self) -> None: + def __init__(self, lm: dspy.LM) -> None: + self._lm = lm self._judge = OutputJudge() def judge_batch( @@ -24,11 +27,12 @@ class DSPyJudgeAdapter(JudgePort): pairs: list[tuple[str, str]], ) -> list[tuple[float, str]]: results: list[tuple[float, str]] = [] - for input_text, output_text in pairs: - pred = self._judge( - task_description=task_description, - input_text=input_text, - output_text=output_text, - ) - results.append((pred.score, pred.feedback)) + with dspy.context(lm=self._lm): + for input_text, output_text in pairs: + pred = self._judge( + task_description=task_description, + input_text=input_text, + output_text=output_text, + ) + results.append((pred.score, pred.feedback)) return results diff --git a/src/prometheus/infrastructure/llm_adapter.py b/src/prometheus/infrastructure/llm_adapter.py index 5085b87..08e054c 100644 --- a/src/prometheus/infrastructure/llm_adapter.py +++ b/src/prometheus/infrastructure/llm_adapter.py @@ -21,12 +21,14 @@ class DSPyLLMAdapter(LLMPort): input_text: str = dspy.InputField(desc="The input to process.") output: str = dspy.OutputField(desc="The response following the instruction.") - def __init__(self, model: str) -> None: + def __init__(self, lm: dspy.LM) -> None: + self._lm = lm self._predictor = dspy.Predict(self._ExecuteSignature) def execute(self, prompt: Prompt, input_text: str) -> str: - result = self._predictor( - instruction=prompt.text, - input_text=input_text, - ) + with dspy.context(lm=self._lm): + result = self._predictor( + instruction=prompt.text, + input_text=input_text, + ) return str(result.output) diff --git a/src/prometheus/infrastructure/proposer_adapter.py b/src/prometheus/infrastructure/proposer_adapter.py index a4adeb8..95ce8ea 100644 --- a/src/prometheus/infrastructure/proposer_adapter.py +++ b/src/prometheus/infrastructure/proposer_adapter.py @@ -6,6 +6,8 @@ Converts trajectories into readable format for the LLM proposer. """ from __future__ import annotations +import dspy + from prometheus.domain.entities import Prompt, Trajectory from prometheus.domain.ports import ProposerPort from prometheus.infrastructure.dspy_modules import InstructionProposer @@ -14,7 +16,8 @@ from prometheus.infrastructure.dspy_modules import InstructionProposer class DSPyProposerAdapter(ProposerPort): """Uses evaluation trajectories to build a failure report and propose a new prompt.""" - def __init__(self) -> None: + def __init__(self, lm: dspy.LM) -> None: + self._lm = lm self._proposer = InstructionProposer() def propose( @@ -24,11 +27,12 @@ class DSPyProposerAdapter(ProposerPort): task_description: str, ) -> Prompt: failure_examples = self._format_failures(trajectories) - pred = self._proposer( - current_instruction=current_prompt.text, - task_description=task_description, - failure_examples=failure_examples, - ) + with dspy.context(lm=self._lm): + pred = self._proposer( + current_instruction=current_prompt.text, + task_description=task_description, + failure_examples=failure_examples, + ) return Prompt(text=pred.new_instruction) @staticmethod diff --git a/src/prometheus/infrastructure/synth_adapter.py b/src/prometheus/infrastructure/synth_adapter.py index 6a4daf4..8a10135 100644 --- a/src/prometheus/infrastructure/synth_adapter.py +++ b/src/prometheus/infrastructure/synth_adapter.py @@ -5,6 +5,8 @@ Implements the SyntheticGeneratorPort via DSPy. """ from __future__ import annotations +import dspy + from prometheus.domain.entities import SyntheticExample from prometheus.domain.ports import SyntheticGeneratorPort from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator @@ -13,7 +15,8 @@ from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator class DSPySyntheticAdapter(SyntheticGeneratorPort): """Generates synthetic inputs in a single batch call via DSPy.""" - def __init__(self) -> None: + def __init__(self, lm: dspy.LM) -> None: + self._lm = lm self._generator = SyntheticInputGenerator() def generate_inputs( @@ -21,10 +24,11 @@ class DSPySyntheticAdapter(SyntheticGeneratorPort): task_description: str, n_examples: int, ) -> list[SyntheticExample]: - pred = self._generator( - task_description=task_description, - n_examples=n_examples, - ) + with dspy.context(lm=self._lm): + pred = self._generator( + task_description=task_description, + n_examples=n_examples, + ) return [ SyntheticExample( input_text=text, diff --git a/tests/integration/test_dspy_adapters.py b/tests/integration/test_dspy_adapters.py index cd1ccfa..dd6f909 100644 --- a/tests/integration/test_dspy_adapters.py +++ b/tests/integration/test_dspy_adapters.py @@ -16,13 +16,12 @@ def mock_lm() -> dspy.LM: {"output": "Mock output response"}, ] ) - dspy.configure(lm=lm) return lm class TestDSPyLLMAdapter: def test_execute_returns_response(self, mock_lm: dspy.LM) -> None: - adapter = DSPyLLMAdapter(model="openai/gpt-4o-mini") + adapter = DSPyLLMAdapter(lm=mock_lm) prompt = Prompt(text="Answer the question.") result = adapter.execute(prompt, "What is 2+2?") assert isinstance(result, str) diff --git a/tests/unit/test_adapter_config.py b/tests/unit/test_adapter_config.py new file mode 100644 index 0000000..37f53e6 --- /dev/null +++ b/tests/unit/test_adapter_config.py @@ -0,0 +1,207 @@ +"""Unit tests for multi-model adapter configuration. + +Verifies that each adapter uses its own dspy.LM instance and +that per-model api_base/api_key_env overrides are wired correctly. +""" +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import dspy +import pytest + +from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory +from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter +from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter +from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter +from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter + + +@pytest.fixture +def task_lm() -> dspy.LM: + """Dummy LM for task execution.""" + return dspy.utils.DummyLM([{"output": "task model output"}]) + + +@pytest.fixture +def judge_lm() -> dspy.LM: + """Dummy LM for judging (ChainOfThought requires reasoning field).""" + return dspy.utils.DummyLM( + [ + {"reasoning": "Evaluating output.", "score": "0.8", "feedback": "Good response."}, + ] + ) + + +@pytest.fixture +def proposer_lm() -> dspy.LM: + """Dummy LM for proposing (ChainOfThought requires reasoning field).""" + return dspy.utils.DummyLM( + [ + {"reasoning": "Analyzing failures.", "new_instruction": "Improved prompt: be more specific."}, + ] + ) + + +@pytest.fixture +def synth_lm() -> dspy.LM: + """Dummy LM for synthetic generation (ChainOfThought requires reasoning field).""" + return dspy.utils.DummyLM( + [ + {"reasoning": "Generating examples.", "examples": json.dumps(["input 1", "input 2", "input 3"])}, + ] + ) + + +class TestDSPyLLMAdapterOwnLM: + """Bug #2 fix: DSPyLLMAdapter must use the LM it receives, not the global one.""" + + def test_uses_provided_lm_not_global(self) -> None: + local_lm = dspy.utils.DummyLM([{"output": "local response"}]) + global_lm = dspy.utils.DummyLM([{"output": "global response"}]) + dspy.configure(lm=global_lm) + + adapter = DSPyLLMAdapter(lm=local_lm) + result = adapter.execute(Prompt(text="test"), "input") + + assert result == "local response" + + def test_does_not_affect_global_lm(self) -> None: + local_lm = dspy.utils.DummyLM([{"output": "local response"}]) + global_lm = dspy.utils.DummyLM([{"output": "global response"}]) + dspy.configure(lm=global_lm) + + adapter = DSPyLLMAdapter(lm=local_lm) + adapter.execute(Prompt(text="test"), "input") + + # Global LM should still be the same + assert dspy.settings.lm is global_lm + + +class TestDSPyJudgeAdapterOwnLM: + """DSPyJudgeAdapter must use its own LM instance.""" + + def test_uses_provided_lm(self, judge_lm: dspy.LM) -> None: + adapter = DSPyJudgeAdapter(lm=judge_lm) + results = adapter.judge_batch( + task_description="Test task", + pairs=[("input 1", "output 1")], + ) + assert len(results) == 1 + score, feedback = results[0] + assert score == 0.8 + assert feedback == "Good response." + + def test_does_not_use_global_lm(self) -> None: + judge_lm = dspy.utils.DummyLM( + [{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response"}] + ) + global_lm = dspy.utils.DummyLM([{"reasoning": "no", "score": "0.1", "feedback": "Wrong LM!"}]) + dspy.configure(lm=global_lm) + + adapter = DSPyJudgeAdapter(lm=judge_lm) + results = adapter.judge_batch("task", [("in", "out")]) + assert results[0][0] == 0.9 + + +class TestDSPyProposerAdapterOwnLM: + """DSPyProposerAdapter must use its own LM instance.""" + + def test_uses_provided_lm(self, proposer_lm: dspy.LM) -> None: + adapter = DSPyProposerAdapter(lm=proposer_lm) + trajectories = [ + Trajectory( + input_text="test input", + output_text="test output", + score=0.3, + feedback="bad", + prompt_used="old prompt", + ) + ] + result = adapter.propose( + current_prompt=Prompt(text="old prompt"), + trajectories=trajectories, + task_description="Test task", + ) + assert "Improved prompt" in result.text + + def test_does_not_use_global_lm(self) -> None: + proposer_lm = dspy.utils.DummyLM( + [{"reasoning": "ok", "new_instruction": "proposer-specific"}] + ) + global_lm = dspy.utils.DummyLM( + [{"reasoning": "no", "new_instruction": "wrong-global"}] + ) + dspy.configure(lm=global_lm) + + adapter = DSPyProposerAdapter(lm=proposer_lm) + result = adapter.propose( + current_prompt=Prompt(text="test"), + trajectories=[], + task_description="task", + ) + assert result.text == "proposer-specific" + + +class TestDSPySyntheticAdapterOwnLM: + """DSPySyntheticAdapter must use its own LM instance.""" + + def test_uses_provided_lm(self, synth_lm: dspy.LM) -> None: + adapter = DSPySyntheticAdapter(lm=synth_lm) + results = adapter.generate_inputs("Test task", 3) + assert len(results) == 3 + assert all(isinstance(ex, SyntheticExample) for ex in results) + + def test_does_not_use_global_lm(self) -> None: + synth_lm = dspy.utils.DummyLM( + [{"reasoning": "ok", "examples": json.dumps(["synth-specific"])}] + ) + global_lm = dspy.utils.DummyLM( + [{"reasoning": "no", "examples": json.dumps(["wrong-global"])}] + ) + dspy.configure(lm=global_lm) + + adapter = DSPySyntheticAdapter(lm=synth_lm) + results = adapter.generate_inputs("task", 1) + assert results[0].input_text == "synth-specific" + + +class TestPerModelOverrides: + """Verify that per-model api_base/api_key_env are passed through to dspy.LM.""" + + @patch("prometheus.cli.app.dspy.LM") + def test_per_model_api_base_override(self, mock_lm_cls: MagicMock) -> None: + """Per-model api_base should be used instead of global.""" + mock_lm_cls.return_value = MagicMock() + + from prometheus.application.dto import OptimizationConfig + + config = OptimizationConfig( + seed_prompt="test", + task_description="test", + task_model="openai/gpt-4o-mini", + judge_model="openai/gpt-4o", + proposer_model="openai/gpt-4o", + synth_model="openai/gpt-4o", + judge_api_base="https://judge.example.com/v1", + judge_api_key_env="JUDGE_API_KEY", + ) + + # Verify config carries the overrides + assert config.judge_api_base == "https://judge.example.com/v1" + assert config.judge_api_key_env == "JUDGE_API_KEY" + assert config.task_api_base is None + + def test_config_defaults_to_none(self) -> None: + from prometheus.application.dto import OptimizationConfig + + config = OptimizationConfig(seed_prompt="test", task_description="test") + assert config.task_api_base is None + assert config.task_api_key_env is None + assert config.judge_api_base is None + assert config.judge_api_key_env is None + assert config.proposer_api_base is None + assert config.proposer_api_key_env is None + assert config.synth_api_base is None + assert config.synth_api_key_env is None