mirror of
https://github.com/affaan-m/everything-claude-code.git
synced 2026-05-13 16:13:03 +08:00
142 lines
5.2 KiB
Python
142 lines
5.2 KiB
Python
from types import SimpleNamespace
|
|
|
|
from llm.core.types import LLMInput, Message, ProviderType, Role, ToolDefinition, ToolCall
|
|
from llm.providers.astraflow import ASTRAFLOW_BASE_URL, ASTRAFLOW_CN_BASE_URL, AstraflowCNProvider, AstraflowProvider
|
|
|
|
|
|
def _tool() -> ToolDefinition:
|
|
return ToolDefinition(
|
|
name="search",
|
|
description="Search",
|
|
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
|
|
)
|
|
|
|
|
|
class _Completions:
|
|
def __init__(self, response: SimpleNamespace) -> None:
|
|
self.params = None
|
|
self.response = response
|
|
|
|
def create(self, **params):
|
|
self.params = params
|
|
return self.response
|
|
|
|
|
|
class _Client:
|
|
def __init__(self, response: SimpleNamespace) -> None:
|
|
self.completions = _Completions(response)
|
|
self.chat = SimpleNamespace(completions=self.completions)
|
|
|
|
|
|
def _response(**overrides) -> SimpleNamespace:
|
|
message = SimpleNamespace(content="ok", tool_calls=None)
|
|
choice = SimpleNamespace(message=message, finish_reason="stop")
|
|
defaults = {
|
|
"choices": [choice],
|
|
"model": "gpt-4o-mini",
|
|
"usage": SimpleNamespace(prompt_tokens=1, completion_tokens=2, total_tokens=3),
|
|
}
|
|
defaults.update(overrides)
|
|
return SimpleNamespace(**defaults)
|
|
|
|
|
|
def test_astraflow_provider_defaults_to_global_umodelverse_endpoint(monkeypatch):
|
|
monkeypatch.delenv("ASTRAFLOW_API_KEY", raising=False)
|
|
monkeypatch.delenv("ASTRAFLOW_BASE_URL", raising=False)
|
|
monkeypatch.delenv("ASTRAFLOW_MODEL", raising=False)
|
|
|
|
provider = AstraflowProvider()
|
|
|
|
assert provider.provider_type == ProviderType.ASTRAFLOW
|
|
assert provider.base_url == ASTRAFLOW_BASE_URL
|
|
assert provider.get_default_model() == "gpt-4o-mini"
|
|
assert provider.validate_config() is False
|
|
|
|
|
|
def test_astraflow_cn_provider_uses_cn_endpoint_and_model_fallback(monkeypatch):
|
|
monkeypatch.setenv("ASTRAFLOW_API_KEY", "global-key")
|
|
monkeypatch.setenv("ASTRAFLOW_MODEL", "deepseek-ai/DeepSeek-V3-0324")
|
|
monkeypatch.setenv("ASTRAFLOW_CN_API_KEY", "cn-key")
|
|
monkeypatch.delenv("ASTRAFLOW_CN_MODEL", raising=False)
|
|
monkeypatch.delenv("ASTRAFLOW_CN_BASE_URL", raising=False)
|
|
|
|
provider = AstraflowCNProvider()
|
|
|
|
assert provider.provider_type == ProviderType.ASTRAFLOW_CN
|
|
assert provider.base_url == ASTRAFLOW_CN_BASE_URL
|
|
assert provider.get_default_model() == "deepseek-ai/DeepSeek-V3-0324"
|
|
assert provider.validate_config() is True
|
|
|
|
|
|
def test_astraflow_provider_generates_openai_compatible_chat_completion():
|
|
provider = AstraflowProvider(api_key="test", default_model="deepseek-ai/DeepSeek-V3-0324")
|
|
client = _Client(_response(model="deepseek-ai/DeepSeek-V3-0324"))
|
|
provider.client = client
|
|
|
|
output = provider.generate(
|
|
LLMInput(
|
|
messages=[Message(role=Role.USER, content="hi")],
|
|
max_tokens=128,
|
|
tools=[_tool()],
|
|
)
|
|
)
|
|
|
|
assert output.content == "ok"
|
|
assert output.model == "deepseek-ai/DeepSeek-V3-0324"
|
|
assert output.usage == {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}
|
|
assert client.completions.params["model"] == "deepseek-ai/DeepSeek-V3-0324"
|
|
assert client.completions.params["max_tokens"] == 128
|
|
assert "temperature" not in client.completions.params
|
|
assert client.completions.params["tools"] == [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search",
|
|
"description": "Search",
|
|
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}},
|
|
"strict": True,
|
|
},
|
|
}
|
|
]
|
|
|
|
|
|
def test_astraflow_provider_forwards_non_default_temperature():
|
|
provider = AstraflowProvider(api_key="test")
|
|
client = _Client(_response())
|
|
provider.client = client
|
|
|
|
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], temperature=0.2))
|
|
|
|
assert client.completions.params["temperature"] == 0.2
|
|
|
|
|
|
def test_astraflow_provider_parses_tool_calls():
|
|
provider = AstraflowProvider(api_key="test")
|
|
tool_call = SimpleNamespace(
|
|
id="call_1",
|
|
function=SimpleNamespace(name="search", arguments='{"query":"ucloud"}'),
|
|
)
|
|
message = SimpleNamespace(content="", tool_calls=[tool_call])
|
|
client = _Client(_response(choices=[SimpleNamespace(message=message, finish_reason="tool_calls")], usage=None))
|
|
provider.client = client
|
|
|
|
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
|
|
|
|
assert output.tool_calls == [ToolCall(id="call_1", name="search", arguments={"query": "ucloud"})]
|
|
assert output.usage is None
|
|
|
|
|
|
def test_astraflow_provider_preserves_malformed_tool_arguments():
|
|
provider = AstraflowProvider(api_key="test")
|
|
tool_call = SimpleNamespace(
|
|
id="call_1",
|
|
function=SimpleNamespace(name="search", arguments="{not-json"),
|
|
)
|
|
message = SimpleNamespace(content="", tool_calls=[tool_call])
|
|
client = _Client(_response(choices=[SimpleNamespace(message=message, finish_reason="tool_calls")]))
|
|
provider.client = client
|
|
|
|
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
|
|
|
|
assert output.tool_calls == [ToolCall(id="call_1", name="search", arguments={"raw": "{not-json"})]
|