fix: multi-model routing — each adapter uses own dspy.LM instance
- DSPyLLMAdapter now accepts dspy.LM instead of model string, uses dspy.context(lm=...) - DSPyJudgeAdapter, DSPyProposerAdapter, DSPySyntheticAdapter each accept and use own LM - OptimizationConfig gains per-model api_base/api_key_env override fields - cli/app.py creates separate dspy.LM per adapter with per-model overrides - New unit tests verify each adapter isolates its LM from global config Fixes Bug #1 (multi-model config not wired) and Bug #2 (DSPyLLMAdapter ignores model param). Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -19,6 +19,16 @@ class OptimizationConfig:
|
|||||||
proposer_model: str = "openai/gpt-4o"
|
proposer_model: str = "openai/gpt-4o"
|
||||||
synth_model: str = "openai/gpt-4o"
|
synth_model: str = "openai/gpt-4o"
|
||||||
|
|
||||||
|
# --- Per-model API overrides (optional, fall back to global api_base/api_key_env) ---
|
||||||
|
task_api_base: str | None = None
|
||||||
|
task_api_key_env: str | None = None
|
||||||
|
judge_api_base: str | None = None
|
||||||
|
judge_api_key_env: str | None = None
|
||||||
|
proposer_api_base: str | None = None
|
||||||
|
proposer_api_key_env: str | None = None
|
||||||
|
synth_api_base: str | None = None
|
||||||
|
synth_api_key_env: str | None = None
|
||||||
|
|
||||||
# --- Evolution parameters ---
|
# --- Evolution parameters ---
|
||||||
max_iterations: int = 30
|
max_iterations: int = 30
|
||||||
n_synthetic_inputs: int = 20
|
n_synthetic_inputs: int = 20
|
||||||
|
|||||||
@@ -76,6 +76,26 @@ def optimize(
|
|||||||
# 1. Load config
|
# 1. Load config
|
||||||
persistence = YamlPersistence()
|
persistence = YamlPersistence()
|
||||||
raw_config = persistence.read_config(input)
|
raw_config = persistence.read_config(input)
|
||||||
|
|
||||||
|
def _model_lm_kwargs(
|
||||||
|
model_api_base: str | None,
|
||||||
|
model_api_key_env: str | None,
|
||||||
|
global_api_base: str | None,
|
||||||
|
global_api_key_env: str | None,
|
||||||
|
) -> dict:
|
||||||
|
"""Build kwargs for dspy.LM, using per-model overrides with global fallback."""
|
||||||
|
kwargs: dict = {}
|
||||||
|
api_base = model_api_base or global_api_base
|
||||||
|
api_key_env = model_api_key_env or global_api_key_env
|
||||||
|
if api_base:
|
||||||
|
kwargs["api_base"] = api_base
|
||||||
|
if api_key_env:
|
||||||
|
kwargs["api_key"] = os.environ.get(api_key_env, "")
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
global_api_base = raw_config.get("api_base")
|
||||||
|
global_api_key_env = raw_config.get("api_key_env")
|
||||||
|
|
||||||
config = OptimizationConfig(
|
config = OptimizationConfig(
|
||||||
seed_prompt=raw_config["seed_prompt"],
|
seed_prompt=raw_config["seed_prompt"],
|
||||||
task_description=raw_config["task_description"],
|
task_description=raw_config["task_description"],
|
||||||
@@ -83,6 +103,14 @@ def optimize(
|
|||||||
judge_model=raw_config.get("judge_model", "openai/gpt-4o"),
|
judge_model=raw_config.get("judge_model", "openai/gpt-4o"),
|
||||||
proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"),
|
proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"),
|
||||||
synth_model=raw_config.get("synth_model", "openai/gpt-4o"),
|
synth_model=raw_config.get("synth_model", "openai/gpt-4o"),
|
||||||
|
task_api_base=raw_config.get("task_api_base"),
|
||||||
|
task_api_key_env=raw_config.get("task_api_key_env"),
|
||||||
|
judge_api_base=raw_config.get("judge_api_base"),
|
||||||
|
judge_api_key_env=raw_config.get("judge_api_key_env"),
|
||||||
|
proposer_api_base=raw_config.get("proposer_api_base"),
|
||||||
|
proposer_api_key_env=raw_config.get("proposer_api_key_env"),
|
||||||
|
synth_api_base=raw_config.get("synth_api_base"),
|
||||||
|
synth_api_key_env=raw_config.get("synth_api_key_env"),
|
||||||
max_iterations=raw_config.get("max_iterations", 30),
|
max_iterations=raw_config.get("max_iterations", 30),
|
||||||
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
|
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
|
||||||
minibatch_size=raw_config.get("minibatch_size", 5),
|
minibatch_size=raw_config.get("minibatch_size", 5),
|
||||||
@@ -93,22 +121,29 @@ def optimize(
|
|||||||
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
||||||
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
||||||
|
|
||||||
# 2. Configure DSPy with optional api_base/api_key from config
|
# 2. Create per-model DSPy LM instances
|
||||||
lm_kwargs: dict = {}
|
task_lm = dspy.LM(
|
||||||
api_base = raw_config.get("api_base")
|
config.task_model,
|
||||||
api_key_env = raw_config.get("api_key_env")
|
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env, global_api_base, global_api_key_env),
|
||||||
if api_base:
|
)
|
||||||
lm_kwargs["api_base"] = api_base
|
judge_lm = dspy.LM(
|
||||||
if api_key_env:
|
config.judge_model,
|
||||||
lm_kwargs["api_key"] = os.environ.get(api_key_env, "")
|
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env, global_api_base, global_api_key_env),
|
||||||
task_lm = dspy.LM(config.task_model, **lm_kwargs)
|
)
|
||||||
dspy.configure(lm=task_lm)
|
proposer_lm = dspy.LM(
|
||||||
|
config.proposer_model,
|
||||||
|
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env, global_api_base, global_api_key_env),
|
||||||
|
)
|
||||||
|
synth_lm = dspy.LM(
|
||||||
|
config.synth_model,
|
||||||
|
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env),
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Build adapters (Dependency Injection)
|
# 3. Build adapters (Dependency Injection — each gets its own LM)
|
||||||
synth_adapter = DSPySyntheticAdapter()
|
synth_adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||||
llm_adapter = DSPyLLMAdapter(model=config.task_model)
|
llm_adapter = DSPyLLMAdapter(lm=task_lm)
|
||||||
judge_adapter = DSPyJudgeAdapter()
|
judge_adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||||
proposer_adapter = DSPyProposerAdapter()
|
proposer_adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||||
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
||||||
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter)
|
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter)
|
||||||
use_case = OptimizePromptUseCase(
|
use_case = OptimizePromptUseCase(
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ Implements the JudgePort via the DSPy OutputJudge module.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
|
||||||
from prometheus.domain.ports import JudgePort
|
from prometheus.domain.ports import JudgePort
|
||||||
from prometheus.infrastructure.dspy_modules import OutputJudge
|
from prometheus.infrastructure.dspy_modules import OutputJudge
|
||||||
|
|
||||||
@@ -15,7 +17,8 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
Sequential for MVP. Future: parallelize via dspy.Parallel.
|
Sequential for MVP. Future: parallelize via dspy.Parallel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, lm: dspy.LM) -> None:
|
||||||
|
self._lm = lm
|
||||||
self._judge = OutputJudge()
|
self._judge = OutputJudge()
|
||||||
|
|
||||||
def judge_batch(
|
def judge_batch(
|
||||||
@@ -24,11 +27,12 @@ class DSPyJudgeAdapter(JudgePort):
|
|||||||
pairs: list[tuple[str, str]],
|
pairs: list[tuple[str, str]],
|
||||||
) -> list[tuple[float, str]]:
|
) -> list[tuple[float, str]]:
|
||||||
results: list[tuple[float, str]] = []
|
results: list[tuple[float, str]] = []
|
||||||
for input_text, output_text in pairs:
|
with dspy.context(lm=self._lm):
|
||||||
pred = self._judge(
|
for input_text, output_text in pairs:
|
||||||
task_description=task_description,
|
pred = self._judge(
|
||||||
input_text=input_text,
|
task_description=task_description,
|
||||||
output_text=output_text,
|
input_text=input_text,
|
||||||
)
|
output_text=output_text,
|
||||||
results.append((pred.score, pred.feedback))
|
)
|
||||||
|
results.append((pred.score, pred.feedback))
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -21,12 +21,14 @@ class DSPyLLMAdapter(LLMPort):
|
|||||||
input_text: str = dspy.InputField(desc="The input to process.")
|
input_text: str = dspy.InputField(desc="The input to process.")
|
||||||
output: str = dspy.OutputField(desc="The response following the instruction.")
|
output: str = dspy.OutputField(desc="The response following the instruction.")
|
||||||
|
|
||||||
def __init__(self, model: str) -> None:
|
def __init__(self, lm: dspy.LM) -> None:
|
||||||
|
self._lm = lm
|
||||||
self._predictor = dspy.Predict(self._ExecuteSignature)
|
self._predictor = dspy.Predict(self._ExecuteSignature)
|
||||||
|
|
||||||
def execute(self, prompt: Prompt, input_text: str) -> str:
|
def execute(self, prompt: Prompt, input_text: str) -> str:
|
||||||
result = self._predictor(
|
with dspy.context(lm=self._lm):
|
||||||
instruction=prompt.text,
|
result = self._predictor(
|
||||||
input_text=input_text,
|
instruction=prompt.text,
|
||||||
)
|
input_text=input_text,
|
||||||
|
)
|
||||||
return str(result.output)
|
return str(result.output)
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ Converts trajectories into readable format for the LLM proposer.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
|
||||||
from prometheus.domain.entities import Prompt, Trajectory
|
from prometheus.domain.entities import Prompt, Trajectory
|
||||||
from prometheus.domain.ports import ProposerPort
|
from prometheus.domain.ports import ProposerPort
|
||||||
from prometheus.infrastructure.dspy_modules import InstructionProposer
|
from prometheus.infrastructure.dspy_modules import InstructionProposer
|
||||||
@@ -14,7 +16,8 @@ from prometheus.infrastructure.dspy_modules import InstructionProposer
|
|||||||
class DSPyProposerAdapter(ProposerPort):
|
class DSPyProposerAdapter(ProposerPort):
|
||||||
"""Uses evaluation trajectories to build a failure report and propose a new prompt."""
|
"""Uses evaluation trajectories to build a failure report and propose a new prompt."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, lm: dspy.LM) -> None:
|
||||||
|
self._lm = lm
|
||||||
self._proposer = InstructionProposer()
|
self._proposer = InstructionProposer()
|
||||||
|
|
||||||
def propose(
|
def propose(
|
||||||
@@ -24,11 +27,12 @@ class DSPyProposerAdapter(ProposerPort):
|
|||||||
task_description: str,
|
task_description: str,
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
failure_examples = self._format_failures(trajectories)
|
failure_examples = self._format_failures(trajectories)
|
||||||
pred = self._proposer(
|
with dspy.context(lm=self._lm):
|
||||||
current_instruction=current_prompt.text,
|
pred = self._proposer(
|
||||||
task_description=task_description,
|
current_instruction=current_prompt.text,
|
||||||
failure_examples=failure_examples,
|
task_description=task_description,
|
||||||
)
|
failure_examples=failure_examples,
|
||||||
|
)
|
||||||
return Prompt(text=pred.new_instruction)
|
return Prompt(text=pred.new_instruction)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ Implements the SyntheticGeneratorPort via DSPy.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
|
||||||
from prometheus.domain.entities import SyntheticExample
|
from prometheus.domain.entities import SyntheticExample
|
||||||
from prometheus.domain.ports import SyntheticGeneratorPort
|
from prometheus.domain.ports import SyntheticGeneratorPort
|
||||||
from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator
|
from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator
|
||||||
@@ -13,7 +15,8 @@ from prometheus.infrastructure.dspy_modules import SyntheticInputGenerator
|
|||||||
class DSPySyntheticAdapter(SyntheticGeneratorPort):
|
class DSPySyntheticAdapter(SyntheticGeneratorPort):
|
||||||
"""Generates synthetic inputs in a single batch call via DSPy."""
|
"""Generates synthetic inputs in a single batch call via DSPy."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, lm: dspy.LM) -> None:
|
||||||
|
self._lm = lm
|
||||||
self._generator = SyntheticInputGenerator()
|
self._generator = SyntheticInputGenerator()
|
||||||
|
|
||||||
def generate_inputs(
|
def generate_inputs(
|
||||||
@@ -21,10 +24,11 @@ class DSPySyntheticAdapter(SyntheticGeneratorPort):
|
|||||||
task_description: str,
|
task_description: str,
|
||||||
n_examples: int,
|
n_examples: int,
|
||||||
) -> list[SyntheticExample]:
|
) -> list[SyntheticExample]:
|
||||||
pred = self._generator(
|
with dspy.context(lm=self._lm):
|
||||||
task_description=task_description,
|
pred = self._generator(
|
||||||
n_examples=n_examples,
|
task_description=task_description,
|
||||||
)
|
n_examples=n_examples,
|
||||||
|
)
|
||||||
return [
|
return [
|
||||||
SyntheticExample(
|
SyntheticExample(
|
||||||
input_text=text,
|
input_text=text,
|
||||||
|
|||||||
@@ -16,13 +16,12 @@ def mock_lm() -> dspy.LM:
|
|||||||
{"output": "Mock output response"},
|
{"output": "Mock output response"},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
dspy.configure(lm=lm)
|
|
||||||
return lm
|
return lm
|
||||||
|
|
||||||
|
|
||||||
class TestDSPyLLMAdapter:
|
class TestDSPyLLMAdapter:
|
||||||
def test_execute_returns_response(self, mock_lm: dspy.LM) -> None:
|
def test_execute_returns_response(self, mock_lm: dspy.LM) -> None:
|
||||||
adapter = DSPyLLMAdapter(model="openai/gpt-4o-mini")
|
adapter = DSPyLLMAdapter(lm=mock_lm)
|
||||||
prompt = Prompt(text="Answer the question.")
|
prompt = Prompt(text="Answer the question.")
|
||||||
result = adapter.execute(prompt, "What is 2+2?")
|
result = adapter.execute(prompt, "What is 2+2?")
|
||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
|
|||||||
207
tests/unit/test_adapter_config.py
Normal file
207
tests/unit/test_adapter_config.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""Unit tests for multi-model adapter configuration.
|
||||||
|
|
||||||
|
Verifies that each adapter uses its own dspy.LM instance and
|
||||||
|
that per-model api_base/api_key_env overrides are wired correctly.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from prometheus.domain.entities import Prompt, SyntheticExample, Trajectory
|
||||||
|
from prometheus.infrastructure.judge_adapter import DSPyJudgeAdapter
|
||||||
|
from prometheus.infrastructure.llm_adapter import DSPyLLMAdapter
|
||||||
|
from prometheus.infrastructure.proposer_adapter import DSPyProposerAdapter
|
||||||
|
from prometheus.infrastructure.synth_adapter import DSPySyntheticAdapter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def task_lm() -> dspy.LM:
|
||||||
|
"""Dummy LM for task execution."""
|
||||||
|
return dspy.utils.DummyLM([{"output": "task model output"}])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def judge_lm() -> dspy.LM:
|
||||||
|
"""Dummy LM for judging (ChainOfThought requires reasoning field)."""
|
||||||
|
return dspy.utils.DummyLM(
|
||||||
|
[
|
||||||
|
{"reasoning": "Evaluating output.", "score": "0.8", "feedback": "Good response."},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def proposer_lm() -> dspy.LM:
|
||||||
|
"""Dummy LM for proposing (ChainOfThought requires reasoning field)."""
|
||||||
|
return dspy.utils.DummyLM(
|
||||||
|
[
|
||||||
|
{"reasoning": "Analyzing failures.", "new_instruction": "Improved prompt: be more specific."},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def synth_lm() -> dspy.LM:
|
||||||
|
"""Dummy LM for synthetic generation (ChainOfThought requires reasoning field)."""
|
||||||
|
return dspy.utils.DummyLM(
|
||||||
|
[
|
||||||
|
{"reasoning": "Generating examples.", "examples": json.dumps(["input 1", "input 2", "input 3"])},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDSPyLLMAdapterOwnLM:
|
||||||
|
"""Bug #2 fix: DSPyLLMAdapter must use the LM it receives, not the global one."""
|
||||||
|
|
||||||
|
def test_uses_provided_lm_not_global(self) -> None:
|
||||||
|
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
|
||||||
|
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
|
||||||
|
dspy.configure(lm=global_lm)
|
||||||
|
|
||||||
|
adapter = DSPyLLMAdapter(lm=local_lm)
|
||||||
|
result = adapter.execute(Prompt(text="test"), "input")
|
||||||
|
|
||||||
|
assert result == "local response"
|
||||||
|
|
||||||
|
def test_does_not_affect_global_lm(self) -> None:
|
||||||
|
local_lm = dspy.utils.DummyLM([{"output": "local response"}])
|
||||||
|
global_lm = dspy.utils.DummyLM([{"output": "global response"}])
|
||||||
|
dspy.configure(lm=global_lm)
|
||||||
|
|
||||||
|
adapter = DSPyLLMAdapter(lm=local_lm)
|
||||||
|
adapter.execute(Prompt(text="test"), "input")
|
||||||
|
|
||||||
|
# Global LM should still be the same
|
||||||
|
assert dspy.settings.lm is global_lm
|
||||||
|
|
||||||
|
|
||||||
|
class TestDSPyJudgeAdapterOwnLM:
|
||||||
|
"""DSPyJudgeAdapter must use its own LM instance."""
|
||||||
|
|
||||||
|
def test_uses_provided_lm(self, judge_lm: dspy.LM) -> None:
|
||||||
|
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||||
|
results = adapter.judge_batch(
|
||||||
|
task_description="Test task",
|
||||||
|
pairs=[("input 1", "output 1")],
|
||||||
|
)
|
||||||
|
assert len(results) == 1
|
||||||
|
score, feedback = results[0]
|
||||||
|
assert score == 0.8
|
||||||
|
assert feedback == "Good response."
|
||||||
|
|
||||||
|
def test_does_not_use_global_lm(self) -> None:
|
||||||
|
judge_lm = dspy.utils.DummyLM(
|
||||||
|
[{"reasoning": "ok", "score": "0.9", "feedback": "Judge-specific response"}]
|
||||||
|
)
|
||||||
|
global_lm = dspy.utils.DummyLM([{"reasoning": "no", "score": "0.1", "feedback": "Wrong LM!"}])
|
||||||
|
dspy.configure(lm=global_lm)
|
||||||
|
|
||||||
|
adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||||
|
results = adapter.judge_batch("task", [("in", "out")])
|
||||||
|
assert results[0][0] == 0.9
|
||||||
|
|
||||||
|
|
||||||
|
class TestDSPyProposerAdapterOwnLM:
|
||||||
|
"""DSPyProposerAdapter must use its own LM instance."""
|
||||||
|
|
||||||
|
def test_uses_provided_lm(self, proposer_lm: dspy.LM) -> None:
|
||||||
|
adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||||
|
trajectories = [
|
||||||
|
Trajectory(
|
||||||
|
input_text="test input",
|
||||||
|
output_text="test output",
|
||||||
|
score=0.3,
|
||||||
|
feedback="bad",
|
||||||
|
prompt_used="old prompt",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = adapter.propose(
|
||||||
|
current_prompt=Prompt(text="old prompt"),
|
||||||
|
trajectories=trajectories,
|
||||||
|
task_description="Test task",
|
||||||
|
)
|
||||||
|
assert "Improved prompt" in result.text
|
||||||
|
|
||||||
|
def test_does_not_use_global_lm(self) -> None:
|
||||||
|
proposer_lm = dspy.utils.DummyLM(
|
||||||
|
[{"reasoning": "ok", "new_instruction": "proposer-specific"}]
|
||||||
|
)
|
||||||
|
global_lm = dspy.utils.DummyLM(
|
||||||
|
[{"reasoning": "no", "new_instruction": "wrong-global"}]
|
||||||
|
)
|
||||||
|
dspy.configure(lm=global_lm)
|
||||||
|
|
||||||
|
adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||||
|
result = adapter.propose(
|
||||||
|
current_prompt=Prompt(text="test"),
|
||||||
|
trajectories=[],
|
||||||
|
task_description="task",
|
||||||
|
)
|
||||||
|
assert result.text == "proposer-specific"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDSPySyntheticAdapterOwnLM:
|
||||||
|
"""DSPySyntheticAdapter must use its own LM instance."""
|
||||||
|
|
||||||
|
def test_uses_provided_lm(self, synth_lm: dspy.LM) -> None:
|
||||||
|
adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||||
|
results = adapter.generate_inputs("Test task", 3)
|
||||||
|
assert len(results) == 3
|
||||||
|
assert all(isinstance(ex, SyntheticExample) for ex in results)
|
||||||
|
|
||||||
|
def test_does_not_use_global_lm(self) -> None:
|
||||||
|
synth_lm = dspy.utils.DummyLM(
|
||||||
|
[{"reasoning": "ok", "examples": json.dumps(["synth-specific"])}]
|
||||||
|
)
|
||||||
|
global_lm = dspy.utils.DummyLM(
|
||||||
|
[{"reasoning": "no", "examples": json.dumps(["wrong-global"])}]
|
||||||
|
)
|
||||||
|
dspy.configure(lm=global_lm)
|
||||||
|
|
||||||
|
adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||||
|
results = adapter.generate_inputs("task", 1)
|
||||||
|
assert results[0].input_text == "synth-specific"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerModelOverrides:
|
||||||
|
"""Verify that per-model api_base/api_key_env are passed through to dspy.LM."""
|
||||||
|
|
||||||
|
@patch("prometheus.cli.app.dspy.LM")
|
||||||
|
def test_per_model_api_base_override(self, mock_lm_cls: MagicMock) -> None:
|
||||||
|
"""Per-model api_base should be used instead of global."""
|
||||||
|
mock_lm_cls.return_value = MagicMock()
|
||||||
|
|
||||||
|
from prometheus.application.dto import OptimizationConfig
|
||||||
|
|
||||||
|
config = OptimizationConfig(
|
||||||
|
seed_prompt="test",
|
||||||
|
task_description="test",
|
||||||
|
task_model="openai/gpt-4o-mini",
|
||||||
|
judge_model="openai/gpt-4o",
|
||||||
|
proposer_model="openai/gpt-4o",
|
||||||
|
synth_model="openai/gpt-4o",
|
||||||
|
judge_api_base="https://judge.example.com/v1",
|
||||||
|
judge_api_key_env="JUDGE_API_KEY",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify config carries the overrides
|
||||||
|
assert config.judge_api_base == "https://judge.example.com/v1"
|
||||||
|
assert config.judge_api_key_env == "JUDGE_API_KEY"
|
||||||
|
assert config.task_api_base is None
|
||||||
|
|
||||||
|
def test_config_defaults_to_none(self) -> None:
|
||||||
|
from prometheus.application.dto import OptimizationConfig
|
||||||
|
|
||||||
|
config = OptimizationConfig(seed_prompt="test", task_description="test")
|
||||||
|
assert config.task_api_base is None
|
||||||
|
assert config.task_api_key_env is None
|
||||||
|
assert config.judge_api_base is None
|
||||||
|
assert config.judge_api_key_env is None
|
||||||
|
assert config.proposer_api_base is None
|
||||||
|
assert config.proposer_api_key_env is None
|
||||||
|
assert config.synth_api_base is None
|
||||||
|
assert config.synth_api_key_env is None
|
||||||
Reference in New Issue
Block a user