|
|
@@ -31,6 +31,7 @@ from datetime import datetime
|
|
|
from .config import get_config
|
|
|
from .api_client import get_client
|
|
|
from .index_manager import IndexManager
|
|
|
+from .query_router import QueryRouter
|
|
|
from .observability import safe_log_tool_call
|
|
|
|
|
|
|
|
|
@@ -71,6 +72,7 @@ class RAGAdapter:
|
|
|
self.config = config or get_config()
|
|
|
self.api_client = get_client(config)
|
|
|
self.index_manager = IndexManager(self.config)
|
|
|
+ self.query_router = QueryRouter()
|
|
|
self._degraded_mode_reason: Optional[str] = None
|
|
|
self._init_db()
|
|
|
|
|
|
@@ -257,16 +259,49 @@ class RAGAdapter:
|
|
|
row = cursor.fetchone()
|
|
|
return int(row[0] or 0) if row else 0
|
|
|
|
|
|
- def _get_recent_chunk_ids(self, limit: int, chunk_type: str | None = None) -> List[str]:
|
|
|
+ def _get_recent_chunk_ids(
|
|
|
+ self,
|
|
|
+ limit: int,
|
|
|
+ chunk_type: str | None = None,
|
|
|
+ chapter: int | None = None,
|
|
|
+ ) -> List[str]:
|
|
|
if limit <= 0:
|
|
|
return []
|
|
|
with self._get_conn() as conn:
|
|
|
cursor = conn.cursor()
|
|
|
- if chunk_type:
|
|
|
+ if chunk_type and chapter is not None:
|
|
|
cursor.execute(
|
|
|
- "SELECT chunk_id FROM vectors WHERE chunk_type = ? ORDER BY chapter DESC, scene_index DESC LIMIT ?",
|
|
|
+ """
|
|
|
+ SELECT chunk_id
|
|
|
+ FROM vectors
|
|
|
+ WHERE chunk_type = ? AND chapter <= ?
|
|
|
+ ORDER BY chapter DESC, scene_index DESC
|
|
|
+ LIMIT ?
|
|
|
+ """,
|
|
|
+ (chunk_type, int(chapter), int(limit)),
|
|
|
+ )
|
|
|
+ elif chunk_type:
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ SELECT chunk_id
|
|
|
+ FROM vectors
|
|
|
+ WHERE chunk_type = ?
|
|
|
+ ORDER BY chapter DESC, scene_index DESC
|
|
|
+ LIMIT ?
|
|
|
+ """,
|
|
|
(chunk_type, int(limit)),
|
|
|
)
|
|
|
+ elif chapter is not None:
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ SELECT chunk_id
|
|
|
+ FROM vectors
|
|
|
+ WHERE chapter <= ?
|
|
|
+ ORDER BY chapter DESC, scene_index DESC
|
|
|
+ LIMIT ?
|
|
|
+ """,
|
|
|
+ (int(chapter), int(limit)),
|
|
|
+ )
|
|
|
else:
|
|
|
cursor.execute(
|
|
|
"SELECT chunk_id FROM vectors ORDER BY chapter DESC, scene_index DESC LIMIT ?",
|
|
|
@@ -547,11 +582,29 @@ class RAGAdapter:
|
|
|
# 从数据库读取所有向量并计算相似度
|
|
|
with self._get_conn() as conn:
|
|
|
cursor = conn.cursor()
|
|
|
- if chunk_type:
|
|
|
+ if chunk_type and chapter is not None:
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file
|
|
|
+ FROM vectors
|
|
|
+ WHERE chunk_type = ? AND chapter <= ?
|
|
|
+ """,
|
|
|
+ (chunk_type, int(chapter)),
|
|
|
+ )
|
|
|
+ elif chunk_type:
|
|
|
cursor.execute(
|
|
|
"SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors WHERE chunk_type = ?",
|
|
|
(chunk_type,),
|
|
|
)
|
|
|
+ elif chapter is not None:
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file
|
|
|
+ FROM vectors
|
|
|
+ WHERE chapter <= ?
|
|
|
+ """,
|
|
|
+ (int(chapter),),
|
|
|
+ )
|
|
|
else:
|
|
|
cursor.execute(
|
|
|
"SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors"
|
|
|
@@ -666,7 +719,16 @@ class RAGAdapter:
|
|
|
# 获取文档内容
|
|
|
results = []
|
|
|
for chunk_id, score in doc_scores.items():
|
|
|
- if chunk_type:
|
|
|
+ if chunk_type and chapter is not None:
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
|
|
|
+ FROM vectors
|
|
|
+ WHERE chunk_id = ? AND chunk_type = ? AND chapter <= ?
|
|
|
+ """,
|
|
|
+ (chunk_id, chunk_type, int(chapter)),
|
|
|
+ )
|
|
|
+ elif chunk_type:
|
|
|
cursor.execute(
|
|
|
"""
|
|
|
SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
|
|
|
@@ -675,6 +737,15 @@ class RAGAdapter:
|
|
|
""",
|
|
|
(chunk_id, chunk_type),
|
|
|
)
|
|
|
+ elif chapter is not None:
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
|
|
|
+ FROM vectors
|
|
|
+ WHERE chunk_id = ? AND chapter <= ?
|
|
|
+ """,
|
|
|
+ (chunk_id, int(chapter)),
|
|
|
+ )
|
|
|
else:
|
|
|
cursor.execute(
|
|
|
"""
|
|
|
@@ -705,6 +776,377 @@ class RAGAdapter:
|
|
|
self._log_query(query, "bm25", results, latency_ms, chapter=chapter)
|
|
|
return results
|
|
|
|
|
|
+ def _extract_query_seed_entities(self, query: str) -> List[str]:
|
|
|
+ """从查询中提取种子实体(通过别名和实体 ID 匹配)。"""
|
|
|
+ tokens = set(re.findall(r"[\u4e00-\u9fff]{2,8}|[A-Za-z][A-Za-z0-9_]{1,24}", query))
|
|
|
+ entity_ids: List[str] = []
|
|
|
+ for token in tokens:
|
|
|
+ if len(entity_ids) >= int(self.config.graph_rag_max_expanded_entities):
|
|
|
+ break
|
|
|
+
|
|
|
+ # 1) 通过别名匹配
|
|
|
+ alias_hits = self.index_manager.get_entities_by_alias(token)
|
|
|
+ for hit in alias_hits:
|
|
|
+ entity_id = str(hit.get("id") or "").strip()
|
|
|
+ if entity_id and entity_id not in entity_ids:
|
|
|
+ entity_ids.append(entity_id)
|
|
|
+
|
|
|
+ if len(entity_ids) >= int(self.config.graph_rag_max_expanded_entities):
|
|
|
+ break
|
|
|
+
|
|
|
+ # 2) 通过实体 ID 直匹配
|
|
|
+ entity = self.index_manager.get_entity(token)
|
|
|
+ if entity:
|
|
|
+ entity_id = str(entity.get("id") or "").strip()
|
|
|
+ if entity_id and entity_id not in entity_ids:
|
|
|
+ entity_ids.append(entity_id)
|
|
|
+
|
|
|
+ return entity_ids[: int(self.config.graph_rag_max_expanded_entities)]
|
|
|
+
|
|
|
+ def _normalize_entity_ids(self, candidates: List[str]) -> List[str]:
|
|
|
+ """将输入实体候选(名称/别名/ID)规范化为实体 ID 列表。"""
|
|
|
+ ids: List[str] = []
|
|
|
+ for token in candidates:
|
|
|
+ candidate = str(token or "").strip()
|
|
|
+ if not candidate:
|
|
|
+ continue
|
|
|
+ direct = self.index_manager.get_entity(candidate)
|
|
|
+ if direct and direct.get("id"):
|
|
|
+ entity_id = str(direct.get("id"))
|
|
|
+ if entity_id not in ids:
|
|
|
+ ids.append(entity_id)
|
|
|
+ continue
|
|
|
+
|
|
|
+ for hit in self.index_manager.get_entities_by_alias(candidate):
|
|
|
+ entity_id = str(hit.get("id") or "").strip()
|
|
|
+ if entity_id and entity_id not in ids:
|
|
|
+ ids.append(entity_id)
|
|
|
+ return ids[: int(self.config.graph_rag_max_expanded_entities)]
|
|
|
+
|
|
|
+ def _expand_related_entities(self, seed_entities: List[str], hops: int | None = None) -> List[str]:
|
|
|
+ """基于关系图扩展相关实体。"""
|
|
|
+ max_entities = int(self.config.graph_rag_max_expanded_entities)
|
|
|
+ hops = max(1, int(hops or self.config.graph_rag_expand_hops))
|
|
|
+ expanded: List[str] = []
|
|
|
+ for seed in seed_entities:
|
|
|
+ if seed not in expanded:
|
|
|
+ expanded.append(seed)
|
|
|
+ if len(expanded) >= max_entities:
|
|
|
+ break
|
|
|
+ graph = self.index_manager.build_relationship_subgraph(
|
|
|
+ center_entity=seed,
|
|
|
+ depth=hops,
|
|
|
+ top_edges=max(20, int(self.config.graph_rag_candidate_limit)),
|
|
|
+ )
|
|
|
+ for node in graph.get("nodes", []):
|
|
|
+ entity_id = str(node.get("id") or "").strip()
|
|
|
+ if entity_id and entity_id not in expanded:
|
|
|
+ expanded.append(entity_id)
|
|
|
+ if len(expanded) >= max_entities:
|
|
|
+ break
|
|
|
+ if len(expanded) >= max_entities:
|
|
|
+ break
|
|
|
+ return expanded[:max_entities]
|
|
|
+
|
|
|
+ def _collect_graph_candidate_chunk_ids(
|
|
|
+ self,
|
|
|
+ entity_ids: List[str],
|
|
|
+ *,
|
|
|
+ chapter: int | None = None,
|
|
|
+ limit: int | None = None,
|
|
|
+ ) -> List[str]:
|
|
|
+ """根据实体名称/别名在向量库正文中筛选候选 chunk。"""
|
|
|
+ if not entity_ids:
|
|
|
+ return []
|
|
|
+
|
|
|
+ limit = int(limit or self.config.graph_rag_candidate_limit)
|
|
|
+ entity_terms: Dict[str, set[str]] = {}
|
|
|
+ for entity_id in entity_ids:
|
|
|
+ terms: set[str] = set()
|
|
|
+ entity = self.index_manager.get_entity(entity_id)
|
|
|
+ if entity:
|
|
|
+ canonical_name = str(entity.get("canonical_name") or "").strip()
|
|
|
+ if canonical_name:
|
|
|
+ terms.add(canonical_name)
|
|
|
+ for alias in self.index_manager.get_entity_aliases(entity_id):
|
|
|
+ alias_text = str(alias or "").strip()
|
|
|
+ if alias_text:
|
|
|
+ terms.add(alias_text)
|
|
|
+ if terms:
|
|
|
+ entity_terms[entity_id] = terms
|
|
|
+
|
|
|
+ if not entity_terms:
|
|
|
+ return []
|
|
|
+
|
|
|
+ with self._get_conn() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ if chapter is None:
|
|
|
+ cursor.execute(
|
|
|
+ "SELECT chunk_id, chapter, content FROM vectors ORDER BY chapter DESC, scene_index DESC"
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ SELECT chunk_id, chapter, content
|
|
|
+ FROM vectors
|
|
|
+ WHERE chapter <= ?
|
|
|
+ ORDER BY chapter DESC, scene_index DESC
|
|
|
+ """,
|
|
|
+ (int(chapter),),
|
|
|
+ )
|
|
|
+ rows = cursor.fetchall()
|
|
|
+
|
|
|
+ scored: List[Tuple[str, int, int]] = []
|
|
|
+ for chunk_id, chapter_no, content in rows:
|
|
|
+ text = str(content or "")
|
|
|
+ if not text:
|
|
|
+ continue
|
|
|
+ hit_score = 0
|
|
|
+ for terms in entity_terms.values():
|
|
|
+ hit_score += sum(1 for term in terms if term and term in text)
|
|
|
+ if hit_score > 0:
|
|
|
+ scored.append((str(chunk_id), int(chapter_no or 0), hit_score))
|
|
|
+
|
|
|
+ scored.sort(key=lambda x: (x[2], x[1]), reverse=True)
|
|
|
+ return [chunk_id for chunk_id, _chapter, _score in scored[:limit]]
|
|
|
+
|
|
|
+ async def _vector_search_by_chunk_ids(
|
|
|
+ self,
|
|
|
+ query: str,
|
|
|
+ chunk_ids: List[str],
|
|
|
+ *,
|
|
|
+ top_k: int,
|
|
|
+ chunk_type: str | None = None,
|
|
|
+ ) -> List[SearchResult]:
|
|
|
+ """在指定候选 chunk 范围内执行向量检索。"""
|
|
|
+ if not chunk_ids:
|
|
|
+ return []
|
|
|
+
|
|
|
+ query_embeddings = await self.api_client.embed([query])
|
|
|
+ if not query_embeddings:
|
|
|
+ self._update_degraded_mode()
|
|
|
+ return []
|
|
|
+ self._degraded_mode_reason = None
|
|
|
+
|
|
|
+ query_embedding = query_embeddings[0]
|
|
|
+ rows = await asyncio.to_thread(self._fetch_vectors_by_chunk_ids, chunk_ids)
|
|
|
+ if chunk_type:
|
|
|
+ rows = [r for r in rows if len(r) > 6 and r[6] == chunk_type]
|
|
|
+ return await asyncio.to_thread(
|
|
|
+ self._vector_search_rows,
|
|
|
+ query_embedding,
|
|
|
+ rows,
|
|
|
+ top_k=top_k,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _apply_graph_priors(
|
|
|
+ self,
|
|
|
+ result: SearchResult,
|
|
|
+ *,
|
|
|
+ seed_terms: set[str],
|
|
|
+ related_terms: set[str],
|
|
|
+ max_chapter: int,
|
|
|
+ ) -> float:
|
|
|
+ """为图谱候选增加先验分。"""
|
|
|
+ score = float(result.score)
|
|
|
+ content = str(result.content or "")
|
|
|
+
|
|
|
+ if any(term and term in content for term in seed_terms):
|
|
|
+ score += float(self.config.graph_rag_boost_same_entity)
|
|
|
+ elif any(term and term in content for term in related_terms):
|
|
|
+ score += float(self.config.graph_rag_boost_related_entity)
|
|
|
+
|
|
|
+ if max_chapter > 0 and result.chapter is not None:
|
|
|
+ gap = max(0, max_chapter - int(result.chapter))
|
|
|
+ recency = max(0.0, 1.0 - min(gap, 100) / 100.0)
|
|
|
+ score += recency * float(self.config.graph_rag_boost_recency)
|
|
|
+
|
|
|
+ return score
|
|
|
+
|
|
|
+ async def graph_hybrid_search(
|
|
|
+ self,
|
|
|
+ query: str,
|
|
|
+ top_k: int = 5,
|
|
|
+ *,
|
|
|
+ chunk_type: str | None = None,
|
|
|
+ chapter: int | None = None,
|
|
|
+ center_entities: Optional[List[str]] = None,
|
|
|
+ log_query: bool = True,
|
|
|
+ ) -> List[SearchResult]:
|
|
|
+ """
|
|
|
+ 图谱增强混合检索:
|
|
|
+ 1) 先走现有 hybrid 作为基础召回;
|
|
|
+ 2) 基于实体关系图扩展候选;
|
|
|
+ 3) 向量重算 + 图谱先验融合;
|
|
|
+ 4) rerank 产出最终结果。
|
|
|
+ """
|
|
|
+ start_time = time.perf_counter()
|
|
|
+
|
|
|
+ base_results = await self.hybrid_search(
|
|
|
+ query=query,
|
|
|
+ vector_top_k=max(top_k * 3, int(self.config.vector_top_k)),
|
|
|
+ bm25_top_k=max(top_k * 3, int(self.config.bm25_top_k)),
|
|
|
+ rerank_top_n=max(top_k * 2, int(self.config.rerank_top_n)),
|
|
|
+ chunk_type=chunk_type,
|
|
|
+ chapter=chapter,
|
|
|
+ log_query=False,
|
|
|
+ )
|
|
|
+ if not bool(self.config.graph_rag_enabled):
|
|
|
+ final = list(base_results)[:top_k]
|
|
|
+ if log_query:
|
|
|
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
+ self._log_query(query, "graph_hybrid_fallback", final, latency_ms, chapter=chapter)
|
|
|
+ return final
|
|
|
+
|
|
|
+ seeds = self._normalize_entity_ids([s for s in (center_entities or []) if str(s).strip()])
|
|
|
+ if not seeds:
|
|
|
+ seeds = self._extract_query_seed_entities(query)
|
|
|
+
|
|
|
+ if not seeds:
|
|
|
+ final = list(base_results)[:top_k]
|
|
|
+ if log_query:
|
|
|
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
+ self._log_query(query, "graph_hybrid_no_seed", final, latency_ms, chapter=chapter)
|
|
|
+ return final
|
|
|
+
|
|
|
+ expanded_entities = self._expand_related_entities(seeds)
|
|
|
+ candidate_chunk_ids = self._collect_graph_candidate_chunk_ids(
|
|
|
+ expanded_entities,
|
|
|
+ chapter=chapter,
|
|
|
+ limit=max(top_k * 8, int(self.config.graph_rag_candidate_limit)),
|
|
|
+ )
|
|
|
+
|
|
|
+ graph_vector_results = await self._vector_search_by_chunk_ids(
|
|
|
+ query,
|
|
|
+ candidate_chunk_ids,
|
|
|
+ top_k=max(top_k * 4, int(self.config.rerank_top_n) * 2),
|
|
|
+ chunk_type=chunk_type,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 构建实体术语集用于先验分
|
|
|
+ seed_terms: set[str] = set()
|
|
|
+ related_terms: set[str] = set()
|
|
|
+ for idx, entity_id in enumerate(expanded_entities):
|
|
|
+ entity = self.index_manager.get_entity(entity_id)
|
|
|
+ canonical_name = str((entity or {}).get("canonical_name") or "").strip()
|
|
|
+ aliases = [str(a).strip() for a in self.index_manager.get_entity_aliases(entity_id)]
|
|
|
+ terms = {t for t in [canonical_name, *aliases] if t}
|
|
|
+ if idx < len(seeds):
|
|
|
+ seed_terms.update(terms)
|
|
|
+ else:
|
|
|
+ related_terms.update(terms)
|
|
|
+
|
|
|
+ max_chapter = 0
|
|
|
+ try:
|
|
|
+ max_chapter = int(self.get_stats().get("max_chapter") or 0)
|
|
|
+ except Exception:
|
|
|
+ max_chapter = 0
|
|
|
+ if chapter is not None:
|
|
|
+ try:
|
|
|
+ max_chapter = int(chapter)
|
|
|
+ except (TypeError, ValueError):
|
|
|
+ pass
|
|
|
+
|
|
|
+ merged: Dict[str, SearchResult] = {}
|
|
|
+ for result in base_results:
|
|
|
+ result.source = "graph_hybrid"
|
|
|
+ merged[result.chunk_id] = result
|
|
|
+
|
|
|
+ for result in graph_vector_results:
|
|
|
+ adjusted = self._apply_graph_priors(
|
|
|
+ result,
|
|
|
+ seed_terms=seed_terms,
|
|
|
+ related_terms=related_terms,
|
|
|
+ max_chapter=max_chapter,
|
|
|
+ )
|
|
|
+ result.score = adjusted
|
|
|
+ result.source = "graph_hybrid"
|
|
|
+ existing = merged.get(result.chunk_id)
|
|
|
+ if existing is None or result.score > existing.score:
|
|
|
+ merged[result.chunk_id] = result
|
|
|
+
|
|
|
+ sorted_candidates = sorted(merged.values(), key=lambda r: r.score, reverse=True)
|
|
|
+ candidates = sorted_candidates[: max(top_k * 3, int(self.config.rerank_top_n) * 2)]
|
|
|
+ if not candidates:
|
|
|
+ if log_query:
|
|
|
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
+ self._log_query(query, "graph_hybrid", [], latency_ms, chapter=chapter)
|
|
|
+ return []
|
|
|
+
|
|
|
+ rerank_top_n = max(top_k, int(self.config.rerank_top_n))
|
|
|
+ rerank_input = [c.content for c in candidates]
|
|
|
+ rerank_results = await self.api_client.rerank(query, rerank_input, top_n=rerank_top_n)
|
|
|
+
|
|
|
+ final_results: List[SearchResult] = []
|
|
|
+ if rerank_results:
|
|
|
+ for item in rerank_results:
|
|
|
+ idx = int(item.get("index", 0))
|
|
|
+ if idx < 0 or idx >= len(candidates):
|
|
|
+ continue
|
|
|
+ picked = candidates[idx]
|
|
|
+ picked.score = float(item.get("relevance_score", picked.score))
|
|
|
+ picked.source = "graph_hybrid"
|
|
|
+ final_results.append(picked)
|
|
|
+ else:
|
|
|
+ final_results = candidates[:rerank_top_n]
|
|
|
+
|
|
|
+ final_results = final_results[:top_k]
|
|
|
+ if log_query:
|
|
|
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
+ self._log_query(query, "graph_hybrid", final_results, latency_ms, chapter=chapter)
|
|
|
+ return final_results
|
|
|
+
|
|
|
+ async def search(
|
|
|
+ self,
|
|
|
+ query: str,
|
|
|
+ top_k: int = 5,
|
|
|
+ *,
|
|
|
+ strategy: str = "auto",
|
|
|
+ chunk_type: str | None = None,
|
|
|
+ chapter: int | None = None,
|
|
|
+ center_entities: Optional[List[str]] = None,
|
|
|
+ filters: Optional[Dict[str, Any]] = None,
|
|
|
+ ) -> List[SearchResult]:
|
|
|
+ """统一检索入口。"""
|
|
|
+ strategy = str(strategy or "auto").lower()
|
|
|
+ if filters and chapter is None:
|
|
|
+ try:
|
|
|
+ chapter = int((filters or {}).get("to_chapter") or 0) or None
|
|
|
+ except (TypeError, ValueError):
|
|
|
+ chapter = None
|
|
|
+
|
|
|
+ if strategy == "auto":
|
|
|
+ intent_payload = self.query_router.route_intent(query)
|
|
|
+ if bool(self.config.graph_rag_enabled) and bool(intent_payload.get("needs_graph")):
|
|
|
+ strategy = "graph_hybrid"
|
|
|
+ if not center_entities:
|
|
|
+ center_entities = list(intent_payload.get("entities") or [])
|
|
|
+ else:
|
|
|
+ strategy = "hybrid"
|
|
|
+
|
|
|
+ if strategy == "vector":
|
|
|
+ return await self.vector_search(query, top_k=top_k, chunk_type=chunk_type, chapter=chapter)
|
|
|
+ if strategy == "bm25":
|
|
|
+ return self.bm25_search(query, top_k=top_k, chunk_type=chunk_type, chapter=chapter)
|
|
|
+ if strategy == "backtrack":
|
|
|
+ return await self.search_with_backtrack(query, top_k=top_k)
|
|
|
+ if strategy == "graph_hybrid":
|
|
|
+ return await self.graph_hybrid_search(
|
|
|
+ query=query,
|
|
|
+ top_k=top_k,
|
|
|
+ chunk_type=chunk_type,
|
|
|
+ chapter=chapter,
|
|
|
+ center_entities=center_entities,
|
|
|
+ )
|
|
|
+ return await self.hybrid_search(
|
|
|
+ query=query,
|
|
|
+ vector_top_k=top_k,
|
|
|
+ bm25_top_k=top_k,
|
|
|
+ rerank_top_n=top_k,
|
|
|
+ chunk_type=chunk_type,
|
|
|
+ chapter=chapter,
|
|
|
+ )
|
|
|
+
|
|
|
# ==================== 混合检索 ====================
|
|
|
|
|
|
async def hybrid_search(
|
|
|
@@ -714,6 +1156,7 @@ class RAGAdapter:
|
|
|
bm25_top_k: int = None,
|
|
|
rerank_top_n: int = None,
|
|
|
chunk_type: str | None = None,
|
|
|
+ chapter: int | None = None,
|
|
|
log_query: bool = True,
|
|
|
) -> List[SearchResult]:
|
|
|
"""
|
|
|
@@ -737,8 +1180,8 @@ class RAGAdapter:
|
|
|
if use_full_scan:
|
|
|
# 并行执行向量和 BM25 检索
|
|
|
vector_results, bm25_results = await asyncio.gather(
|
|
|
- self.vector_search(query, vector_top_k, chunk_type=chunk_type, log_query=False),
|
|
|
- asyncio.to_thread(self.bm25_search, query, bm25_top_k, 1.5, 0.75, chunk_type, False),
|
|
|
+ self.vector_search(query, vector_top_k, chunk_type=chunk_type, log_query=False, chapter=chapter),
|
|
|
+ asyncio.to_thread(self.bm25_search, query, bm25_top_k, 1.5, 0.75, chunk_type, False, chapter),
|
|
|
)
|
|
|
else:
|
|
|
bm25_candidates = max(
|
|
|
@@ -753,8 +1196,17 @@ class RAGAdapter:
|
|
|
int(rerank_top_n) * 10,
|
|
|
)
|
|
|
|
|
|
- bm25_task = asyncio.to_thread(self.bm25_search, query, bm25_candidates, 1.5, 0.75, chunk_type, False)
|
|
|
- recent_task = asyncio.to_thread(self._get_recent_chunk_ids, recent_candidates, chunk_type)
|
|
|
+ bm25_task = asyncio.to_thread(
|
|
|
+ self.bm25_search,
|
|
|
+ query,
|
|
|
+ bm25_candidates,
|
|
|
+ 1.5,
|
|
|
+ 0.75,
|
|
|
+ chunk_type,
|
|
|
+ False,
|
|
|
+ chapter,
|
|
|
+ )
|
|
|
+ recent_task = asyncio.to_thread(self._get_recent_chunk_ids, recent_candidates, chunk_type, chapter)
|
|
|
embed_task = self.api_client.embed([query])
|
|
|
|
|
|
bm25_candidates_results, recent_ids, query_embeddings = await asyncio.gather(
|
|
|
@@ -775,6 +1227,8 @@ class RAGAdapter:
|
|
|
rows = await asyncio.to_thread(self._fetch_vectors_by_chunk_ids, list(candidate_ids))
|
|
|
if chunk_type:
|
|
|
rows = [r for r in rows if len(r) > 6 and r[6] == chunk_type]
|
|
|
+ if chapter is not None:
|
|
|
+ rows = [r for r in rows if len(r) > 1 and int(r[1] or 0) <= int(chapter)]
|
|
|
vector_results = await asyncio.to_thread(
|
|
|
self._vector_search_rows,
|
|
|
query_embedding,
|
|
|
@@ -813,7 +1267,7 @@ class RAGAdapter:
|
|
|
final_results: List[SearchResult] = []
|
|
|
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
if log_query:
|
|
|
- self._log_query(query, "hybrid", final_results, latency_ms)
|
|
|
+ self._log_query(query, "hybrid", final_results, latency_ms, chapter=chapter)
|
|
|
return final_results
|
|
|
|
|
|
# 调用 Rerank API
|
|
|
@@ -825,7 +1279,7 @@ class RAGAdapter:
|
|
|
final_results = [item["result"] for item in sorted_results[:rerank_top_n]]
|
|
|
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
if log_query:
|
|
|
- self._log_query(query, "hybrid", final_results, latency_ms)
|
|
|
+ self._log_query(query, "hybrid", final_results, latency_ms, chapter=chapter)
|
|
|
return final_results
|
|
|
|
|
|
# 组装最终结果
|
|
|
@@ -840,7 +1294,7 @@ class RAGAdapter:
|
|
|
|
|
|
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
if log_query:
|
|
|
- self._log_query(query, "hybrid", final_results, latency_ms)
|
|
|
+ self._log_query(query, "hybrid", final_results, latency_ms, chapter=chapter)
|
|
|
return final_results
|
|
|
|
|
|
def _get_chunks_by_ids(self, chunk_ids: List[str]) -> List[SearchResult]:
|
|
|
@@ -951,9 +1405,18 @@ def main():
|
|
|
# 搜索
|
|
|
search_parser = subparsers.add_parser("search")
|
|
|
search_parser.add_argument("--query", required=True)
|
|
|
- search_parser.add_argument("--mode", choices=["vector", "bm25", "hybrid", "backtrack"], default="hybrid")
|
|
|
+ search_parser.add_argument(
|
|
|
+ "--mode",
|
|
|
+ choices=["auto", "vector", "bm25", "hybrid", "graph_hybrid", "backtrack"],
|
|
|
+ default="hybrid",
|
|
|
+ )
|
|
|
search_parser.add_argument("--top-k", type=int, default=5)
|
|
|
search_parser.add_argument("--chunk-type", choices=["scene", "summary"], default=None)
|
|
|
+ search_parser.add_argument(
|
|
|
+ "--center-entities",
|
|
|
+ required=False,
|
|
|
+ help="中心实体列表(JSON 数组或逗号分隔)",
|
|
|
+ )
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
@@ -1034,12 +1497,42 @@ def main():
|
|
|
emit_success(result, message="indexed")
|
|
|
|
|
|
elif args.command == "search":
|
|
|
+ center_entities: List[str] | None = None
|
|
|
+ if getattr(args, "center_entities", None):
|
|
|
+ raw = str(args.center_entities).strip()
|
|
|
+ if raw:
|
|
|
+ try:
|
|
|
+ parsed = json.loads(raw)
|
|
|
+ if isinstance(parsed, list):
|
|
|
+ center_entities = [str(x).strip() for x in parsed if str(x).strip()]
|
|
|
+ except Exception:
|
|
|
+ center_entities = [x.strip() for x in re.split(r"[,,;;\s]+", raw) if x.strip()]
|
|
|
+
|
|
|
if args.mode == "vector":
|
|
|
results = asyncio.run(adapter.vector_search(args.query, args.top_k, chunk_type=args.chunk_type))
|
|
|
elif args.mode == "bm25":
|
|
|
results = adapter.bm25_search(args.query, args.top_k, chunk_type=args.chunk_type)
|
|
|
elif args.mode == "backtrack":
|
|
|
results = asyncio.run(adapter.search_with_backtrack(args.query, args.top_k))
|
|
|
+ elif args.mode == "graph_hybrid":
|
|
|
+ results = asyncio.run(
|
|
|
+ adapter.graph_hybrid_search(
|
|
|
+ args.query,
|
|
|
+ args.top_k,
|
|
|
+ chunk_type=args.chunk_type,
|
|
|
+ center_entities=center_entities,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ elif args.mode == "auto":
|
|
|
+ results = asyncio.run(
|
|
|
+ adapter.search(
|
|
|
+ args.query,
|
|
|
+ args.top_k,
|
|
|
+ strategy="auto",
|
|
|
+ chunk_type=args.chunk_type,
|
|
|
+ center_entities=center_entities,
|
|
|
+ )
|
|
|
+ )
|
|
|
else:
|
|
|
results = asyncio.run(adapter.hybrid_search(args.query, args.top_k, args.top_k, args.top_k, chunk_type=args.chunk_type))
|
|
|
|