mirror of
https://github.com/affaan-m/everything-claude-code.git
synced 2026-05-12 07:37:24 +08:00
76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
"""Provider factory and resolver."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from llm.core.interface import LLMProvider
|
|
from llm.core.types import ProviderType
|
|
from llm.providers.claude import ClaudeProvider
|
|
from llm.providers.openai import OpenAIProvider
|
|
from llm.providers.ollama import OllamaProvider
|
|
|
|
|
|
_PROVIDER_MAP: dict[ProviderType, type[LLMProvider]] = {
|
|
ProviderType.CLAUDE: ClaudeProvider,
|
|
ProviderType.OPENAI: OpenAIProvider,
|
|
ProviderType.OLLAMA: OllamaProvider,
|
|
}
|
|
|
|
LLM_ENV_FILE = ".llm.env"
|
|
|
|
|
|
def _strip_env_value(value: str) -> str:
|
|
value = value.strip()
|
|
if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}:
|
|
return value[1:-1]
|
|
return value
|
|
|
|
|
|
def _read_saved_llm_config(env_path: str | Path = LLM_ENV_FILE) -> dict[str, str]:
|
|
path = Path(env_path)
|
|
if not path.is_file():
|
|
return {}
|
|
|
|
config: dict[str, str] = {}
|
|
for line in path.read_text().splitlines():
|
|
stripped = line.strip()
|
|
if not stripped or stripped.startswith("#") or "=" not in stripped:
|
|
continue
|
|
key, value = stripped.split("=", 1)
|
|
config[key.strip()] = _strip_env_value(value)
|
|
return config
|
|
|
|
|
|
def _resolve_provider_type(provider_type: ProviderType | str | None) -> ProviderType | str:
|
|
if provider_type is not None:
|
|
return provider_type
|
|
|
|
env_provider = os.environ.get("LLM_PROVIDER")
|
|
if env_provider:
|
|
return _strip_env_value(env_provider).lower()
|
|
|
|
saved_config = _read_saved_llm_config()
|
|
return saved_config.get("LLM_PROVIDER", "claude").lower()
|
|
|
|
|
|
def get_provider(provider_type: ProviderType | str | None = None, **kwargs: str) -> LLMProvider:
|
|
provider_type = _resolve_provider_type(provider_type)
|
|
|
|
if isinstance(provider_type, str):
|
|
try:
|
|
provider_type = ProviderType(provider_type)
|
|
except ValueError:
|
|
raise ValueError(f"Unknown provider type: {provider_type}. Valid types: {[p.value for p in ProviderType]}")
|
|
|
|
provider_cls = _PROVIDER_MAP.get(provider_type)
|
|
if not provider_cls:
|
|
raise ValueError(f"No provider registered for type: {provider_type}")
|
|
|
|
return provider_cls(**kwargs)
|
|
|
|
|
|
def register_provider(provider_type: ProviderType, provider_cls: type[LLMProvider]) -> None:
|
|
_PROVIDER_MAP[provider_type] = provider_cls
|