test_api_client.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. API Client tests
  5. """
  6. import asyncio
  7. import json
  8. import pytest
  9. from data_modules.config import DataModulesConfig
  10. from data_modules.api_client import (
  11. EmbeddingAPIClient,
  12. RerankAPIClient,
  13. ModalAPIClient,
  14. get_client,
  15. )
  16. class FakeResponse:
  17. def __init__(self, status, json_data=None, text_data=""):
  18. self.status = status
  19. self._json = json_data
  20. if text_data:
  21. self._text = text_data
  22. elif json_data is not None:
  23. self._text = json.dumps(json_data, ensure_ascii=False)
  24. else:
  25. self._text = ""
  26. async def __aenter__(self):
  27. return self
  28. async def __aexit__(self, exc_type, exc, tb):
  29. return False
  30. async def json(self):
  31. return self._json
  32. async def text(self):
  33. return self._text
  34. class FakeSession:
  35. def __init__(self, responses):
  36. self._responses = list(responses)
  37. self.closed = False
  38. def post(self, *args, **kwargs):
  39. if not self._responses:
  40. raise AssertionError("No more responses")
  41. resp = self._responses.pop(0)
  42. if isinstance(resp, Exception):
  43. raise resp
  44. return resp
  45. async def close(self):
  46. self.closed = True
  47. @pytest.mark.asyncio
  48. async def test_embedding_client_success_and_retry(tmp_path, monkeypatch):
  49. config = DataModulesConfig.from_project_root(tmp_path)
  50. config.embed_api_type = "openai"
  51. config.api_max_retries = 2
  52. client = EmbeddingAPIClient(config)
  53. responses = [
  54. FakeResponse(500, text_data="err"),
  55. FakeResponse(
  56. 200,
  57. json_data={
  58. "data": [
  59. {"embedding": [0.1, 0.2], "index": 1},
  60. {"embedding": [0.3, 0.4], "index": 0},
  61. ]
  62. },
  63. ),
  64. ]
  65. fake_session = FakeSession(responses)
  66. async def fake_get_session():
  67. return fake_session
  68. monkeypatch.setattr(client, "_get_session", fake_get_session)
  69. result = await client.embed(["a", "b"])
  70. assert result == [[0.3, 0.4], [0.1, 0.2]]
  71. assert client.stats.total_calls == 1
  72. assert client.stats.errors == 0
  73. @pytest.mark.asyncio
  74. async def test_embedding_client_timeout_and_error(tmp_path, monkeypatch):
  75. config = DataModulesConfig.from_project_root(tmp_path)
  76. config.embed_api_type = "openai"
  77. config.api_max_retries = 1
  78. client = EmbeddingAPIClient(config)
  79. responses = [asyncio.TimeoutError()]
  80. fake_session = FakeSession(responses)
  81. async def fake_get_session():
  82. return fake_session
  83. monkeypatch.setattr(client, "_get_session", fake_get_session)
  84. result = await client.embed(["x"])
  85. assert result is None
  86. assert client.stats.errors == 1
  87. @pytest.mark.asyncio
  88. async def test_embedding_batch(tmp_path, monkeypatch):
  89. config = DataModulesConfig.from_project_root(tmp_path)
  90. config.embed_batch_size = 2
  91. client = EmbeddingAPIClient(config)
  92. async def fake_embed(texts):
  93. if len(texts) == 2:
  94. return [[1.0, 0.0], [0.0, 1.0]]
  95. return None
  96. monkeypatch.setattr(client, "embed", fake_embed)
  97. result = await client.embed_batch(["a", "b", "c"], skip_failures=True)
  98. assert result[0] is not None
  99. assert result[2] is None
  100. result_fail = await client.embed_batch(["a", "b", "c"], skip_failures=False)
  101. assert result_fail == []
  102. def test_embedding_build_url_and_payload(tmp_path):
  103. config = DataModulesConfig.from_project_root(tmp_path)
  104. config.embed_api_type = "openai"
  105. config.embed_base_url = "https://api.example.com"
  106. client = EmbeddingAPIClient(config)
  107. assert client._build_url().endswith("/v1/embeddings")
  108. payload = client._build_payload(["hi"])
  109. assert payload["model"] == config.embed_model
  110. config.embed_base_url = "https://api.example.com/v1"
  111. assert client._build_url().endswith("/v1/embeddings")
  112. config.embed_base_url = "https://api.example.com/v1/embeddings"
  113. assert client._build_url().endswith("/v1/embeddings")
  114. config.embed_api_type = "modal"
  115. config.embed_base_url = "https://modal.example.com/embed"
  116. assert client._build_url() == "https://modal.example.com/embed"
  117. payload = client._build_payload(["hi"])
  118. assert "encoding_format" not in payload
  119. @pytest.mark.asyncio
  120. async def test_rerank_client_success(tmp_path, monkeypatch):
  121. config = DataModulesConfig.from_project_root(tmp_path)
  122. config.rerank_api_type = "openai"
  123. config.api_max_retries = 1
  124. client = RerankAPIClient(config)
  125. responses = [
  126. FakeResponse(
  127. 200,
  128. json_data={"results": [{"index": 0, "relevance_score": 0.9}]},
  129. )
  130. ]
  131. fake_session = FakeSession(responses)
  132. async def fake_get_session():
  133. return fake_session
  134. monkeypatch.setattr(client, "_get_session", fake_get_session)
  135. result = await client.rerank("q", ["doc1"], top_n=1)
  136. assert result[0]["index"] == 0
  137. assert client.stats.total_calls == 1
  138. @pytest.mark.asyncio
  139. async def test_rerank_retry_and_empty(tmp_path, monkeypatch):
  140. config = DataModulesConfig.from_project_root(tmp_path)
  141. config.rerank_api_type = "openai"
  142. config.api_max_retries = 2
  143. client = RerankAPIClient(config)
  144. responses = [
  145. FakeResponse(503, text_data="err"),
  146. FakeResponse(
  147. 200,
  148. json_data={"results": [{"index": 0, "relevance_score": 0.8}]},
  149. ),
  150. ]
  151. fake_session = FakeSession(responses)
  152. async def fake_get_session():
  153. return fake_session
  154. monkeypatch.setattr(client, "_get_session", fake_get_session)
  155. result = await client.rerank("q", ["doc1"], top_n=1)
  156. assert result[0]["relevance_score"] == 0.8
  157. assert await client.rerank("q", []) == []
  158. @pytest.mark.asyncio
  159. async def test_modal_client_warmup_and_passthrough(tmp_path, monkeypatch):
  160. config = DataModulesConfig.from_project_root(tmp_path)
  161. client = ModalAPIClient(config)
  162. async def fake_warmup():
  163. return None
  164. async def fake_embed(texts):
  165. return [[0.1, 0.2] for _ in texts]
  166. async def fake_rerank(query, documents, top_n=None):
  167. return [{"index": 0, "relevance_score": 1.0}]
  168. monkeypatch.setattr(client._embed_client, "warmup", fake_warmup)
  169. monkeypatch.setattr(client._rerank_client, "warmup", fake_warmup)
  170. monkeypatch.setattr(client._embed_client, "embed", fake_embed)
  171. monkeypatch.setattr(client._rerank_client, "rerank", fake_rerank)
  172. await client.warmup()
  173. assert client._warmed_up["embed"] is True
  174. assert client._warmed_up["rerank"] is True
  175. emb = await client.embed(["hi"])
  176. assert emb[0] == [0.1, 0.2]
  177. rr = await client.rerank("q", ["doc"])
  178. assert rr[0]["index"] == 0
  179. def test_get_client_singleton(tmp_path):
  180. cfg = DataModulesConfig.from_project_root(tmp_path)
  181. client1 = get_client(cfg)
  182. client2 = get_client()
  183. assert client1 is client2
  184. client3 = get_client(cfg)
  185. assert client3 is not client1
  186. @pytest.mark.asyncio
  187. async def test_embedding_empty_and_error_paths(tmp_path, monkeypatch):
  188. config = DataModulesConfig.from_project_root(tmp_path)
  189. config.embed_api_key = "sk-test"
  190. config.api_max_retries = 1
  191. client = EmbeddingAPIClient(config)
  192. assert await client.embed([]) == []
  193. headers = client._build_headers()
  194. assert headers["Authorization"] == "Bearer sk-test"
  195. fake_session = FakeSession([FakeResponse(400, text_data="bad request")])
  196. async def fake_get_session():
  197. return fake_session
  198. monkeypatch.setattr(client, "_get_session", fake_get_session)
  199. result = await client.embed(["x"])
  200. assert result is None
  201. assert client.stats.errors == 1
  202. @pytest.mark.asyncio
  203. async def test_embedding_exception_and_close(tmp_path, monkeypatch):
  204. config = DataModulesConfig.from_project_root(tmp_path)
  205. config.api_max_retries = 1
  206. client = EmbeddingAPIClient(config)
  207. class BoomSession:
  208. def __init__(self):
  209. self.closed = False
  210. def post(self, *args, **kwargs):
  211. raise RuntimeError("boom")
  212. async def close(self):
  213. self.closed = True
  214. session = BoomSession()
  215. async def fake_get_session():
  216. return session
  217. monkeypatch.setattr(client, "_get_session", fake_get_session)
  218. result = await client.embed(["x"])
  219. assert result is None
  220. assert client.stats.errors == 1
  221. client._session = session
  222. await client.close()
  223. assert session.closed is True
  224. def test_rerank_headers_payload_and_stats(tmp_path, capsys):
  225. config = DataModulesConfig.from_project_root(tmp_path)
  226. config.rerank_api_key = "rk-test"
  227. client = RerankAPIClient(config)
  228. headers = client._build_headers()
  229. assert headers["Authorization"] == "Bearer rk-test"
  230. payload = client._build_payload("q", ["doc"], top_n=2)
  231. assert payload["top_n"] == 2
  232. modal = ModalAPIClient(config)
  233. modal._embed_client.stats.total_calls = 1
  234. modal._embed_client.stats.total_time = 2.0
  235. modal.print_stats()
  236. output = capsys.readouterr().out
  237. assert "EMBED" in output
  238. @pytest.mark.asyncio
  239. async def test_rerank_non_retry_error(tmp_path, monkeypatch):
  240. config = DataModulesConfig.from_project_root(tmp_path)
  241. config.api_max_retries = 1
  242. client = RerankAPIClient(config)
  243. fake_session = FakeSession([FakeResponse(400, text_data="bad request")])
  244. async def fake_get_session():
  245. return fake_session
  246. monkeypatch.setattr(client, "_get_session", fake_get_session)
  247. result = await client.rerank("q", ["doc"])
  248. assert result is None
  249. assert client.stats.errors == 1
  250. @pytest.mark.asyncio
  251. async def test_embedding_session_parse_and_retry_paths(tmp_path, monkeypatch):
  252. config = DataModulesConfig.from_project_root(tmp_path)
  253. config.embed_api_type = "modal"
  254. config.api_max_retries = 2
  255. config.api_retry_delay = 0
  256. client = EmbeddingAPIClient(config)
  257. session = await client._get_session()
  258. assert session is not None
  259. await client.close()
  260. assert client._parse_response({}) is None
  261. parsed = client._parse_response({"data": [{"embedding": [1.0, 2.0]}]})
  262. assert parsed == [[1.0, 2.0]]
  263. responses = [
  264. asyncio.TimeoutError(),
  265. FakeResponse(200, text_data=json.dumps({"data": [{"embedding": [0.1], "index": 0}]})),
  266. ]
  267. fake_session = FakeSession(responses)
  268. async def fake_get_session():
  269. return fake_session
  270. monkeypatch.setattr(client, "_get_session", fake_get_session)
  271. result = await client.embed(["x"])
  272. assert result == [[0.1]]
  273. @pytest.mark.asyncio
  274. async def test_embedding_exception_retry_and_batch(tmp_path, monkeypatch):
  275. config = DataModulesConfig.from_project_root(tmp_path)
  276. config.api_max_retries = 2
  277. config.api_retry_delay = 0
  278. client = EmbeddingAPIClient(config)
  279. responses = [
  280. RuntimeError("boom"),
  281. FakeResponse(200, text_data=json.dumps({"data": [{"embedding": [0.2], "index": 0}]})),
  282. ]
  283. fake_session = FakeSession(responses)
  284. async def fake_get_session():
  285. return fake_session
  286. monkeypatch.setattr(client, "_get_session", fake_get_session)
  287. result = await client.embed(["x"])
  288. assert result == [[0.2]]
  289. assert await client.embed_batch([]) == []
  290. async def fake_embed(texts):
  291. return [[0.0] for _ in texts]
  292. monkeypatch.setattr(client, "embed", fake_embed)
  293. await client.warmup()
  294. assert client._warmed_up is True
  295. @pytest.mark.asyncio
  296. async def test_rerank_modal_retry_and_warmup(tmp_path, monkeypatch):
  297. config = DataModulesConfig.from_project_root(tmp_path)
  298. config.rerank_api_type = "modal"
  299. config.rerank_base_url = "https://modal.example.com/rerank"
  300. config.api_max_retries = 2
  301. config.api_retry_delay = 0
  302. client = RerankAPIClient(config)
  303. session = await client._get_session()
  304. assert session is not None
  305. await client.close()
  306. payload = client._build_payload("q", ["doc"], top_n=1)
  307. assert payload["top_n"] == 1
  308. assert client._build_url() == "https://modal.example.com/rerank"
  309. assert client._parse_response({"results": [{"index": 0}]}) == [{"index": 0}]
  310. responses = [
  311. asyncio.TimeoutError(),
  312. FakeResponse(200, json_data={"results": [{"index": 0, "relevance_score": 1.0}]}),
  313. ]
  314. fake_session = FakeSession(responses)
  315. async def fake_get_session():
  316. return fake_session
  317. monkeypatch.setattr(client, "_get_session", fake_get_session)
  318. result = await client.rerank("q", ["doc"])
  319. assert result[0]["index"] == 0
  320. responses = [
  321. RuntimeError("boom"),
  322. FakeResponse(200, json_data={"results": [{"index": 0, "relevance_score": 0.5}]}),
  323. ]
  324. fake_session = FakeSession(responses)
  325. async def fake_get_session2():
  326. return fake_session
  327. monkeypatch.setattr(client, "_get_session", fake_get_session2)
  328. result = await client.rerank("q", ["doc"])
  329. assert result[0]["relevance_score"] == 0.5
  330. async def fake_rerank(query, docs, top_n=None):
  331. return [{"index": 0, "relevance_score": 1.0}]
  332. monkeypatch.setattr(client, "rerank", fake_rerank)
  333. await client.warmup()
  334. assert client._warmed_up is True
  335. @pytest.mark.asyncio
  336. async def test_modal_client_helpers(tmp_path, monkeypatch, capsys):
  337. config = DataModulesConfig.from_project_root(tmp_path)
  338. client = ModalAPIClient(config)
  339. async def fake_embed_batch(texts, skip_failures=True):
  340. return [[0.1] for _ in texts]
  341. monkeypatch.setattr(client._embed_client, "embed_batch", fake_embed_batch)
  342. result = await client.embed_batch(["a", "b"])
  343. assert result[0] == [0.1]
  344. async def fail_warmup():
  345. raise RuntimeError("fail")
  346. async def ok_warmup():
  347. return None
  348. monkeypatch.setattr(client, "_warmup_embed", fail_warmup)
  349. monkeypatch.setattr(client, "_warmup_rerank", ok_warmup)
  350. await client.warmup()
  351. output = capsys.readouterr().out
  352. assert "[FAIL]" in output
  353. async def fake_get_session():
  354. return FakeSession([])
  355. monkeypatch.setattr(client._embed_client, "_get_session", fake_get_session)
  356. session = await client._get_session()
  357. assert session is not None
  358. closed = {"embed": False, "rerank": False}
  359. async def close_embed():
  360. closed["embed"] = True
  361. async def close_rerank():
  362. closed["rerank"] = True
  363. monkeypatch.setattr(client._embed_client, "close", close_embed)
  364. monkeypatch.setattr(client._rerank_client, "close", close_rerank)
  365. await client.close()
  366. assert closed["embed"] and closed["rerank"]