Przeglądaj źródła

Expose degraded search warning on embedding auth failure

lingfengQAQ 4 miesięcy temu
rodzic
commit
6eaef2b427

+ 10 - 0
.claude/scripts/data_modules/api_client.py

@@ -51,6 +51,8 @@ class EmbeddingAPIClient:
         self.stats = APIStats()
         self._warmed_up = False
         self._session: Optional[aiohttp.ClientSession] = None
+        self.last_error_status: Optional[int] = None
+        self.last_error_message: str = ""
 
     async def _get_session(self) -> aiohttp.ClientSession:
         if self._session is None or self._session.closed:
@@ -148,6 +150,8 @@ class EmbeddingAPIClient:
                                 self.stats.total_calls += 1
                                 self.stats.total_time += time.time() - start
                                 self._warmed_up = True
+                                self.last_error_status = None
+                                self.last_error_message = ""
                                 return embeddings
 
                         # 可重试的状态码: 429 (限流), 500, 502, 503, 504
@@ -159,6 +163,8 @@ class EmbeddingAPIClient:
 
                         self.stats.errors += 1
                         err_text = await resp.text()
+                        self.last_error_status = int(resp.status)
+                        self.last_error_message = str(err_text[:200])
                         print(f"[ERR] Embed {resp.status}: {err_text[:200]}")
                         return None
 
@@ -169,6 +175,8 @@ class EmbeddingAPIClient:
                         await asyncio.sleep(delay)
                         continue
                     self.stats.errors += 1
+                    self.last_error_status = None
+                    self.last_error_message = f"Timeout after {max_retries} attempts"
                     print(f"[ERR] Embed: Timeout after {max_retries} attempts")
                     return None
 
@@ -179,6 +187,8 @@ class EmbeddingAPIClient:
                         await asyncio.sleep(delay)
                         continue
                     self.stats.errors += 1
+                    self.last_error_status = None
+                    self.last_error_message = str(e)
                     print(f"[ERR] Embed: {e}")
                     return None
 

+ 24 - 1
.claude/scripts/data_modules/rag_adapter.py

@@ -56,8 +56,20 @@ class RAGAdapter:
         self.config = config or get_config()
         self.api_client = get_client(config)
         self.index_manager = IndexManager(self.config)
+        self._degraded_mode_reason: Optional[str] = None
         self._init_db()
 
+    @property
+    def degraded_mode_reason(self) -> Optional[str]:
+        return self._degraded_mode_reason
+
+    def _update_degraded_mode(self) -> None:
+        self._degraded_mode_reason = None
+        embed_client = getattr(self.api_client, "_embed_client", None)
+        status = getattr(embed_client, "last_error_status", None)
+        if status == 401:
+            self._degraded_mode_reason = "embedding_auth_failed"
+
     def _init_db(self):
         """初始化向量数据库"""
         self.config.ensure_dirs()
@@ -425,8 +437,11 @@ class RAGAdapter:
         # 获取查询向量
         query_embeddings = await self.api_client.embed([query])
         if not query_embeddings:
+            self._update_degraded_mode()
             return []
 
+        self._degraded_mode_reason = None
+
         query_embedding = query_embeddings[0]
 
         # 从数据库读取所有向量并计算相似度
@@ -649,7 +664,9 @@ class RAGAdapter:
             )
 
             if not query_embeddings:
+                self._update_degraded_mode()
                 return []
+            self._degraded_mode_reason = None
             query_embedding = query_embeddings[0]
 
             candidate_ids = {r.chunk_id for r in bm25_candidates_results}
@@ -927,7 +944,13 @@ def main():
             results = asyncio.run(adapter.hybrid_search(args.query, args.top_k, args.top_k, args.top_k, chunk_type=args.chunk_type))
 
         payload = [r.__dict__ for r in results]
-        emit_success(payload, message="search_results")
+        degraded_reason = adapter.degraded_mode_reason
+        if degraded_reason:
+            warnings = [{"code": "DEGRADED_MODE", "reason": degraded_reason}]
+            print_success(payload, message="search_results", warnings=warnings)
+            safe_log_tool_call(adapter.index_manager, tool_name=tool_name, success=True)
+        else:
+            emit_success(payload, message="search_results")
 
     else:
         emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")

+ 33 - 0
.claude/scripts/data_modules/tests/test_rag_adapter.py

@@ -35,6 +35,20 @@ class StubClientWithFailures(StubClient):
         return [None, [1.0, 0.0]]
 
 
+class StubEmbedClient401:
+    def __init__(self):
+        self.last_error_status = 401
+        self.last_error_message = "auth failed"
+
+
+class StubClientAuthFailure(StubClient):
+    def __init__(self):
+        self._embed_client = StubEmbedClient401()
+
+    async def embed(self, texts):
+        return None
+
+
 @pytest.fixture
 def temp_project(tmp_path, monkeypatch):
     cfg = DataModulesConfig.from_project_root(tmp_path)
@@ -201,3 +215,22 @@ def test_rag_adapter_log_query_failure_is_reported(temp_project, monkeypatch, ca
 
     message_text = "\n".join(record.getMessage() for record in caplog.records)
     assert "failed to log rag query" in message_text
+
+
+def test_rag_adapter_cli_search_shows_degraded_warning(temp_project, monkeypatch, capsys):
+    monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientAuthFailure())
+
+    def run_cli(args):
+        monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
+        rag_module.main()
+
+    root = str(temp_project.project_root)
+    run_cli(["--project-root", root, "search", "--query", "测试", "--mode", "vector", "--top-k", "3"])
+
+    captured = capsys.readouterr()
+    payload = json.loads(captured.out.strip().splitlines()[-1])
+    assert payload.get("status") == "success"
+    warnings = payload.get("warnings") or []
+    assert warnings
+    assert warnings[0].get("code") == "DEGRADED_MODE"
+    assert warnings[0].get("reason") == "embedding_auth_failed"