mirror of
https://github.com/affaan-m/everything-claude-code.git
synced 2026-05-14 08:28:39 +08:00
fix: port LLM provider config and tool schemas
This commit is contained in:
committed by
Affaan Mustafa
parent
f442bac8c9
commit
7fa1e5b6db
@@ -57,6 +57,24 @@ class ToolDefinition:
|
||||
"strict": self.strict,
|
||||
}
|
||||
|
||||
def to_openai_tool(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
"strict": self.strict,
|
||||
},
|
||||
}
|
||||
|
||||
def to_anthropic_tool(self) -> dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"input_schema": self.parameters,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolCall:
|
||||
|
||||
@@ -60,7 +60,7 @@ class ClaudeProvider(LLMProvider):
|
||||
else:
|
||||
params["max_tokens"] = 8192 # required by Anthropic API
|
||||
if input.tools:
|
||||
params["tools"] = [tool.to_dict() for tool in input.tools]
|
||||
params["tools"] = [tool.to_anthropic_tool() for tool in input.tools]
|
||||
|
||||
response = self.client.messages.create(**params)
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ class OpenAIProvider(LLMProvider):
|
||||
if input.max_tokens:
|
||||
params["max_tokens"] = input.max_tokens
|
||||
if input.tools:
|
||||
params["tools"] = [tool.to_dict() for tool in input.tools]
|
||||
params["tools"] = [tool.to_openai_tool() for tool in input.tools]
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
choice = response.choices[0]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from llm.core.interface import LLMProvider
|
||||
from llm.core.types import ProviderType
|
||||
@@ -17,10 +18,45 @@ _PROVIDER_MAP: dict[ProviderType, type[LLMProvider]] = {
|
||||
ProviderType.OLLAMA: OllamaProvider,
|
||||
}
|
||||
|
||||
LLM_ENV_FILE = ".llm.env"
|
||||
|
||||
|
||||
def _strip_env_value(value: str) -> str:
|
||||
value = value.strip()
|
||||
if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}:
|
||||
return value[1:-1]
|
||||
return value
|
||||
|
||||
|
||||
def _read_saved_llm_config(env_path: str | Path = LLM_ENV_FILE) -> dict[str, str]:
|
||||
path = Path(env_path)
|
||||
if not path.is_file():
|
||||
return {}
|
||||
|
||||
config: dict[str, str] = {}
|
||||
for line in path.read_text().splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped or stripped.startswith("#") or "=" not in stripped:
|
||||
continue
|
||||
key, value = stripped.split("=", 1)
|
||||
config[key.strip()] = _strip_env_value(value)
|
||||
return config
|
||||
|
||||
|
||||
def _resolve_provider_type(provider_type: ProviderType | str | None) -> ProviderType | str:
|
||||
if provider_type is not None:
|
||||
return provider_type
|
||||
|
||||
env_provider = os.environ.get("LLM_PROVIDER")
|
||||
if env_provider:
|
||||
return _strip_env_value(env_provider).lower()
|
||||
|
||||
saved_config = _read_saved_llm_config()
|
||||
return saved_config.get("LLM_PROVIDER", "claude").lower()
|
||||
|
||||
|
||||
def get_provider(provider_type: ProviderType | str | None = None, **kwargs: str) -> LLMProvider:
|
||||
if provider_type is None:
|
||||
provider_type = os.environ.get("LLM_PROVIDER", "claude").lower()
|
||||
provider_type = _resolve_provider_type(provider_type)
|
||||
|
||||
if isinstance(provider_type, str):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user