|
|
@@ -114,47 +114,75 @@ class EmbeddingAPIClient:
|
|
|
return None
|
|
|
|
|
|
async def embed(self, texts: List[str]) -> Optional[List[List[float]]]:
|
|
|
- """调用 Embedding 服务"""
|
|
|
+ """调用 Embedding 服务(带重试机制)"""
|
|
|
if not texts:
|
|
|
return []
|
|
|
|
|
|
timeout = self.config.cold_start_timeout if not self._warmed_up else self.config.normal_timeout
|
|
|
+ max_retries = getattr(self.config, 'api_max_retries', 3)
|
|
|
+ base_delay = getattr(self.config, 'api_retry_delay', 1.0)
|
|
|
|
|
|
async with self.sem:
|
|
|
start = time.time()
|
|
|
session = await self._get_session()
|
|
|
|
|
|
- try:
|
|
|
- url = self._build_url()
|
|
|
- headers = self._build_headers()
|
|
|
- payload = self._build_payload(texts)
|
|
|
-
|
|
|
- async with session.post(
|
|
|
- url,
|
|
|
- json=payload,
|
|
|
- headers=headers,
|
|
|
- timeout=aiohttp.ClientTimeout(total=timeout)
|
|
|
- ) as resp:
|
|
|
- if resp.status == 200:
|
|
|
- text = await resp.text()
|
|
|
- import json as json_module
|
|
|
- data = json_module.loads(text)
|
|
|
- embeddings = self._parse_response(data)
|
|
|
-
|
|
|
- if embeddings:
|
|
|
- self.stats.total_calls += 1
|
|
|
- self.stats.total_time += time.time() - start
|
|
|
- return embeddings
|
|
|
+ for attempt in range(max_retries):
|
|
|
+ try:
|
|
|
+ url = self._build_url()
|
|
|
+ headers = self._build_headers()
|
|
|
+ payload = self._build_payload(texts)
|
|
|
+
|
|
|
+ async with session.post(
|
|
|
+ url,
|
|
|
+ json=payload,
|
|
|
+ headers=headers,
|
|
|
+ timeout=aiohttp.ClientTimeout(total=timeout)
|
|
|
+ ) as resp:
|
|
|
+ if resp.status == 200:
|
|
|
+ text = await resp.text()
|
|
|
+ import json as json_module
|
|
|
+ data = json_module.loads(text)
|
|
|
+ embeddings = self._parse_response(data)
|
|
|
+
|
|
|
+ if embeddings:
|
|
|
+ self.stats.total_calls += 1
|
|
|
+ self.stats.total_time += time.time() - start
|
|
|
+ self._warmed_up = True
|
|
|
+ return embeddings
|
|
|
+
|
|
|
+ # 可重试的状态码: 429 (限流), 500, 502, 503, 504
|
|
|
+ if resp.status in (429, 500, 502, 503, 504) and attempt < max_retries - 1:
|
|
|
+ delay = base_delay * (2 ** attempt) # 指数退避
|
|
|
+ print(f"[WARN] Embed {resp.status}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
|
|
|
+ await asyncio.sleep(delay)
|
|
|
+ continue
|
|
|
+
|
|
|
+ self.stats.errors += 1
|
|
|
+ err_text = await resp.text()
|
|
|
+ print(f"[ERR] Embed {resp.status}: {err_text[:200]}")
|
|
|
+ return None
|
|
|
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ delay = base_delay * (2 ** attempt)
|
|
|
+ print(f"[WARN] Embed timeout, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
|
|
|
+ await asyncio.sleep(delay)
|
|
|
+ continue
|
|
|
self.stats.errors += 1
|
|
|
- err_text = await resp.text()
|
|
|
- print(f"[ERR] Embed {resp.status}: {err_text[:200]}")
|
|
|
+ print(f"[ERR] Embed: Timeout after {max_retries} attempts")
|
|
|
return None
|
|
|
|
|
|
- except Exception as e:
|
|
|
- self.stats.errors += 1
|
|
|
- print(f"[ERR] Embed: {e}")
|
|
|
- return None
|
|
|
+ except Exception as e:
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ delay = base_delay * (2 ** attempt)
|
|
|
+ print(f"[WARN] Embed error: {e}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
|
|
|
+ await asyncio.sleep(delay)
|
|
|
+ continue
|
|
|
+ self.stats.errors += 1
|
|
|
+ print(f"[ERR] Embed: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ return None
|
|
|
|
|
|
async def embed_batch(
|
|
|
self, texts: List[str], *, skip_failures: bool = True
|
|
|
@@ -277,44 +305,72 @@ class RerankAPIClient:
|
|
|
documents: List[str],
|
|
|
top_n: Optional[int] = None
|
|
|
) -> Optional[List[Dict[str, Any]]]:
|
|
|
- """调用 Rerank 服务"""
|
|
|
+ """调用 Rerank 服务(带重试机制)"""
|
|
|
if not documents:
|
|
|
return []
|
|
|
|
|
|
timeout = self.config.cold_start_timeout if not self._warmed_up else self.config.normal_timeout
|
|
|
+ max_retries = getattr(self.config, 'api_max_retries', 3)
|
|
|
+ base_delay = getattr(self.config, 'api_retry_delay', 1.0)
|
|
|
|
|
|
async with self.sem:
|
|
|
start = time.time()
|
|
|
session = await self._get_session()
|
|
|
|
|
|
- try:
|
|
|
- url = self._build_url()
|
|
|
- headers = self._build_headers()
|
|
|
- payload = self._build_payload(query, documents, top_n)
|
|
|
-
|
|
|
- async with session.post(
|
|
|
- url,
|
|
|
- json=payload,
|
|
|
- headers=headers,
|
|
|
- timeout=aiohttp.ClientTimeout(total=timeout)
|
|
|
- ) as resp:
|
|
|
- if resp.status == 200:
|
|
|
- data = await resp.json()
|
|
|
-
|
|
|
- self.stats.total_calls += 1
|
|
|
- self.stats.total_time += time.time() - start
|
|
|
-
|
|
|
- return self._parse_response(data)
|
|
|
- else:
|
|
|
+ for attempt in range(max_retries):
|
|
|
+ try:
|
|
|
+ url = self._build_url()
|
|
|
+ headers = self._build_headers()
|
|
|
+ payload = self._build_payload(query, documents, top_n)
|
|
|
+
|
|
|
+ async with session.post(
|
|
|
+ url,
|
|
|
+ json=payload,
|
|
|
+ headers=headers,
|
|
|
+ timeout=aiohttp.ClientTimeout(total=timeout)
|
|
|
+ ) as resp:
|
|
|
+ if resp.status == 200:
|
|
|
+ data = await resp.json()
|
|
|
+
|
|
|
+ self.stats.total_calls += 1
|
|
|
+ self.stats.total_time += time.time() - start
|
|
|
+ self._warmed_up = True
|
|
|
+
|
|
|
+ return self._parse_response(data)
|
|
|
+
|
|
|
+ # 可重试的状态码
|
|
|
+ if resp.status in (429, 500, 502, 503, 504) and attempt < max_retries - 1:
|
|
|
+ delay = base_delay * (2 ** attempt)
|
|
|
+ print(f"[WARN] Rerank {resp.status}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
|
|
|
+ await asyncio.sleep(delay)
|
|
|
+ continue
|
|
|
+
|
|
|
self.stats.errors += 1
|
|
|
err_text = await resp.text()
|
|
|
print(f"[ERR] Rerank {resp.status}: {err_text[:200]}")
|
|
|
return None
|
|
|
|
|
|
- except Exception as e:
|
|
|
- self.stats.errors += 1
|
|
|
- print(f"[ERR] Rerank: {e}")
|
|
|
- return None
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ delay = base_delay * (2 ** attempt)
|
|
|
+ print(f"[WARN] Rerank timeout, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
|
|
|
+ await asyncio.sleep(delay)
|
|
|
+ continue
|
|
|
+ self.stats.errors += 1
|
|
|
+ print(f"[ERR] Rerank: Timeout after {max_retries} attempts")
|
|
|
+ return None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ delay = base_delay * (2 ** attempt)
|
|
|
+ print(f"[WARN] Rerank error: {e}, retrying in {delay:.1f}s ({attempt + 1}/{max_retries})")
|
|
|
+ await asyncio.sleep(delay)
|
|
|
+ continue
|
|
|
+ self.stats.errors += 1
|
|
|
+ print(f"[ERR] Rerank: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ return None
|
|
|
|
|
|
async def warmup(self):
|
|
|
"""预热服务"""
|