fix: multi-model routing — each adapter uses own dspy.LM instance
- DSPyLLMAdapter now accepts dspy.LM instead of model string, uses dspy.context(lm=...) - DSPyJudgeAdapter, DSPyProposerAdapter, DSPySyntheticAdapter each accept and use own LM - OptimizationConfig gains per-model api_base/api_key_env override fields - cli/app.py creates separate dspy.LM per adapter with per-model overrides - New unit tests verify each adapter isolates its LM from global config Fixes Bug #1 (multi-model config not wired) and Bug #2 (DSPyLLMAdapter ignores model param). Co-Authored-By: Paperclip <noreply@paperclip.ing>
This commit is contained in:
@@ -76,6 +76,26 @@ def optimize(
|
||||
# 1. Load config
|
||||
persistence = YamlPersistence()
|
||||
raw_config = persistence.read_config(input)
|
||||
|
||||
def _model_lm_kwargs(
|
||||
model_api_base: str | None,
|
||||
model_api_key_env: str | None,
|
||||
global_api_base: str | None,
|
||||
global_api_key_env: str | None,
|
||||
) -> dict:
|
||||
"""Build kwargs for dspy.LM, using per-model overrides with global fallback."""
|
||||
kwargs: dict = {}
|
||||
api_base = model_api_base or global_api_base
|
||||
api_key_env = model_api_key_env or global_api_key_env
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
if api_key_env:
|
||||
kwargs["api_key"] = os.environ.get(api_key_env, "")
|
||||
return kwargs
|
||||
|
||||
global_api_base = raw_config.get("api_base")
|
||||
global_api_key_env = raw_config.get("api_key_env")
|
||||
|
||||
config = OptimizationConfig(
|
||||
seed_prompt=raw_config["seed_prompt"],
|
||||
task_description=raw_config["task_description"],
|
||||
@@ -83,6 +103,14 @@ def optimize(
|
||||
judge_model=raw_config.get("judge_model", "openai/gpt-4o"),
|
||||
proposer_model=raw_config.get("proposer_model", "openai/gpt-4o"),
|
||||
synth_model=raw_config.get("synth_model", "openai/gpt-4o"),
|
||||
task_api_base=raw_config.get("task_api_base"),
|
||||
task_api_key_env=raw_config.get("task_api_key_env"),
|
||||
judge_api_base=raw_config.get("judge_api_base"),
|
||||
judge_api_key_env=raw_config.get("judge_api_key_env"),
|
||||
proposer_api_base=raw_config.get("proposer_api_base"),
|
||||
proposer_api_key_env=raw_config.get("proposer_api_key_env"),
|
||||
synth_api_base=raw_config.get("synth_api_base"),
|
||||
synth_api_key_env=raw_config.get("synth_api_key_env"),
|
||||
max_iterations=raw_config.get("max_iterations", 30),
|
||||
n_synthetic_inputs=raw_config.get("n_synthetic_inputs", 20),
|
||||
minibatch_size=raw_config.get("minibatch_size", 5),
|
||||
@@ -93,22 +121,29 @@ def optimize(
|
||||
console.print(f"[dim]Task: {config.task_description[:80]}...[/dim]")
|
||||
console.print(f"[dim]Seed prompt: {config.seed_prompt[:80]}...[/dim]")
|
||||
|
||||
# 2. Configure DSPy with optional api_base/api_key from config
|
||||
lm_kwargs: dict = {}
|
||||
api_base = raw_config.get("api_base")
|
||||
api_key_env = raw_config.get("api_key_env")
|
||||
if api_base:
|
||||
lm_kwargs["api_base"] = api_base
|
||||
if api_key_env:
|
||||
lm_kwargs["api_key"] = os.environ.get(api_key_env, "")
|
||||
task_lm = dspy.LM(config.task_model, **lm_kwargs)
|
||||
dspy.configure(lm=task_lm)
|
||||
# 2. Create per-model DSPy LM instances
|
||||
task_lm = dspy.LM(
|
||||
config.task_model,
|
||||
**_model_lm_kwargs(config.task_api_base, config.task_api_key_env, global_api_base, global_api_key_env),
|
||||
)
|
||||
judge_lm = dspy.LM(
|
||||
config.judge_model,
|
||||
**_model_lm_kwargs(config.judge_api_base, config.judge_api_key_env, global_api_base, global_api_key_env),
|
||||
)
|
||||
proposer_lm = dspy.LM(
|
||||
config.proposer_model,
|
||||
**_model_lm_kwargs(config.proposer_api_base, config.proposer_api_key_env, global_api_base, global_api_key_env),
|
||||
)
|
||||
synth_lm = dspy.LM(
|
||||
config.synth_model,
|
||||
**_model_lm_kwargs(config.synth_api_base, config.synth_api_key_env, global_api_base, global_api_key_env),
|
||||
)
|
||||
|
||||
# 3. Build adapters (Dependency Injection)
|
||||
synth_adapter = DSPySyntheticAdapter()
|
||||
llm_adapter = DSPyLLMAdapter(model=config.task_model)
|
||||
judge_adapter = DSPyJudgeAdapter()
|
||||
proposer_adapter = DSPyProposerAdapter()
|
||||
# 3. Build adapters (Dependency Injection — each gets its own LM)
|
||||
synth_adapter = DSPySyntheticAdapter(lm=synth_lm)
|
||||
llm_adapter = DSPyLLMAdapter(lm=task_lm)
|
||||
judge_adapter = DSPyJudgeAdapter(lm=judge_lm)
|
||||
proposer_adapter = DSPyProposerAdapter(lm=proposer_lm)
|
||||
bootstrap = SyntheticBootstrap(generator=synth_adapter, seed=config.seed)
|
||||
evaluator = PromptEvaluator(executor=llm_adapter, judge=judge_adapter)
|
||||
use_case = OptimizePromptUseCase(
|
||||
|
||||
Reference in New Issue
Block a user