fix: multi-model routing — each adapter uses own dspy.LM instance
- DSPyLLMAdapter now accepts dspy.LM instead of model string, uses dspy.context(lm=...) - DSPyJudgeAdapter, DSPyProposerAdapter, DSPySyntheticAdapter each accept and use own LM - OptimizationConfig gains per-model api_base/api_key_env override fields - cli/app.py creates separate dspy.LM per adapter with per-model overrides - New unit tests verify each adapter isolates its LM from global config Fixes Bug #1 (multi-model config not wired) and Bug #2 (DSPyLLMAdapter ignores model param). Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
207
tests/unit/test_adapter_config.py
Normal file
207
tests/unit/test_adapter_config.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Unit tests for multi-model adapter configuration.
|
||||
|
||||
Verifies that each adapter uses its own dspy.LM instance and
|
||||
that per-model api_base/api_key_env overrides are wired correctly.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import dspy
|
||||
import pytest
|
||||
|
||||
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory
|
||||
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
||||
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
||||
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_lm() -> dspy.LM:
|
||||
"""Dummy LM for task execution."""
|
||||
return dspy.utils.DummyLM([{"output": "task model output"}])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def judge_lm() -> dspy.LM:
|
||||
"""Dummy LM for judging (ChainOfThought requires reasoning field)."""
|
||||
return dspy.utils.DummyLM(
|
||||
[
|
||||
{"reasoning": "Evaluating output.", "score": "0.8", "feedback": "Good response."},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def proposer_lm() -> dspy.LM:
|
||||
"""Dummy LM for proposing (ChainOfThought requires reasoning field)."""
|
||||
return dspy.utils.DummyLM(
|
||||
[
|
||||
{"reasoning": "Analyzing failures.", "new_instruction": "Improved prompt: be more specific."},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def synth_lm() -> dspy.LM:
|
||||
"""Dummy LM for synthetic generation (ChainOfThought requires reasoning field)."""
|
||||
return dspy.utils.DummyLM(
|
||||
[
|
||||
{"reasoning": "Generating examples.", "examples": json.dumps(["input 1", "input 2", "input 3"])},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestDSPyLLMAdapterOwnLM:
|
||||
"""Bug #2 fix: DSPyLLMAdapter must use the LM it receives, not the global one."""
|
||||
|
||||
def test_uses_provided_lm_not_global(self) -> None:
|
||||
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
|
||||
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
|
||||
dspy.configure(lm=global_lm)
|
||||
|
||||
adapter = DSPyLLMAdapter(lm=local_lm)
|
||||
result = adapter.execute(Prompt(text="test"), "input")
|
||||
|
||||
assert result == "local response"
|
||||
|
||||
def test_does_not_affect_global_lm(self) -> None:
|
||||
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
|
||||
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
|
||||
dspy.configure(lm=global_lm)
|
||||
|
||||
adapter = DSPyLLMAdapter(lm=local_lm)
|
||||
adapter.execute(Prompt(text="test"), "input")
|
||||
|
||||
# Global LM should still be the same
|
||||
assert dspy.settings.lm is global_lm
|
||||
|
||||
|
||||
class TestDSPyJudgeAdapterOwnLM:
|
||||
"""DSPyJudgeAdapter must use its own LM instance."""
|
||||
|
||||
def test_uses_provided_lm(self, judge_lm: dspy.LM) -> None:
|
||||
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||
results = adapter.judge_batch(
|
||||
task_description="Test task",
|
||||
pairs=[("input 1", "output 1")],
|
||||
)
|
||||
assert len(results) == 1
|
||||
score, feedback = results[0]
|
||||
assert score == 0.8
|
||||
assert feedback == "Good response."
|
||||
|
||||
def test_does_not_use_global_lm(self) -> None:
|
||||
judge_lm = dspy.utils.DummyLM(
|
||||
[{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response"}]
|
||||
)
|
||||
global_lm = dspy.utils.DummyLM([{"reasoning": "no", "score": "0.1", "feedback": "Wrong LM!"}])
|
||||
dspy.configure(lm=global_lm)
|
||||
|
||||
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||
results = adapter.judge_batch("task", [("in", "out")])
|
||||
assert results[0][0] == 0.9
|
||||
|
||||
|
||||
class TestDSPyProposerAdapterOwnLM:
|
||||
"""DSPyProposerAdapter must use its own LM instance."""
|
||||
|
||||
def test_uses_provided_lm(self, proposer_lm: dspy.LM) -> None:
|
||||
adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||
trajectories = [
|
||||
Trajectory(
|
||||
input_text="test input",
|
||||
output_text="test output",
|
||||
score=0.3,
|
||||
feedback="bad",
|
||||
prompt_used="old prompt",
|
||||
)
|
||||
]
|
||||
result = adapter.propose(
|
||||
current_prompt=Prompt(text="old prompt"),
|
||||
trajectories=trajectories,
|
||||
task_description="Test task",
|
||||
)
|
||||
assert "Improved prompt" in result.text
|
||||
|
||||
def test_does_not_use_global_lm(self) -> None:
|
||||
proposer_lm = dspy.utils.DummyLM(
|
||||
[{"reasoning": "ok", "new_instruction": "proposer-specific"}]
|
||||
)
|
||||
global_lm = dspy.utils.DummyLM(
|
||||
[{"reasoning": "no", "new_instruction": "wrong-global"}]
|
||||
)
|
||||
dspy.configure(lm=global_lm)
|
||||
|
||||
adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||
result = adapter.propose(
|
||||
current_prompt=Prompt(text="test"),
|
||||
trajectories=[],
|
||||
task_description="task",
|
||||
)
|
||||
assert result.text == "proposer-specific"
|
||||
|
||||
|
||||
class TestDSPySyntheticAdapterOwnLM:
|
||||
"""DSPySyntheticAdapter must use its own LM instance."""
|
||||
|
||||
def test_uses_provided_lm(self, synth_lm: dspy.LM) -> None:
|
||||
adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||
results = adapter.generate_inputs("Test task", 3)
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(ex, SyntheticExample) for ex in results)
|
||||
|
||||
def test_does_not_use_global_lm(self) -> None:
|
||||
synth_lm = dspy.utils.DummyLM(
|
||||
[{"reasoning": "ok", "examples": json.dumps(["synth-specific"])}]
|
||||
)
|
||||
global_lm = dspy.utils.DummyLM(
|
||||
[{"reasoning": "no", "examples": json.dumps(["wrong-global"])}]
|
||||
)
|
||||
dspy.configure(lm=global_lm)
|
||||
|
||||
adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||
results = adapter.generate_inputs("task", 1)
|
||||
assert results[0].input_text == "synth-specific"
|
||||
|
||||
|
||||
class TestPerModelOverrides:
|
||||
"""Verify that per-model api_base/api_key_env are passed through to dspy.LM."""
|
||||
|
||||
@patch("prometheus.cli.app.dspy.LM")
|
||||
def test_per_model_api_base_override(self, mock_lm_cls: MagicMock) -> None:
|
||||
"""Per-model api_base should be used instead of global."""
|
||||
mock_lm_cls.return_value = MagicMock()
|
||||
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
|
||||
config = OptimizationConfig(
|
||||
seed_prompt="test",
|
||||
task_description="test",
|
||||
task_model="openai/gpt-4o-mini",
|
||||
judge_model="openai/gpt-4o",
|
||||
proposer_model="openai/gpt-4o",
|
||||
synth_model="openai/gpt-4o",
|
||||
judge_api_base="https://judge.example.com/v1",
|
||||
judge_api_key_env="JUDGE_API_KEY",
|
||||
)
|
||||
|
||||
# Verify config carries the overrides
|
||||
assert config.judge_api_base == "https://judge.example.com/v1"
|
||||
assert config.judge_api_key_env == "JUDGE_API_KEY"
|
||||
assert config.task_api_base is None
|
||||
|
||||
def test_config_defaults_to_none(self) -> None:
|
||||
from prometheus.application.dto import OptimizationConfig
|
||||
|
||||
config = OptimizationConfig(seed_prompt="test", task_description="test")
|
||||
assert config.task_api_base is None
|
||||
assert config.task_api_key_env is None
|
||||
assert config.judge_api_base is None
|
||||
assert config.judge_api_key_env is None
|
||||
assert config.proposer_api_base is None
|
||||
assert config.proposer_api_key_env is None
|
||||
assert config.synth_api_base is None
|
||||
assert config.synth_api_key_env is None
|
||||
Reference in New Issue
Block a user