Files
everything-claude-code/tests/test_astraflow_provider.py
2026-05-11 23:21:46 -04:00

142 lines
5.2 KiB
Python

from types import SimpleNamespace
from llm.core.types import LLMInput, Message, ProviderType, Role, ToolDefinition, ToolCall
from llm.providers.astraflow import ASTRAFLOW_BASE_URL, ASTRAFLOW_CN_BASE_URL, AstraflowCNProvider, AstraflowProvider
def _tool() -> ToolDefinition:
return ToolDefinition(
name="search",
description="Search",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
)
class _Completions:
def __init__(self, response: SimpleNamespace) -> None:
self.params = None
self.response = response
def create(self, **params):
self.params = params
return self.response
class _Client:
def __init__(self, response: SimpleNamespace) -> None:
self.completions = _Completions(response)
self.chat = SimpleNamespace(completions=self.completions)
def _response(**overrides) -> SimpleNamespace:
message = SimpleNamespace(content="ok", tool_calls=None)
choice = SimpleNamespace(message=message, finish_reason="stop")
defaults = {
"choices": [choice],
"model": "gpt-4o-mini",
"usage": SimpleNamespace(prompt_tokens=1, completion_tokens=2, total_tokens=3),
}
defaults.update(overrides)
return SimpleNamespace(**defaults)
def test_astraflow_provider_defaults_to_global_umodelverse_endpoint(monkeypatch):
monkeypatch.delenv("ASTRAFLOW_API_KEY", raising=False)
monkeypatch.delenv("ASTRAFLOW_BASE_URL", raising=False)
monkeypatch.delenv("ASTRAFLOW_MODEL", raising=False)
provider = AstraflowProvider()
assert provider.provider_type == ProviderType.ASTRAFLOW
assert provider.base_url == ASTRAFLOW_BASE_URL
assert provider.get_default_model() == "gpt-4o-mini"
assert provider.validate_config() is False
def test_astraflow_cn_provider_uses_cn_endpoint_and_model_fallback(monkeypatch):
monkeypatch.setenv("ASTRAFLOW_API_KEY", "global-key")
monkeypatch.setenv("ASTRAFLOW_MODEL", "deepseek-ai/DeepSeek-V3-0324")
monkeypatch.setenv("ASTRAFLOW_CN_API_KEY", "cn-key")
monkeypatch.delenv("ASTRAFLOW_CN_MODEL", raising=False)
monkeypatch.delenv("ASTRAFLOW_CN_BASE_URL", raising=False)
provider = AstraflowCNProvider()
assert provider.provider_type == ProviderType.ASTRAFLOW_CN
assert provider.base_url == ASTRAFLOW_CN_BASE_URL
assert provider.get_default_model() == "deepseek-ai/DeepSeek-V3-0324"
assert provider.validate_config() is True
def test_astraflow_provider_generates_openai_compatible_chat_completion():
provider = AstraflowProvider(api_key="test", default_model="deepseek-ai/DeepSeek-V3-0324")
client = _Client(_response(model="deepseek-ai/DeepSeek-V3-0324"))
provider.client = client
output = provider.generate(
LLMInput(
messages=[Message(role=Role.USER, content="hi")],
max_tokens=128,
tools=[_tool()],
)
)
assert output.content == "ok"
assert output.model == "deepseek-ai/DeepSeek-V3-0324"
assert output.usage == {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}
assert client.completions.params["model"] == "deepseek-ai/DeepSeek-V3-0324"
assert client.completions.params["max_tokens"] == 128
assert "temperature" not in client.completions.params
assert client.completions.params["tools"] == [
{
"type": "function",
"function": {
"name": "search",
"description": "Search",
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}},
"strict": True,
},
}
]
def test_astraflow_provider_forwards_non_default_temperature():
provider = AstraflowProvider(api_key="test")
client = _Client(_response())
provider.client = client
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], temperature=0.2))
assert client.completions.params["temperature"] == 0.2
def test_astraflow_provider_parses_tool_calls():
provider = AstraflowProvider(api_key="test")
tool_call = SimpleNamespace(
id="call_1",
function=SimpleNamespace(name="search", arguments='{"query":"ucloud"}'),
)
message = SimpleNamespace(content="", tool_calls=[tool_call])
client = _Client(_response(choices=[SimpleNamespace(message=message, finish_reason="tool_calls")], usage=None))
provider.client = client
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
assert output.tool_calls == [ToolCall(id="call_1", name="search", arguments={"query": "ucloud"})]
assert output.usage is None
def test_astraflow_provider_preserves_malformed_tool_arguments():
provider = AstraflowProvider(api_key="test")
tool_call = SimpleNamespace(
id="call_1",
function=SimpleNamespace(name="search", arguments="{not-json"),
)
message = SimpleNamespace(content="", tool_calls=[tool_call])
client = _Client(_response(choices=[SimpleNamespace(message=message, finish_reason="tool_calls")]))
provider.client = client
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
assert output.tool_calls == [ToolCall(id="call_1", name="search", arguments={"raw": "{not-json"})]