fix: port LLM provider config and tool schemas

This commit is contained in:
Affaan Mustafa
2026-05-11 04:04:14 -04:00
committed by Affaan Mustafa
parent f442bac8c9
commit 7fa1e5b6db
7 changed files with 211 additions and 4 deletions

View File

@@ -0,0 +1,88 @@
from types import SimpleNamespace
from llm.core.types import LLMInput, Message, Role, ToolDefinition
from llm.providers.claude import ClaudeProvider
from llm.providers.openai import OpenAIProvider
def _tool() -> ToolDefinition:
return ToolDefinition(
name="search",
description="Search",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
)
class _OpenAICompletions:
def __init__(self) -> None:
self.params = None
def create(self, **params):
self.params = params
return SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")],
model=params["model"],
usage=SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2),
)
class _OpenAIClient:
def __init__(self) -> None:
self.completions = _OpenAICompletions()
self.chat = SimpleNamespace(completions=self.completions)
class _AnthropicMessages:
def __init__(self) -> None:
self.params = None
def create(self, **params):
self.params = params
return SimpleNamespace(
content=[SimpleNamespace(text="ok", type="text")],
model=params["model"],
usage=SimpleNamespace(input_tokens=1, output_tokens=1),
stop_reason="end_turn",
)
class _AnthropicClient:
def __init__(self) -> None:
self.messages = _AnthropicMessages()
self.api_key = "test"
def test_openai_provider_serializes_tools_for_chat_completions():
provider = OpenAIProvider(api_key="test")
client = _OpenAIClient()
provider.client = client
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], tools=[_tool()]))
assert client.completions.params["tools"] == [
{
"type": "function",
"function": {
"name": "search",
"description": "Search",
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}},
"strict": True,
},
}
]
def test_claude_provider_serializes_tools_for_messages_api():
provider = ClaudeProvider(api_key="test")
client = _AnthropicClient()
provider.client = client
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], tools=[_tool()]))
assert client.messages.params["tools"] == [
{
"name": "search",
"description": "Search",
"input_schema": {"type": "object", "properties": {"query": {"type": "string"}}},
}
]

View File

@@ -26,3 +26,37 @@ class TestGetProvider:
def test_invalid_provider_raises(self):
with pytest.raises(ValueError, match="Unknown provider type"):
get_provider("invalid")
def test_saved_llm_env_selects_provider(self, monkeypatch, tmp_path):
monkeypatch.delenv("LLM_PROVIDER", raising=False)
monkeypatch.chdir(tmp_path)
tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=ollama\nLLM_MODEL=llama3.2\n")
provider = get_provider()
assert isinstance(provider, OllamaProvider)
def test_env_provider_overrides_saved_llm_env(self, monkeypatch, tmp_path):
monkeypatch.setenv("LLM_PROVIDER", "ollama")
monkeypatch.chdir(tmp_path)
tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=openai\n")
provider = get_provider()
assert isinstance(provider, OllamaProvider)
def test_env_provider_is_normalized(self, monkeypatch):
monkeypatch.setenv("LLM_PROVIDER", "OLLAMA")
provider = get_provider()
assert isinstance(provider, OllamaProvider)
def test_explicit_provider_overrides_saved_llm_env(self, monkeypatch, tmp_path):
monkeypatch.delenv("LLM_PROVIDER", raising=False)
monkeypatch.chdir(tmp_path)
tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=openai\n")
provider = get_provider("ollama")
assert isinstance(provider, OllamaProvider)

View File

@@ -63,6 +63,37 @@ class TestToolDefinition:
assert result["name"] == "search"
assert result["strict"] is True
def test_tool_to_openai_tool(self):
tool = ToolDefinition(
name="search",
description="Search",
parameters={"type": "object"},
strict=False,
)
assert tool.to_openai_tool() == {
"type": "function",
"function": {
"name": "search",
"description": "Search",
"parameters": {"type": "object"},
"strict": False,
},
}
def test_tool_to_anthropic_tool(self):
tool = ToolDefinition(
name="search",
description="Search",
parameters={"type": "object"},
)
assert tool.to_anthropic_tool() == {
"name": "search",
"description": "Search",
"input_schema": {"type": "object"},
}
class TestToolCall:
def test_create_tool_call(self):