| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Data Modules - API 客户端 (v5.4,v5.0 OpenAI 兼容接口沿用)
- 支持两种 API 类型:
- 1. openai: OpenAI 兼容的 /v1/embeddings 和 /v1/rerank 接口
- - 适用于: OpenAI, Jina, Cohere, vLLM, Ollama 等
- 2. modal: Modal 自定义接口格式
- - 适用于: 自部署的 Modal 服务
- 配置示例 (config.py):
- embed_api_type = "openai"
- embed_base_url = "https://api.openai.com/v1"
- embed_model = "text-embedding-3-small"
- embed_api_key = "sk-xxx"
- rerank_api_type = "openai" # Jina/Cohere 也使用此类型
- rerank_base_url = "https://api.jina.ai/v1"
- rerank_model = "jina-reranker-v2-base-multilingual"
- rerank_api_key = "jina_xxx"
- """
- import asyncio
- import aiohttp
- import time
- from typing import List, Dict, Any, Optional
- from dataclasses import dataclass
- from .config import get_config
- @dataclass
- class APIStats:
- """API 调用统计"""
- total_calls: int = 0
- total_time: float = 0.0
- errors: int = 0
- class EmbeddingAPIClient:
- """
- 通用 Embedding API 客户端
- 支持 OpenAI 兼容接口 (/v1/embeddings) 和 Modal 自定义接口
- """
- def __init__(self, config=None):
- self.config = config or get_config()
- self.sem = asyncio.Semaphore(self.config.embed_concurrency)
- self.stats = APIStats()
- self._warmed_up = False
- self._session: Optional[aiohttp.ClientSession] = None
- self.last_error_status: Optional[int] = None
- self.last_error_message: str = ""
- async def _get_session(self) -> aiohttp.ClientSession:
- if self._session is None or self._session.closed:
- connector = aiohttp.TCPConnector(limit=200, limit_per_host=100)
- self._session = aiohttp.ClientSession(connector=connector)
- return self._session
- async def close(self):
- if self._session and not self._session.closed:
- await self._session.close()
- def _build_headers(self) -> Dict[str, str]:
- """构建请求头"""
- headers = {"Content-Type": "application/json"}
- if self.config.embed_api_key:
- headers["Authorization"] = f"Bearer {self.config.embed_api_key}"
- return headers
- def _build_url(self) -> str:
- """构建请求 URL"""
- base_url = self.config.embed_base_url.rstrip("/")
- if self.config.embed_api_type == "openai":
- # OpenAI 兼容: /v1/embeddings
- if not base_url.endswith("/embeddings"):
- if base_url.endswith("/v1"):
- return f"{base_url}/embeddings"
- return f"{base_url}/v1/embeddings"
- return base_url
- else:
- # Modal 自定义接口: 直接使用配置的 URL
- return base_url
- def _build_payload(self, texts: List[str]) -> Dict[str, Any]:
- """构建请求体"""
- if self.config.embed_api_type == "openai":
- return {
- "input": texts,
- "model": self.config.embed_model,
- "encoding_format": "float"
- }
- else:
- # Modal 格式
- return {
- "input": texts,
- "model": self.config.embed_model
- }
- def _parse_response(self, data: Dict[str, Any]) -> Optional[List[List[float]]]:
- """解析响应"""
- if self.config.embed_api_type == "openai":
- # OpenAI 格式: {"data": [{"embedding": [...], "index": 0}, ...]}
- if "data" in data:
- # 按 index 排序,确保顺序正确
- sorted_data = sorted(data["data"], key=lambda x: x.get("index", 0))
- return [item["embedding"] for item in sorted_data]
- return None
- else:
- # Modal 格式: {"data": [{"embedding": [...]}, ...]}
- if "data" in data:
- return [item["embedding"] for item in data["data"]]
- return None
- async def embed(self, texts: List[str]) -> Optional[List[List[float]]]:
- """调用 Embedding 服务(带重试机制)"""
- if not texts:
- return []
- timeout = self.config.cold_start_timeout if not self._warmed_up else self.config.normal_timeout
- max_retries = getattr(self.config, 'api_max_retries', 3)
- base_delay = getattr(self.config, 'api_retry_delay', 1.0)
- async with self.sem:
- start = time.time()
- session = await self._get_session()
- for attempt in range(max_retries):
- try:
- url = self._build_url()
- headers = self._build_headers()
- payload = self._build_payload(texts)
- async with session.post(
- url,
- json=payload,
- headers=headers,
- timeout=aiohttp.ClientTimeout(total=timeout)
- ) as resp:
- if resp.status == 200:
- text = await resp.text()
- import json as json_module
- data = json_module.loads(text)
- embeddings = self._parse_response(data)
- if embeddings:
- self.stats.total_calls += 1
- self.stats.total_time += time.time() - start
- self._warmed_up = True
- self.last_error_status = None
- self.last_error_message = ""
- return embeddings
- # 可重试的状态码: 429 (限流), 500, 502, 503, 504
- if resp.status in (429, 500, 502, 503, 504) and attempt < max_retries - 1:
- delay = base_delay * (2 ** attempt) # 指数退避
- print(f"[WARN] Embed {resp.status}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
- await asyncio.sleep(delay)
- continue
- self.stats.errors += 1
- err_text = await resp.text()
- self.last_error_status = int(resp.status)
- self.last_error_message = str(err_text[:200])
- print(f"[ERR] Embed {resp.status}: {err_text[:200]}")
- return None
- except asyncio.TimeoutError:
- if attempt < max_retries - 1:
- delay = base_delay * (2 ** attempt)
- print(f"[WARN] Embed timeout, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
- await asyncio.sleep(delay)
- continue
- self.stats.errors += 1
- self.last_error_status = None
- self.last_error_message = f"Timeout after {max_retries} attempts"
- print(f"[ERR] Embed: Timeout after {max_retries} attempts")
- return None
- except Exception as e:
- if attempt < max_retries - 1:
- delay = base_delay * (2 ** attempt)
- print(f"[WARN] Embed error: {e}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
- await asyncio.sleep(delay)
- continue
- self.stats.errors += 1
- self.last_error_status = None
- self.last_error_message = str(e)
- print(f"[ERR] Embed: {e}")
- return None
- return None
- async def embed_batch(
- self, texts: List[str], *, skip_failures: bool = True
- ) -> List[Optional[List[float]]]:
- """
- 分批 Embedding
- Args:
- texts: 要嵌入的文本列表
- skip_failures: True 时失败的文本返回 None;False 时任一失败则整体返回空列表
- Returns:
- 与 texts 等长的列表,成功的位置是向量,失败的位置是 None
- """
- if not texts:
- return []
- all_embeddings: List[Optional[List[float]]] = []
- batch_size = self.config.embed_batch_size
- batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
- tasks = [self.embed(batch) for batch in batches]
- results = await asyncio.gather(*tasks)
- for batch_idx, result in enumerate(results):
- actual_batch_size = len(batches[batch_idx])
- if result and len(result) == actual_batch_size:
- all_embeddings.extend(result)
- else:
- if not skip_failures:
- print(f"[WARN] Embed batch {batch_idx} failed, aborting all")
- return []
- print(f"[WARN] Embed batch {batch_idx} failed, marking {actual_batch_size} items as None")
- all_embeddings.extend([None] * actual_batch_size)
- return all_embeddings[:len(texts)]
- async def warmup(self):
- """预热服务"""
- await self.embed(["test"])
- self._warmed_up = True
- class RerankAPIClient:
- """
- 通用 Rerank API 客户端
- 支持 OpenAI 兼容接口 (Jina/Cohere 格式) 和 Modal 自定义接口
- """
- def __init__(self, config=None):
- self.config = config or get_config()
- self.sem = asyncio.Semaphore(self.config.rerank_concurrency)
- self.stats = APIStats()
- self._warmed_up = False
- self._session: Optional[aiohttp.ClientSession] = None
- async def _get_session(self) -> aiohttp.ClientSession:
- if self._session is None or self._session.closed:
- connector = aiohttp.TCPConnector(limit=200, limit_per_host=100)
- self._session = aiohttp.ClientSession(connector=connector)
- return self._session
- async def close(self):
- if self._session and not self._session.closed:
- await self._session.close()
- def _build_headers(self) -> Dict[str, str]:
- """构建请求头"""
- headers = {"Content-Type": "application/json"}
- if self.config.rerank_api_key:
- headers["Authorization"] = f"Bearer {self.config.rerank_api_key}"
- return headers
- def _build_url(self) -> str:
- """构建请求 URL"""
- base_url = self.config.rerank_base_url.rstrip("/")
- if self.config.rerank_api_type == "openai":
- # Jina/Cohere 兼容: /v1/rerank
- if not base_url.endswith("/rerank"):
- if base_url.endswith("/v1"):
- return f"{base_url}/rerank"
- return f"{base_url}/v1/rerank"
- return base_url
- else:
- # Modal 自定义接口
- return base_url
- def _build_payload(self, query: str, documents: List[str], top_n: Optional[int]) -> Dict[str, Any]:
- """构建请求体"""
- if self.config.rerank_api_type == "openai":
- # Jina/Cohere 格式
- payload: Dict[str, Any] = {
- "query": query,
- "documents": documents,
- "model": self.config.rerank_model
- }
- if top_n:
- payload["top_n"] = top_n
- return payload
- else:
- # Modal 格式
- payload = {"query": query, "documents": documents}
- if top_n:
- payload["top_n"] = top_n
- return payload
- def _parse_response(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
- """解析响应"""
- if self.config.rerank_api_type == "openai":
- # Jina/Cohere 格式: {"results": [{"index": 0, "relevance_score": 0.9}, ...]}
- return data.get("results", [])
- else:
- # Modal 格式: {"results": [...]}
- return data.get("results", [])
- async def rerank(
- self,
- query: str,
- documents: List[str],
- top_n: Optional[int] = None
- ) -> Optional[List[Dict[str, Any]]]:
- """调用 Rerank 服务(带重试机制)"""
- if not documents:
- return []
- timeout = self.config.cold_start_timeout if not self._warmed_up else self.config.normal_timeout
- max_retries = getattr(self.config, 'api_max_retries', 3)
- base_delay = getattr(self.config, 'api_retry_delay', 1.0)
- async with self.sem:
- start = time.time()
- session = await self._get_session()
- for attempt in range(max_retries):
- try:
- url = self._build_url()
- headers = self._build_headers()
- payload = self._build_payload(query, documents, top_n)
- async with session.post(
- url,
- json=payload,
- headers=headers,
- timeout=aiohttp.ClientTimeout(total=timeout)
- ) as resp:
- if resp.status == 200:
- data = await resp.json()
- self.stats.total_calls += 1
- self.stats.total_time += time.time() - start
- self._warmed_up = True
- return self._parse_response(data)
- # 可重试的状态码
- if resp.status in (429, 500, 502, 503, 504) and attempt < max_retries - 1:
- delay = base_delay * (2 ** attempt)
- print(f"[WARN] Rerank {resp.status}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
- await asyncio.sleep(delay)
- continue
- self.stats.errors += 1
- err_text = await resp.text()
- print(f"[ERR] Rerank {resp.status}: {err_text[:200]}")
- return None
- except asyncio.TimeoutError:
- if attempt < max_retries - 1:
- delay = base_delay * (2 ** attempt)
- print(f"[WARN] Rerank timeout, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
- await asyncio.sleep(delay)
- continue
- self.stats.errors += 1
- print(f"[ERR] Rerank: Timeout after {max_retries} attempts")
- return None
- except Exception as e:
- if attempt < max_retries - 1:
- delay = base_delay * (2 ** attempt)
- print(f"[WARN] Rerank error: {e}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
- await asyncio.sleep(delay)
- continue
- self.stats.errors += 1
- print(f"[ERR] Rerank: {e}")
- return None
- return None
- async def warmup(self):
- """预热服务"""
- await self.rerank("test", ["doc1", "doc2"])
- self._warmed_up = True
- class ModalAPIClient:
- """
- 统一 API 客户端 (兼容旧接口)
- 整合 Embedding + Rerank 客户端,保持向后兼容
- """
- def __init__(self, config=None):
- self.config = config or get_config()
- self._embed_client = EmbeddingAPIClient(self.config)
- self._rerank_client = RerankAPIClient(self.config)
- # 兼容旧代码的信号量
- self.sem_embed = self._embed_client.sem
- self.sem_rerank = self._rerank_client.sem
- self._warmed_up = {"embed": False, "rerank": False}
- self._session: Optional[aiohttp.ClientSession] = None
- @property
- def stats(self) -> Dict[str, APIStats]:
- return {
- "embed": self._embed_client.stats,
- "rerank": self._rerank_client.stats
- }
- async def _get_session(self) -> aiohttp.ClientSession:
- # 复用 embed client 的 session
- return await self._embed_client._get_session()
- async def close(self):
- await self._embed_client.close()
- await self._rerank_client.close()
- # ==================== 预热 ====================
- async def warmup(self):
- """预热 Embedding 和 Rerank 服务"""
- print("[WARMUP] Warming up Embed + Rerank...")
- start = time.time()
- tasks = [self._warmup_embed(), self._warmup_rerank()]
- results = await asyncio.gather(*tasks, return_exceptions=True)
- for name, result in zip(["Embed", "Rerank"], results):
- if isinstance(result, Exception):
- print(f" [FAIL] {name}: {result}")
- else:
- print(f" [OK] {name} ready")
- print(f"[WARMUP] Done in {time.time() - start:.1f}s")
- async def _warmup_embed(self):
- await self._embed_client.warmup()
- self._warmed_up["embed"] = True
- async def _warmup_rerank(self):
- await self._rerank_client.warmup()
- self._warmed_up["rerank"] = True
- # ==================== Embedding API ====================
- async def embed(self, texts: List[str]) -> Optional[List[List[float]]]:
- """调用 Embedding 服务"""
- return await self._embed_client.embed(texts)
- async def embed_batch(
- self, texts: List[str], *, skip_failures: bool = True
- ) -> List[Optional[List[float]]]:
- """分批 Embedding"""
- return await self._embed_client.embed_batch(texts, skip_failures=skip_failures)
- # ==================== Rerank API ====================
- async def rerank(
- self,
- query: str,
- documents: List[str],
- top_n: Optional[int] = None
- ) -> Optional[List[Dict[str, Any]]]:
- """调用 Rerank 服务"""
- return await self._rerank_client.rerank(query, documents, top_n)
- # ==================== 统计 ====================
- def print_stats(self):
- print("\n[API STATS]")
- for name, stats in self.stats.items():
- if stats.total_calls > 0:
- avg_time = stats.total_time / stats.total_calls
- print(f" {name.upper()}: {stats.total_calls} calls, "
- f"{stats.total_time:.1f}s total, "
- f"{avg_time:.2f}s avg, "
- f"{stats.errors} errors")
- # 全局客户端
- _client: Optional[ModalAPIClient] = None
- def get_client(config=None) -> ModalAPIClient:
- global _client
- if _client is None or config is not None:
- _client = ModalAPIClient(config)
- return _client
|