api_client.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Data Modules - API 客户端 (v5.4,v5.0 OpenAI 兼容接口沿用)
  5. 支持两种 API 类型:
  6. 1. openai: OpenAI 兼容的 /v1/embeddings 和 /v1/rerank 接口
  7. - 适用于: OpenAI, Jina, Cohere, vLLM, Ollama 等
  8. 2. modal: Modal 自定义接口格式
  9. - 适用于: 自部署的 Modal 服务
  10. 配置示例 (config.py):
  11. embed_api_type = "openai"
  12. embed_base_url = "https://api.openai.com/v1"
  13. embed_model = "text-embedding-3-small"
  14. embed_api_key = "sk-xxx"
  15. rerank_api_type = "openai" # Jina/Cohere 也使用此类型
  16. rerank_base_url = "https://api.jina.ai/v1"
  17. rerank_model = "jina-reranker-v2-base-multilingual"
  18. rerank_api_key = "jina_xxx"
  19. """
  20. import asyncio
  21. import aiohttp
  22. import time
  23. from typing import List, Dict, Any, Optional
  24. from dataclasses import dataclass
  25. from .config import get_config
  26. @dataclass
  27. class APIStats:
  28. """API 调用统计"""
  29. total_calls: int = 0
  30. total_time: float = 0.0
  31. errors: int = 0
  32. class EmbeddingAPIClient:
  33. """
  34. 通用 Embedding API 客户端
  35. 支持 OpenAI 兼容接口 (/v1/embeddings) 和 Modal 自定义接口
  36. """
  37. def __init__(self, config=None):
  38. self.config = config or get_config()
  39. self.sem = asyncio.Semaphore(self.config.embed_concurrency)
  40. self.stats = APIStats()
  41. self._warmed_up = False
  42. self._session: Optional[aiohttp.ClientSession] = None
  43. self.last_error_status: Optional[int] = None
  44. self.last_error_message: str = ""
  45. async def _get_session(self) -> aiohttp.ClientSession:
  46. if self._session is None or self._session.closed:
  47. connector = aiohttp.TCPConnector(limit=200, limit_per_host=100)
  48. self._session = aiohttp.ClientSession(connector=connector)
  49. return self._session
  50. async def close(self):
  51. if self._session and not self._session.closed:
  52. await self._session.close()
  53. def _build_headers(self) -> Dict[str, str]:
  54. """构建请求头"""
  55. headers = {"Content-Type": "application/json"}
  56. if self.config.embed_api_key:
  57. headers["Authorization"] = f"Bearer {self.config.embed_api_key}"
  58. return headers
  59. def _build_url(self) -> str:
  60. """构建请求 URL"""
  61. base_url = self.config.embed_base_url.rstrip("/")
  62. if self.config.embed_api_type == "openai":
  63. # OpenAI 兼容: /v1/embeddings
  64. if not base_url.endswith("/embeddings"):
  65. if base_url.endswith("/v1"):
  66. return f"{base_url}/embeddings"
  67. return f"{base_url}/v1/embeddings"
  68. return base_url
  69. else:
  70. # Modal 自定义接口: 直接使用配置的 URL
  71. return base_url
  72. def _build_payload(self, texts: List[str]) -> Dict[str, Any]:
  73. """构建请求体"""
  74. if self.config.embed_api_type == "openai":
  75. return {
  76. "input": texts,
  77. "model": self.config.embed_model,
  78. "encoding_format": "float"
  79. }
  80. else:
  81. # Modal 格式
  82. return {
  83. "input": texts,
  84. "model": self.config.embed_model
  85. }
  86. def _parse_response(self, data: Dict[str, Any]) -> Optional[List[List[float]]]:
  87. """解析响应"""
  88. if self.config.embed_api_type == "openai":
  89. # OpenAI 格式: {"data": [{"embedding": [...], "index": 0}, ...]}
  90. if "data" in data:
  91. # 按 index 排序,确保顺序正确
  92. sorted_data = sorted(data["data"], key=lambda x: x.get("index", 0))
  93. return [item["embedding"] for item in sorted_data]
  94. return None
  95. else:
  96. # Modal 格式: {"data": [{"embedding": [...]}, ...]}
  97. if "data" in data:
  98. return [item["embedding"] for item in data["data"]]
  99. return None
  100. async def embed(self, texts: List[str]) -> Optional[List[List[float]]]:
  101. """调用 Embedding 服务(带重试机制)"""
  102. if not texts:
  103. return []
  104. # 某些 embedding 端点(如 Gemini)拒绝空字符串,用单空格占位保持索引对齐
  105. texts = [t if t else " " for t in texts]
  106. timeout = self.config.cold_start_timeout if not self._warmed_up else self.config.normal_timeout
  107. max_retries = getattr(self.config, 'api_max_retries', 3)
  108. base_delay = getattr(self.config, 'api_retry_delay', 1.0)
  109. async with self.sem:
  110. start = time.time()
  111. session = await self._get_session()
  112. for attempt in range(max_retries):
  113. try:
  114. url = self._build_url()
  115. headers = self._build_headers()
  116. payload = self._build_payload(texts)
  117. async with session.post(
  118. url,
  119. json=payload,
  120. headers=headers,
  121. timeout=aiohttp.ClientTimeout(total=timeout)
  122. ) as resp:
  123. if resp.status == 200:
  124. text = await resp.text()
  125. import json as json_module
  126. data = json_module.loads(text)
  127. embeddings = self._parse_response(data)
  128. if embeddings:
  129. self.stats.total_calls += 1
  130. self.stats.total_time += time.time() - start
  131. self._warmed_up = True
  132. self.last_error_status = None
  133. self.last_error_message = ""
  134. return embeddings
  135. # 可重试的状态码: 429 (限流), 500, 502, 503, 504
  136. if resp.status in (429, 500, 502, 503, 504) and attempt < max_retries - 1:
  137. delay = base_delay * (2 ** attempt) # 指数退避
  138. print(f"[WARN] Embed {resp.status}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
  139. await asyncio.sleep(delay)
  140. continue
  141. self.stats.errors += 1
  142. err_text = await resp.text()
  143. self.last_error_status = int(resp.status)
  144. self.last_error_message = str(err_text[:200])
  145. print(f"[ERR] Embed {resp.status}: {err_text[:200]}")
  146. return None
  147. except asyncio.TimeoutError:
  148. if attempt < max_retries - 1:
  149. delay = base_delay * (2 ** attempt)
  150. print(f"[WARN] Embed timeout, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
  151. await asyncio.sleep(delay)
  152. continue
  153. self.stats.errors += 1
  154. self.last_error_status = None
  155. self.last_error_message = f"Timeout after {max_retries} attempts"
  156. print(f"[ERR] Embed: Timeout after {max_retries} attempts")
  157. return None
  158. except Exception as e:
  159. if attempt < max_retries - 1:
  160. delay = base_delay * (2 ** attempt)
  161. print(f"[WARN] Embed error: {e}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
  162. await asyncio.sleep(delay)
  163. continue
  164. self.stats.errors += 1
  165. self.last_error_status = None
  166. self.last_error_message = str(e)
  167. print(f"[ERR] Embed: {e}")
  168. return None
  169. return None
  170. async def embed_batch(
  171. self, texts: List[str], *, skip_failures: bool = True
  172. ) -> List[Optional[List[float]]]:
  173. """
  174. 分批 Embedding
  175. Args:
  176. texts: 要嵌入的文本列表
  177. skip_failures: True 时失败的文本返回 None;False 时任一失败则整体返回空列表
  178. Returns:
  179. 与 texts 等长的列表,成功的位置是向量,失败的位置是 None
  180. """
  181. if not texts:
  182. return []
  183. all_embeddings: List[Optional[List[float]]] = []
  184. batch_size = self.config.embed_batch_size
  185. batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
  186. tasks = [self.embed(batch) for batch in batches]
  187. results = await asyncio.gather(*tasks)
  188. for batch_idx, result in enumerate(results):
  189. actual_batch_size = len(batches[batch_idx])
  190. if result and len(result) == actual_batch_size:
  191. all_embeddings.extend(result)
  192. else:
  193. if not skip_failures:
  194. print(f"[WARN] Embed batch {batch_idx} failed, aborting all")
  195. return []
  196. print(f"[WARN] Embed batch {batch_idx} failed, marking {actual_batch_size} items as None")
  197. all_embeddings.extend([None] * actual_batch_size)
  198. return all_embeddings[:len(texts)]
  199. async def warmup(self):
  200. """预热服务"""
  201. await self.embed(["test"])
  202. self._warmed_up = True
  203. class RerankAPIClient:
  204. """
  205. 通用 Rerank API 客户端
  206. 支持 OpenAI 兼容接口 (Jina/Cohere 格式) 和 Modal 自定义接口
  207. """
  208. def __init__(self, config=None):
  209. self.config = config or get_config()
  210. self.sem = asyncio.Semaphore(self.config.rerank_concurrency)
  211. self.stats = APIStats()
  212. self._warmed_up = False
  213. self._session: Optional[aiohttp.ClientSession] = None
  214. async def _get_session(self) -> aiohttp.ClientSession:
  215. if self._session is None or self._session.closed:
  216. connector = aiohttp.TCPConnector(limit=200, limit_per_host=100)
  217. self._session = aiohttp.ClientSession(connector=connector)
  218. return self._session
  219. async def close(self):
  220. if self._session and not self._session.closed:
  221. await self._session.close()
  222. def _build_headers(self) -> Dict[str, str]:
  223. """构建请求头"""
  224. headers = {"Content-Type": "application/json"}
  225. if self.config.rerank_api_key:
  226. headers["Authorization"] = f"Bearer {self.config.rerank_api_key}"
  227. return headers
  228. def _build_url(self) -> str:
  229. """构建请求 URL"""
  230. base_url = self.config.rerank_base_url.rstrip("/")
  231. if self.config.rerank_api_type == "openai":
  232. # Jina/Cohere 兼容: /v1/rerank
  233. if not base_url.endswith("/rerank"):
  234. if base_url.endswith("/v1"):
  235. return f"{base_url}/rerank"
  236. return f"{base_url}/v1/rerank"
  237. return base_url
  238. else:
  239. # Modal 自定义接口
  240. return base_url
  241. def _build_payload(self, query: str, documents: List[str], top_n: Optional[int]) -> Dict[str, Any]:
  242. """构建请求体"""
  243. if self.config.rerank_api_type == "openai":
  244. # Jina/Cohere 格式
  245. payload: Dict[str, Any] = {
  246. "query": query,
  247. "documents": documents,
  248. "model": self.config.rerank_model
  249. }
  250. if top_n:
  251. payload["top_n"] = top_n
  252. return payload
  253. else:
  254. # Modal 格式
  255. payload = {"query": query, "documents": documents}
  256. if top_n:
  257. payload["top_n"] = top_n
  258. return payload
  259. def _parse_response(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
  260. """解析响应"""
  261. if self.config.rerank_api_type == "openai":
  262. # Jina/Cohere 格式: {"results": [{"index": 0, "relevance_score": 0.9}, ...]}
  263. return data.get("results", [])
  264. else:
  265. # Modal 格式: {"results": [...]}
  266. return data.get("results", [])
  267. async def rerank(
  268. self,
  269. query: str,
  270. documents: List[str],
  271. top_n: Optional[int] = None
  272. ) -> Optional[List[Dict[str, Any]]]:
  273. """调用 Rerank 服务(带重试机制)"""
  274. if not documents:
  275. return []
  276. timeout = self.config.cold_start_timeout if not self._warmed_up else self.config.normal_timeout
  277. max_retries = getattr(self.config, 'api_max_retries', 3)
  278. base_delay = getattr(self.config, 'api_retry_delay', 1.0)
  279. async with self.sem:
  280. start = time.time()
  281. session = await self._get_session()
  282. for attempt in range(max_retries):
  283. try:
  284. url = self._build_url()
  285. headers = self._build_headers()
  286. payload = self._build_payload(query, documents, top_n)
  287. async with session.post(
  288. url,
  289. json=payload,
  290. headers=headers,
  291. timeout=aiohttp.ClientTimeout(total=timeout)
  292. ) as resp:
  293. if resp.status == 200:
  294. data = await resp.json()
  295. self.stats.total_calls += 1
  296. self.stats.total_time += time.time() - start
  297. self._warmed_up = True
  298. return self._parse_response(data)
  299. # 可重试的状态码
  300. if resp.status in (429, 500, 502, 503, 504) and attempt < max_retries - 1:
  301. delay = base_delay * (2 ** attempt)
  302. print(f"[WARN] Rerank {resp.status}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
  303. await asyncio.sleep(delay)
  304. continue
  305. self.stats.errors += 1
  306. err_text = await resp.text()
  307. print(f"[ERR] Rerank {resp.status}: {err_text[:200]}")
  308. return None
  309. except asyncio.TimeoutError:
  310. if attempt < max_retries - 1:
  311. delay = base_delay * (2 ** attempt)
  312. print(f"[WARN] Rerank timeout, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
  313. await asyncio.sleep(delay)
  314. continue
  315. self.stats.errors += 1
  316. print(f"[ERR] Rerank: Timeout after {max_retries} attempts")
  317. return None
  318. except Exception as e:
  319. if attempt < max_retries - 1:
  320. delay = base_delay * (2 ** attempt)
  321. print(f"[WARN] Rerank error: {e}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
  322. await asyncio.sleep(delay)
  323. continue
  324. self.stats.errors += 1
  325. print(f"[ERR] Rerank: {e}")
  326. return None
  327. return None
  328. async def warmup(self):
  329. """预热服务"""
  330. await self.rerank("test", ["doc1", "doc2"])
  331. self._warmed_up = True
  332. class ModalAPIClient:
  333. """
  334. 统一 API 客户端 (兼容旧接口)
  335. 整合 Embedding + Rerank 客户端,保持向后兼容
  336. """
  337. def __init__(self, config=None):
  338. self.config = config or get_config()
  339. self._embed_client = EmbeddingAPIClient(self.config)
  340. self._rerank_client = RerankAPIClient(self.config)
  341. # 兼容旧代码的信号量
  342. self.sem_embed = self._embed_client.sem
  343. self.sem_rerank = self._rerank_client.sem
  344. self._warmed_up = {"embed": False, "rerank": False}
  345. self._session: Optional[aiohttp.ClientSession] = None
  346. @property
  347. def stats(self) -> Dict[str, APIStats]:
  348. return {
  349. "embed": self._embed_client.stats,
  350. "rerank": self._rerank_client.stats
  351. }
  352. async def _get_session(self) -> aiohttp.ClientSession:
  353. # 复用 embed client 的 session
  354. return await self._embed_client._get_session()
  355. async def close(self):
  356. await self._embed_client.close()
  357. await self._rerank_client.close()
  358. # ==================== 预热 ====================
  359. async def warmup(self):
  360. """预热 Embedding 和 Rerank 服务"""
  361. print("[WARMUP] Warming up Embed + Rerank...")
  362. start = time.time()
  363. tasks = [self._warmup_embed(), self._warmup_rerank()]
  364. results = await asyncio.gather(*tasks, return_exceptions=True)
  365. for name, result in zip(["Embed", "Rerank"], results):
  366. if isinstance(result, Exception):
  367. print(f" [FAIL] {name}: {result}")
  368. else:
  369. print(f" [OK] {name} ready")
  370. print(f"[WARMUP] Done in {time.time() - start:.1f}s")
  371. async def _warmup_embed(self):
  372. await self._embed_client.warmup()
  373. self._warmed_up["embed"] = True
  374. async def _warmup_rerank(self):
  375. await self._rerank_client.warmup()
  376. self._warmed_up["rerank"] = True
  377. # ==================== Embedding API ====================
  378. async def embed(self, texts: List[str]) -> Optional[List[List[float]]]:
  379. """调用 Embedding 服务"""
  380. return await self._embed_client.embed(texts)
  381. async def embed_batch(
  382. self, texts: List[str], *, skip_failures: bool = True
  383. ) -> List[Optional[List[float]]]:
  384. """分批 Embedding"""
  385. return await self._embed_client.embed_batch(texts, skip_failures=skip_failures)
  386. # ==================== Rerank API ====================
  387. async def rerank(
  388. self,
  389. query: str,
  390. documents: List[str],
  391. top_n: Optional[int] = None
  392. ) -> Optional[List[Dict[str, Any]]]:
  393. """调用 Rerank 服务"""
  394. return await self._rerank_client.rerank(query, documents, top_n)
  395. # ==================== 统计 ====================
  396. def print_stats(self):
  397. print("\n[API STATS]")
  398. for name, stats in self.stats.items():
  399. if stats.total_calls > 0:
  400. avg_time = stats.total_time / stats.total_calls
  401. print(f" {name.upper()}: {stats.total_calls} calls, "
  402. f"{stats.total_time:.1f}s total, "
  403. f"{avg_time:.2f}s avg, "
  404. f"{stats.errors} errors")
  405. # 全局客户端
  406. _client: Optional[ModalAPIClient] = None
  407. def get_client(config=None) -> ModalAPIClient:
  408. global _client
  409. if _client is None or config is not None:
  410. _client = ModalAPIClient(config)
  411. return _client