api_client.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Data Modules - API 客户端 (v5.0 OpenAI 兼容接口)
  5. 支持两种 API 类型:
  6. 1. openai: OpenAI 兼容的 /v1/embeddings 和 /v1/rerank 接口
  7. - 适用于: OpenAI, Jina, Cohere, vLLM, Ollama 等
  8. 2. modal: Modal 自定义接口格式
  9. - 适用于: 自部署的 Modal 服务
  10. 配置示例 (config.py):
  11. embed_api_type = "openai"
  12. embed_base_url = "https://api.openai.com/v1"
  13. embed_model = "text-embedding-3-small"
  14. embed_api_key = "sk-xxx"
  15. rerank_api_type = "openai" # Jina/Cohere 也使用此类型
  16. rerank_base_url = "https://api.jina.ai/v1"
  17. rerank_model = "jina-reranker-v2-base-multilingual"
  18. rerank_api_key = "jina_xxx"
  19. """
  20. import asyncio
  21. import aiohttp
  22. import time
  23. from typing import List, Dict, Any, Optional
  24. from dataclasses import dataclass
  25. from .config import get_config
  26. @dataclass
  27. class APIStats:
  28. """API 调用统计"""
  29. total_calls: int = 0
  30. total_time: float = 0.0
  31. errors: int = 0
  32. class EmbeddingAPIClient:
  33. """
  34. 通用 Embedding API 客户端
  35. 支持 OpenAI 兼容接口 (/v1/embeddings) 和 Modal 自定义接口
  36. """
  37. def __init__(self, config=None):
  38. self.config = config or get_config()
  39. self.sem = asyncio.Semaphore(self.config.embed_concurrency)
  40. self.stats = APIStats()
  41. self._warmed_up = False
  42. self._session: Optional[aiohttp.ClientSession] = None
  43. async def _get_session(self) -> aiohttp.ClientSession:
  44. if self._session is None or self._session.closed:
  45. connector = aiohttp.TCPConnector(limit=200, limit_per_host=100)
  46. self._session = aiohttp.ClientSession(connector=connector)
  47. return self._session
  48. async def close(self):
  49. if self._session and not self._session.closed:
  50. await self._session.close()
  51. def _build_headers(self) -> Dict[str, str]:
  52. """构建请求头"""
  53. headers = {"Content-Type": "application/json"}
  54. if self.config.embed_api_key:
  55. headers["Authorization"] = f"Bearer {self.config.embed_api_key}"
  56. return headers
  57. def _build_url(self) -> str:
  58. """构建请求 URL"""
  59. base_url = self.config.embed_base_url.rstrip("/")
  60. if self.config.embed_api_type == "openai":
  61. # OpenAI 兼容: /v1/embeddings
  62. if not base_url.endswith("/embeddings"):
  63. if base_url.endswith("/v1"):
  64. return f"{base_url}/embeddings"
  65. return f"{base_url}/v1/embeddings"
  66. return base_url
  67. else:
  68. # Modal 自定义接口: 直接使用配置的 URL
  69. return base_url
  70. def _build_payload(self, texts: List[str]) -> Dict[str, Any]:
  71. """构建请求体"""
  72. if self.config.embed_api_type == "openai":
  73. return {
  74. "input": texts,
  75. "model": self.config.embed_model
  76. }
  77. else:
  78. # Modal 格式
  79. return {
  80. "input": texts,
  81. "model": self.config.embed_model
  82. }
  83. def _parse_response(self, data: Dict[str, Any]) -> Optional[List[List[float]]]:
  84. """解析响应"""
  85. if self.config.embed_api_type == "openai":
  86. # OpenAI 格式: {"data": [{"embedding": [...], "index": 0}, ...]}
  87. if "data" in data:
  88. # 按 index 排序,确保顺序正确
  89. sorted_data = sorted(data["data"], key=lambda x: x.get("index", 0))
  90. return [item["embedding"] for item in sorted_data]
  91. return None
  92. else:
  93. # Modal 格式: {"data": [{"embedding": [...]}, ...]}
  94. if "data" in data:
  95. return [item["embedding"] for item in data["data"]]
  96. return None
  97. async def embed(self, texts: List[str]) -> Optional[List[List[float]]]:
  98. """调用 Embedding 服务"""
  99. if not texts:
  100. return []
  101. timeout = self.config.cold_start_timeout if not self._warmed_up else self.config.normal_timeout
  102. async with self.sem:
  103. start = time.time()
  104. session = await self._get_session()
  105. try:
  106. url = self._build_url()
  107. headers = self._build_headers()
  108. payload = self._build_payload(texts)
  109. async with session.post(
  110. url,
  111. json=payload,
  112. headers=headers,
  113. timeout=aiohttp.ClientTimeout(total=timeout)
  114. ) as resp:
  115. if resp.status == 200:
  116. data = await resp.json()
  117. embeddings = self._parse_response(data)
  118. if embeddings:
  119. self.stats.total_calls += 1
  120. self.stats.total_time += time.time() - start
  121. return embeddings
  122. self.stats.errors += 1
  123. print(f"[ERR] Embed {resp.status}: {await resp.text()[:200]}")
  124. return None
  125. except Exception as e:
  126. self.stats.errors += 1
  127. print(f"[ERR] Embed: {e}")
  128. return None
  129. async def embed_batch(
  130. self, texts: List[str], *, skip_failures: bool = True
  131. ) -> List[Optional[List[float]]]:
  132. """
  133. 分批 Embedding
  134. Args:
  135. texts: 要嵌入的文本列表
  136. skip_failures: True 时失败的文本返回 None;False 时任一失败则整体返回空列表
  137. Returns:
  138. 与 texts 等长的列表,成功的位置是向量,失败的位置是 None
  139. """
  140. if not texts:
  141. return []
  142. all_embeddings: List[Optional[List[float]]] = []
  143. batch_size = self.config.embed_batch_size
  144. batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
  145. tasks = [self.embed(batch) for batch in batches]
  146. results = await asyncio.gather(*tasks)
  147. for batch_idx, result in enumerate(results):
  148. actual_batch_size = len(batches[batch_idx])
  149. if result and len(result) == actual_batch_size:
  150. all_embeddings.extend(result)
  151. else:
  152. if not skip_failures:
  153. print(f"[WARN] Embed batch {batch_idx} failed, aborting all")
  154. return []
  155. print(f"[WARN] Embed batch {batch_idx} failed, marking {actual_batch_size} items as None")
  156. all_embeddings.extend([None] * actual_batch_size)
  157. return all_embeddings[:len(texts)]
  158. async def warmup(self):
  159. """预热服务"""
  160. await self.embed(["test"])
  161. self._warmed_up = True
  162. class RerankAPIClient:
  163. """
  164. 通用 Rerank API 客户端
  165. 支持 OpenAI 兼容接口 (Jina/Cohere 格式) 和 Modal 自定义接口
  166. """
  167. def __init__(self, config=None):
  168. self.config = config or get_config()
  169. self.sem = asyncio.Semaphore(self.config.rerank_concurrency)
  170. self.stats = APIStats()
  171. self._warmed_up = False
  172. self._session: Optional[aiohttp.ClientSession] = None
  173. async def _get_session(self) -> aiohttp.ClientSession:
  174. if self._session is None or self._session.closed:
  175. connector = aiohttp.TCPConnector(limit=200, limit_per_host=100)
  176. self._session = aiohttp.ClientSession(connector=connector)
  177. return self._session
  178. async def close(self):
  179. if self._session and not self._session.closed:
  180. await self._session.close()
  181. def _build_headers(self) -> Dict[str, str]:
  182. """构建请求头"""
  183. headers = {"Content-Type": "application/json"}
  184. if self.config.rerank_api_key:
  185. headers["Authorization"] = f"Bearer {self.config.rerank_api_key}"
  186. return headers
  187. def _build_url(self) -> str:
  188. """构建请求 URL"""
  189. base_url = self.config.rerank_base_url.rstrip("/")
  190. if self.config.rerank_api_type == "openai":
  191. # Jina/Cohere 兼容: /v1/rerank
  192. if not base_url.endswith("/rerank"):
  193. if base_url.endswith("/v1"):
  194. return f"{base_url}/rerank"
  195. return f"{base_url}/v1/rerank"
  196. return base_url
  197. else:
  198. # Modal 自定义接口
  199. return base_url
  200. def _build_payload(self, query: str, documents: List[str], top_n: Optional[int]) -> Dict[str, Any]:
  201. """构建请求体"""
  202. if self.config.rerank_api_type == "openai":
  203. # Jina/Cohere 格式
  204. payload: Dict[str, Any] = {
  205. "query": query,
  206. "documents": documents,
  207. "model": self.config.rerank_model
  208. }
  209. if top_n:
  210. payload["top_n"] = top_n
  211. return payload
  212. else:
  213. # Modal 格式
  214. payload = {"query": query, "documents": documents}
  215. if top_n:
  216. payload["top_n"] = top_n
  217. return payload
  218. def _parse_response(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
  219. """解析响应"""
  220. if self.config.rerank_api_type == "openai":
  221. # Jina/Cohere 格式: {"results": [{"index": 0, "relevance_score": 0.9}, ...]}
  222. return data.get("results", [])
  223. else:
  224. # Modal 格式: {"results": [...]}
  225. return data.get("results", [])
  226. async def rerank(
  227. self,
  228. query: str,
  229. documents: List[str],
  230. top_n: Optional[int] = None
  231. ) -> Optional[List[Dict[str, Any]]]:
  232. """调用 Rerank 服务"""
  233. if not documents:
  234. return []
  235. timeout = self.config.cold_start_timeout if not self._warmed_up else self.config.normal_timeout
  236. async with self.sem:
  237. start = time.time()
  238. session = await self._get_session()
  239. try:
  240. url = self._build_url()
  241. headers = self._build_headers()
  242. payload = self._build_payload(query, documents, top_n)
  243. async with session.post(
  244. url,
  245. json=payload,
  246. headers=headers,
  247. timeout=aiohttp.ClientTimeout(total=timeout)
  248. ) as resp:
  249. if resp.status == 200:
  250. data = await resp.json()
  251. self.stats.total_calls += 1
  252. self.stats.total_time += time.time() - start
  253. return self._parse_response(data)
  254. else:
  255. self.stats.errors += 1
  256. print(f"[ERR] Rerank {resp.status}: {await resp.text()[:200]}")
  257. return None
  258. except Exception as e:
  259. self.stats.errors += 1
  260. print(f"[ERR] Rerank: {e}")
  261. return None
  262. async def warmup(self):
  263. """预热服务"""
  264. await self.rerank("test", ["doc1", "doc2"])
  265. self._warmed_up = True
  266. class ModalAPIClient:
  267. """
  268. 统一 API 客户端 (兼容旧接口)
  269. 整合 Embedding + Rerank 客户端,保持向后兼容
  270. """
  271. def __init__(self, config=None):
  272. self.config = config or get_config()
  273. self._embed_client = EmbeddingAPIClient(self.config)
  274. self._rerank_client = RerankAPIClient(self.config)
  275. # 兼容旧代码的信号量
  276. self.sem_embed = self._embed_client.sem
  277. self.sem_rerank = self._rerank_client.sem
  278. self._warmed_up = {"embed": False, "rerank": False}
  279. self._session: Optional[aiohttp.ClientSession] = None
  280. @property
  281. def stats(self) -> Dict[str, APIStats]:
  282. return {
  283. "embed": self._embed_client.stats,
  284. "rerank": self._rerank_client.stats
  285. }
  286. async def _get_session(self) -> aiohttp.ClientSession:
  287. # 复用 embed client 的 session
  288. return await self._embed_client._get_session()
  289. async def close(self):
  290. await self._embed_client.close()
  291. await self._rerank_client.close()
  292. # ==================== 预热 ====================
  293. async def warmup(self):
  294. """预热 Embedding 和 Rerank 服务"""
  295. print("[WARMUP] Warming up Embed + Rerank...")
  296. start = time.time()
  297. tasks = [self._warmup_embed(), self._warmup_rerank()]
  298. results = await asyncio.gather(*tasks, return_exceptions=True)
  299. for name, result in zip(["Embed", "Rerank"], results):
  300. if isinstance(result, Exception):
  301. print(f" [FAIL] {name}: {result}")
  302. else:
  303. print(f" [OK] {name} ready")
  304. print(f"[WARMUP] Done in {time.time() - start:.1f}s")
  305. async def _warmup_embed(self):
  306. await self._embed_client.warmup()
  307. self._warmed_up["embed"] = True
  308. async def _warmup_rerank(self):
  309. await self._rerank_client.warmup()
  310. self._warmed_up["rerank"] = True
  311. # ==================== Embedding API ====================
  312. async def embed(self, texts: List[str]) -> Optional[List[List[float]]]:
  313. """调用 Embedding 服务"""
  314. return await self._embed_client.embed(texts)
  315. async def embed_batch(
  316. self, texts: List[str], *, skip_failures: bool = True
  317. ) -> List[Optional[List[float]]]:
  318. """分批 Embedding"""
  319. return await self._embed_client.embed_batch(texts, skip_failures=skip_failures)
  320. # ==================== Rerank API ====================
  321. async def rerank(
  322. self,
  323. query: str,
  324. documents: List[str],
  325. top_n: Optional[int] = None
  326. ) -> Optional[List[Dict[str, Any]]]:
  327. """调用 Rerank 服务"""
  328. return await self._rerank_client.rerank(query, documents, top_n)
  329. # ==================== 统计 ====================
  330. def print_stats(self):
  331. print("\n[API STATS]")
  332. for name, stats in self.stats.items():
  333. if stats.total_calls > 0:
  334. avg_time = stats.total_time / stats.total_calls
  335. print(f" {name.upper()}: {stats.total_calls} calls, "
  336. f"{stats.total_time:.1f}s total, "
  337. f"{avg_time:.2f}s avg, "
  338. f"{stats.errors} errors")
  339. # 全局客户端
  340. _client: Optional[ModalAPIClient] = None
  341. def get_client(config=None) -> ModalAPIClient:
  342. global _client
  343. if _client is None or config is not None:
  344. _client = ModalAPIClient(config)
  345. return _client