From 7fa1e5b6db87313f87fbfff91da91f448d75aca9 Mon Sep 17 00:00:00 2001 From: Affaan Mustafa Date: Mon, 11 May 2026 04:04:14 -0400 Subject: [PATCH] fix: port LLM provider config and tool schemas --- src/llm/core/types.py | 18 +++++++ src/llm/providers/claude.py | 2 +- src/llm/providers/openai.py | 2 +- src/llm/providers/resolver.py | 40 +++++++++++++++- tests/test_provider_tools.py | 88 +++++++++++++++++++++++++++++++++++ tests/test_resolver.py | 34 ++++++++++++++ tests/test_types.py | 31 ++++++++++++ 7 files changed, 211 insertions(+), 4 deletions(-) create mode 100644 tests/test_provider_tools.py diff --git a/src/llm/core/types.py b/src/llm/core/types.py index b53588dc..6b06adce 100644 --- a/src/llm/core/types.py +++ b/src/llm/core/types.py @@ -57,6 +57,24 @@ class ToolDefinition: "strict": self.strict, } + def to_openai_tool(self) -> dict[str, Any]: + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self.parameters, + "strict": self.strict, + }, + } + + def to_anthropic_tool(self) -> dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "input_schema": self.parameters, + } + @dataclass(frozen=True) class ToolCall: diff --git a/src/llm/providers/claude.py b/src/llm/providers/claude.py index 975f036b..f344a19f 100644 --- a/src/llm/providers/claude.py +++ b/src/llm/providers/claude.py @@ -60,7 +60,7 @@ class ClaudeProvider(LLMProvider): else: params["max_tokens"] = 8192 # required by Anthropic API if input.tools: - params["tools"] = [tool.to_dict() for tool in input.tools] + params["tools"] = [tool.to_anthropic_tool() for tool in input.tools] response = self.client.messages.create(**params) diff --git a/src/llm/providers/openai.py b/src/llm/providers/openai.py index 019696cf..e4e7f895 100644 --- a/src/llm/providers/openai.py +++ b/src/llm/providers/openai.py @@ -67,7 +67,7 @@ class OpenAIProvider(LLMProvider): if input.max_tokens: params["max_tokens"] = input.max_tokens if input.tools: - params["tools"] = [tool.to_dict() for tool in input.tools] + params["tools"] = [tool.to_openai_tool() for tool in input.tools] response = self.client.chat.completions.create(**params) choice = response.choices[0] diff --git a/src/llm/providers/resolver.py b/src/llm/providers/resolver.py index 655c8775..5967523e 100644 --- a/src/llm/providers/resolver.py +++ b/src/llm/providers/resolver.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from pathlib import Path from llm.core.interface import LLMProvider from llm.core.types import ProviderType @@ -17,10 +18,45 @@ _PROVIDER_MAP: dict[ProviderType, type[LLMProvider]] = { ProviderType.OLLAMA: OllamaProvider, } +LLM_ENV_FILE = ".llm.env" + + +def _strip_env_value(value: str) -> str: + value = value.strip() + if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}: + return value[1:-1] + return value + + +def _read_saved_llm_config(env_path: str | Path = LLM_ENV_FILE) -> dict[str, str]: + path = Path(env_path) + if not path.is_file(): + return {} + + config: dict[str, str] = {} + for line in path.read_text().splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#") or "=" not in stripped: + continue + key, value = stripped.split("=", 1) + config[key.strip()] = _strip_env_value(value) + return config + + +def _resolve_provider_type(provider_type: ProviderType | str | None) -> ProviderType | str: + if provider_type is not None: + return provider_type + + env_provider = os.environ.get("LLM_PROVIDER") + if env_provider: + return _strip_env_value(env_provider).lower() + + saved_config = _read_saved_llm_config() + return saved_config.get("LLM_PROVIDER", "claude").lower() + def get_provider(provider_type: ProviderType | str | None = None, **kwargs: str) -> LLMProvider: - if provider_type is None: - provider_type = os.environ.get("LLM_PROVIDER", "claude").lower() + provider_type = _resolve_provider_type(provider_type) if isinstance(provider_type, str): try: diff --git a/tests/test_provider_tools.py b/tests/test_provider_tools.py new file mode 100644 index 00000000..e12f7aa6 --- /dev/null +++ b/tests/test_provider_tools.py @@ -0,0 +1,88 @@ +from types import SimpleNamespace + +from llm.core.types import LLMInput, Message, Role, ToolDefinition +from llm.providers.claude import ClaudeProvider +from llm.providers.openai import OpenAIProvider + + +def _tool() -> ToolDefinition: + return ToolDefinition( + name="search", + description="Search", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + ) + + +class _OpenAICompletions: + def __init__(self) -> None: + self.params = None + + def create(self, **params): + self.params = params + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")], + model=params["model"], + usage=SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2), + ) + + +class _OpenAIClient: + def __init__(self) -> None: + self.completions = _OpenAICompletions() + self.chat = SimpleNamespace(completions=self.completions) + + +class _AnthropicMessages: + def __init__(self) -> None: + self.params = None + + def create(self, **params): + self.params = params + return SimpleNamespace( + content=[SimpleNamespace(text="ok", type="text")], + model=params["model"], + usage=SimpleNamespace(input_tokens=1, output_tokens=1), + stop_reason="end_turn", + ) + + +class _AnthropicClient: + def __init__(self) -> None: + self.messages = _AnthropicMessages() + self.api_key = "test" + + +def test_openai_provider_serializes_tools_for_chat_completions(): + provider = OpenAIProvider(api_key="test") + client = _OpenAIClient() + provider.client = client + + provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], tools=[_tool()])) + + assert client.completions.params["tools"] == [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": {"type": "object", "properties": {"query": {"type": "string"}}}, + "strict": True, + }, + } + ] + + +def test_claude_provider_serializes_tools_for_messages_api(): + provider = ClaudeProvider(api_key="test") + client = _AnthropicClient() + provider.client = client + + provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], tools=[_tool()])) + + assert client.messages.params["tools"] == [ + { + "name": "search", + "description": "Search", + "input_schema": {"type": "object", "properties": {"query": {"type": "string"}}}, + } + ] diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 6676c20e..7a8b9b63 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -26,3 +26,37 @@ class TestGetProvider: def test_invalid_provider_raises(self): with pytest.raises(ValueError, match="Unknown provider type"): get_provider("invalid") + + def test_saved_llm_env_selects_provider(self, monkeypatch, tmp_path): + monkeypatch.delenv("LLM_PROVIDER", raising=False) + monkeypatch.chdir(tmp_path) + tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=ollama\nLLM_MODEL=llama3.2\n") + + provider = get_provider() + + assert isinstance(provider, OllamaProvider) + + def test_env_provider_overrides_saved_llm_env(self, monkeypatch, tmp_path): + monkeypatch.setenv("LLM_PROVIDER", "ollama") + monkeypatch.chdir(tmp_path) + tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=openai\n") + + provider = get_provider() + + assert isinstance(provider, OllamaProvider) + + def test_env_provider_is_normalized(self, monkeypatch): + monkeypatch.setenv("LLM_PROVIDER", "OLLAMA") + + provider = get_provider() + + assert isinstance(provider, OllamaProvider) + + def test_explicit_provider_overrides_saved_llm_env(self, monkeypatch, tmp_path): + monkeypatch.delenv("LLM_PROVIDER", raising=False) + monkeypatch.chdir(tmp_path) + tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=openai\n") + + provider = get_provider("ollama") + + assert isinstance(provider, OllamaProvider) diff --git a/tests/test_types.py b/tests/test_types.py index a24b6119..a065cfc8 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -63,6 +63,37 @@ class TestToolDefinition: assert result["name"] == "search" assert result["strict"] is True + def test_tool_to_openai_tool(self): + tool = ToolDefinition( + name="search", + description="Search", + parameters={"type": "object"}, + strict=False, + ) + + assert tool.to_openai_tool() == { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": {"type": "object"}, + "strict": False, + }, + } + + def test_tool_to_anthropic_tool(self): + tool = ToolDefinition( + name="search", + description="Search", + parameters={"type": "object"}, + ) + + assert tool.to_anthropic_tool() == { + "name": "search", + "description": "Search", + "input_schema": {"type": "object"}, + } + class TestToolCall: def test_create_tool_call(self):