Просмотр исходного кода

Merge pull request #102 from summerfallenleaves/my-changes

fix: support Alibaba DashScope rerank endpoint /v1/reranks
lingfengQAQ 3 недель назад
Родитель
Сommit
bbc0053540

+ 75 - 0
webnovel-writer/scripts/data_modules/api_client.py

@@ -270,10 +270,58 @@ class RerankAPIClient:
             headers["Authorization"] = f"Bearer {self.config.rerank_api_key}"
         return headers
 
+    def _is_dashscope_url(self, base_url: Optional[str] = None) -> bool:
+        """判断当前 Rerank 配置是否指向 DashScope。"""
+        url = base_url if base_url is not None else self.config.rerank_base_url
+        return "dashscope" in str(url or "").lower()
+
+    def _is_dashscope_native_rerank_model(self) -> bool:
+        """DashScope 原生 rerank 接口使用 input/parameters 请求结构。"""
+        model = str(self.config.rerank_model or "").lower()
+        return model in {"qwen3-vl-rerank", "gte-rerank-v2"}
+
+    def _is_dashscope_native_rerank(self) -> bool:
+        return (
+            self.config.rerank_api_type == "openai"
+            and self._is_dashscope_url()
+            and self._is_dashscope_native_rerank_model()
+        )
+
+    def _dashscope_endpoint_root(self, base_url: str) -> str:
+        """从已配置的 DashScope base URL 中剥离已知 rerank 路径。"""
+        base_url_lower = base_url.lower()
+        known_suffixes = [
+            "/api/v1/services/rerank/text-rerank/text-rerank",
+            "/api/v1/services/rerank/text-rerank",
+            "/api/v1/services/rerank",
+            "/api/v1/services",
+            "/compatible-api/v1/reranks",
+            "/compatible-api/v1/rerank",
+            "/compatible-api/v1",
+            "/compatible-api",
+            "/v1/reranks",
+            "/v1/rerank",
+            "/api/v1",
+            "/api",
+            "/v1",
+        ]
+        for suffix in known_suffixes:
+            if base_url_lower.endswith(suffix):
+                return base_url[:-len(suffix)]
+        return base_url
+
+    def _build_dashscope_url(self, base_url: str) -> str:
+        root_url = self._dashscope_endpoint_root(base_url)
+        if self._is_dashscope_native_rerank_model():
+            return f"{root_url}/api/v1/services/rerank/text-rerank/text-rerank"
+        return f"{root_url}/compatible-api/v1/reranks"
+
     def _build_url(self) -> str:
         """构建请求 URL"""
         base_url = self.config.rerank_base_url.rstrip("/")
         if self.config.rerank_api_type == "openai":
+            if self._is_dashscope_url(base_url):
+                return self._build_dashscope_url(base_url)
             # Jina/Cohere 兼容: /v1/rerank
             if not base_url.endswith("/rerank"):
                 if base_url.endswith("/v1"):
@@ -287,6 +335,30 @@ class RerankAPIClient:
     def _build_payload(self, query: str, documents: List[str], top_n: Optional[int]) -> Dict[str, Any]:
         """构建请求体"""
         if self.config.rerank_api_type == "openai":
+            if self._is_dashscope_native_rerank():
+                parameters: Dict[str, Any] = {"return_documents": True}
+                if top_n:
+                    parameters["top_n"] = top_n
+
+                if str(self.config.rerank_model or "").lower() == "qwen3-vl-rerank":
+                    native_query: Any = {"text": query}
+                    native_documents = [
+                        document if isinstance(document, dict) else {"text": document}
+                        for document in documents
+                    ]
+                else:
+                    native_query = query
+                    native_documents = documents
+
+                return {
+                    "model": self.config.rerank_model,
+                    "input": {
+                        "query": native_query,
+                        "documents": native_documents,
+                    },
+                    "parameters": parameters,
+                }
+
             # Jina/Cohere 格式
             payload: Dict[str, Any] = {
                 "query": query,
@@ -306,6 +378,9 @@ class RerankAPIClient:
     def _parse_response(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
         """解析响应"""
         if self.config.rerank_api_type == "openai":
+            if self._is_dashscope_native_rerank():
+                # DashScope 原生格式: {"output": {"results": [...]}}
+                return data.get("output", {}).get("results", [])
             # Jina/Cohere 格式: {"results": [{"index": 0, "relevance_score": 0.9}, ...]}
             return data.get("results", [])
         else:

+ 99 - 0
webnovel-writer/scripts/data_modules/tests/test_api_client.py

@@ -292,6 +292,105 @@ async def test_embedding_exception_and_close(tmp_path, monkeypatch):
     assert session.closed is True
 
 
+def test_rerank_build_url(tmp_path):
+    config = DataModulesConfig.from_project_root(tmp_path)
+    config.rerank_api_type = "openai"
+    config.rerank_model = "qwen3-rerank"
+
+    # Jina/Cohere: bare URL
+    config.rerank_base_url = "https://api.jina.ai"
+    client = RerankAPIClient(config)
+    assert client._build_url() == "https://api.jina.ai/v1/rerank"
+
+    # Jina/Cohere: URL ending with /v1
+    config.rerank_base_url = "https://api.jina.ai/v1"
+    assert client._build_url() == "https://api.jina.ai/v1/rerank"
+
+    # Jina/Cohere: URL already ending with /rerank
+    config.rerank_base_url = "https://api.jina.ai/v1/rerank"
+    assert client._build_url() == "https://api.jina.ai/v1/rerank"
+
+    # DashScope: bare URL
+    config.rerank_base_url = "https://dashscope.aliyuncs.com"
+    assert client._build_url() == "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
+
+    # DashScope: URL ending with /v1 should be normalized to the official compatible API path
+    config.rerank_base_url = "https://dashscope.aliyuncs.com/v1"
+    assert client._build_url() == "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
+
+    # DashScope: URL ending with /compatible-api
+    config.rerank_base_url = "https://dashscope.aliyuncs.com/compatible-api"
+    assert client._build_url() == "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
+
+    # DashScope: URL ending with /compatible-api/v1
+    config.rerank_base_url = "https://dashscope.aliyuncs.com/compatible-api/v1"
+    assert client._build_url() == "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
+
+    # DashScope: URL already ending with /reranks
+    config.rerank_base_url = "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
+    assert client._build_url() == "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
+
+    # DashScope: URL accidentally ending with singular /rerank
+    config.rerank_base_url = "https://dashscope.aliyuncs.com/compatible-api/v1/rerank"
+    assert client._build_url() == "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
+
+    # DashScope: case-insensitive host detection
+    config.rerank_base_url = "https://DashScope.aliyuncs.com/compatible-api/v1"
+    assert client._build_url() == "https://DashScope.aliyuncs.com/compatible-api/v1/reranks"
+
+    # Modal: passthrough
+    config.rerank_api_type = "modal"
+    config.rerank_base_url = "https://modal.example.com/rerank"
+    assert client._build_url() == "https://modal.example.com/rerank"
+
+
+def test_rerank_dashscope_native_url_payload_and_response(tmp_path):
+    config = DataModulesConfig.from_project_root(tmp_path)
+    config.rerank_api_type = "openai"
+    config.rerank_base_url = "https://dashscope.aliyuncs.com"
+    client = RerankAPIClient(config)
+
+    native_url = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
+
+    config.rerank_model = "gte-rerank-v2"
+    assert client._build_url() == native_url
+
+    config.rerank_base_url = "https://dashscope.aliyuncs.com/api/v1"
+    assert client._build_url() == native_url
+
+    config.rerank_base_url = native_url
+    assert client._build_url() == native_url
+
+    payload = client._build_payload("q", ["doc1", "doc2"], top_n=1)
+    assert payload == {
+        "model": "gte-rerank-v2",
+        "input": {
+            "query": "q",
+            "documents": ["doc1", "doc2"],
+        },
+        "parameters": {
+            "return_documents": True,
+            "top_n": 1,
+        },
+    }
+
+    parsed = client._parse_response({"output": {"results": [{"index": 1, "relevance_score": 0.9}]}})
+    assert parsed == [{"index": 1, "relevance_score": 0.9}]
+
+    config.rerank_model = "qwen3-vl-rerank"
+    config.rerank_base_url = "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
+    assert client._build_url() == native_url
+
+    payload = client._build_payload("q", ["doc", {"image": "https://example.com/a.png"}], top_n=2)
+    assert payload["model"] == "qwen3-vl-rerank"
+    assert payload["input"]["query"] == {"text": "q"}
+    assert payload["input"]["documents"] == [
+        {"text": "doc"},
+        {"image": "https://example.com/a.png"},
+    ]
+    assert payload["parameters"] == {"return_documents": True, "top_n": 2}
+
+
 def test_rerank_headers_payload_and_stats(tmp_path, capsys):
     config = DataModulesConfig.from_project_root(tmp_path)
     config.rerank_api_key = "rk-test"