rag_adapter.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. RAG Adapter - RAG 检索适配模块
  5. 封装向量检索功能:
  6. - 向量嵌入 (调用 Modal API)
  7. - 语义搜索
  8. - 重排序
  9. - 混合检索 (向量 + BM25)
  10. """
  11. import asyncio
  12. import sqlite3
  13. import json
  14. import math
  15. import logging
  16. from pathlib import Path
  17. from runtime_compat import enable_windows_utf8_stdio
  18. from typing import Dict, List, Optional, Any, Tuple
  19. from dataclasses import dataclass
  20. from collections import Counter
  21. import re
  22. from contextlib import contextmanager
  23. import itertools
  24. import time
  25. from .config import get_config
  26. from .api_client import get_client
  27. from .index_manager import IndexManager
  28. from .observability import safe_log_tool_call
  29. logger = logging.getLogger(__name__)
  30. @dataclass
  31. class SearchResult:
  32. """搜索结果"""
  33. chunk_id: str
  34. chapter: int
  35. scene_index: int
  36. content: str
  37. score: float
  38. source: str # "vector" | "bm25" | "hybrid"
  39. parent_chunk_id: str | None = None
  40. chunk_type: str | None = None
  41. source_file: str | None = None
  42. class RAGAdapter:
  43. """RAG 检索适配器"""
  44. def __init__(self, config=None):
  45. self.config = config or get_config()
  46. self.api_client = get_client(config)
  47. self.index_manager = IndexManager(self.config)
  48. self._degraded_mode_reason: Optional[str] = None
  49. self._init_db()
  50. @property
  51. def degraded_mode_reason(self) -> Optional[str]:
  52. return self._degraded_mode_reason
  53. def _update_degraded_mode(self) -> None:
  54. self._degraded_mode_reason = None
  55. embed_client = getattr(self.api_client, "_embed_client", None)
  56. status = getattr(embed_client, "last_error_status", None)
  57. if status == 401:
  58. self._degraded_mode_reason = "embedding_auth_failed"
  59. def _init_db(self):
  60. """初始化向量数据库"""
  61. self.config.ensure_dirs()
  62. with self._get_conn() as conn:
  63. cursor = conn.cursor()
  64. def _table_columns(table_name: str) -> set[str]:
  65. cursor.execute(f"PRAGMA table_info({table_name})")
  66. return {row[1] for row in cursor.fetchall()}
  67. required_cols = {
  68. "chunk_id",
  69. "chapter",
  70. "scene_index",
  71. "content",
  72. "embedding",
  73. "parent_chunk_id",
  74. "chunk_type",
  75. "source_file",
  76. "created_at",
  77. }
  78. if "vectors" in {r[0] for r in cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")}: # type: ignore
  79. cols = _table_columns("vectors")
  80. if not required_cols.issubset(cols):
  81. cursor.execute("DROP TABLE IF EXISTS vectors")
  82. cursor.execute("DROP TABLE IF EXISTS bm25_index")
  83. cursor.execute("DROP TABLE IF EXISTS doc_stats")
  84. # 向量存储表
  85. cursor.execute("""
  86. CREATE TABLE IF NOT EXISTS vectors (
  87. chunk_id TEXT PRIMARY KEY,
  88. chapter INTEGER,
  89. scene_index INTEGER,
  90. content TEXT,
  91. embedding BLOB,
  92. parent_chunk_id TEXT,
  93. chunk_type TEXT DEFAULT 'scene',
  94. source_file TEXT,
  95. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  96. )
  97. """)
  98. # BM25 倒排索引表
  99. cursor.execute("""
  100. CREATE TABLE IF NOT EXISTS bm25_index (
  101. term TEXT,
  102. chunk_id TEXT,
  103. tf REAL,
  104. PRIMARY KEY (term, chunk_id)
  105. )
  106. """)
  107. # 文档统计表
  108. cursor.execute("""
  109. CREATE TABLE IF NOT EXISTS doc_stats (
  110. chunk_id TEXT PRIMARY KEY,
  111. doc_length INTEGER
  112. )
  113. """)
  114. # 创建索引
  115. cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_chapter ON vectors(chapter)")
  116. cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_parent ON vectors(parent_chunk_id)")
  117. cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_type ON vectors(chunk_type)")
  118. cursor.execute("CREATE INDEX IF NOT EXISTS idx_bm25_term ON bm25_index(term)")
  119. conn.commit()
  120. @contextmanager
  121. def _get_conn(self):
  122. """获取数据库连接(确保关闭,避免 Windows 下文件句柄泄漏)"""
  123. conn = sqlite3.connect(str(self.config.vector_db))
  124. try:
  125. yield conn
  126. finally:
  127. conn.close()
  128. def _get_vectors_count(self) -> int:
  129. with self._get_conn() as conn:
  130. cursor = conn.cursor()
  131. cursor.execute("SELECT COUNT(*) FROM vectors")
  132. row = cursor.fetchone()
  133. return int(row[0] or 0) if row else 0
  134. def _get_recent_chunk_ids(self, limit: int, chunk_type: str | None = None) -> List[str]:
  135. if limit <= 0:
  136. return []
  137. with self._get_conn() as conn:
  138. cursor = conn.cursor()
  139. if chunk_type:
  140. cursor.execute(
  141. "SELECT chunk_id FROM vectors WHERE chunk_type = ? ORDER BY chapter DESC, scene_index DESC LIMIT ?",
  142. (chunk_type, int(limit)),
  143. )
  144. else:
  145. cursor.execute(
  146. "SELECT chunk_id FROM vectors ORDER BY chapter DESC, scene_index DESC LIMIT ?",
  147. (int(limit),),
  148. )
  149. return [str(r[0]) for r in cursor.fetchall() if r and r[0]]
  150. def _fetch_vectors_by_chunk_ids(self, chunk_ids: List[str]) -> List[Tuple]:
  151. if not chunk_ids:
  152. return []
  153. # SQLite 参数数量限制(默认 999),这里做分片查询
  154. def _chunks(xs: List[str], size: int = 500):
  155. it = iter(xs)
  156. while True:
  157. batch = list(itertools.islice(it, size))
  158. if not batch:
  159. break
  160. yield batch
  161. rows: List[Tuple] = []
  162. with self._get_conn() as conn:
  163. cursor = conn.cursor()
  164. for batch in _chunks(chunk_ids):
  165. placeholders = ",".join(["?"] * len(batch))
  166. cursor.execute(
  167. f"SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors WHERE chunk_id IN ({placeholders})",
  168. tuple(batch),
  169. )
  170. rows.extend(cursor.fetchall())
  171. return rows
  172. def _vector_search_rows(
  173. self,
  174. query_embedding: List[float],
  175. rows: List[Tuple],
  176. *,
  177. top_k: int,
  178. ) -> List[SearchResult]:
  179. results: List[SearchResult] = []
  180. for row in rows:
  181. (
  182. chunk_id,
  183. chapter,
  184. scene_index,
  185. content,
  186. embedding_bytes,
  187. parent_chunk_id,
  188. chunk_type,
  189. source_file,
  190. ) = row
  191. if not embedding_bytes:
  192. continue
  193. embedding = self._deserialize_embedding(embedding_bytes)
  194. score = self._cosine_similarity(query_embedding, embedding)
  195. results.append(
  196. SearchResult(
  197. chunk_id=chunk_id,
  198. chapter=chapter,
  199. scene_index=scene_index,
  200. content=content,
  201. score=score,
  202. source="vector",
  203. parent_chunk_id=parent_chunk_id,
  204. chunk_type=chunk_type,
  205. source_file=source_file,
  206. )
  207. )
  208. results.sort(key=lambda x: x.score, reverse=True)
  209. return results[:top_k]
  210. # ==================== 向量存储 ====================
  211. async def store_chunks(self, chunks: List[Dict]) -> int:
  212. """
  213. 存储场景切片的向量
  214. chunks 格式:
  215. [
  216. {
  217. "chapter": 100,
  218. "scene_index": 1,
  219. "content": "场景内容...",
  220. "chunk_type": "scene",
  221. "parent_chunk_id": "ch0100_summary",
  222. "source_file": "正文/第0100章.md#scene_1"
  223. }
  224. ]
  225. 返回存储数量
  226. """
  227. if not chunks:
  228. return 0
  229. # 提取内容用于嵌入
  230. contents = [c.get("content", "") for c in chunks]
  231. # 调用 API 获取嵌入向量(可能包含 None 表示失败)
  232. embeddings = await self.api_client.embed_batch(contents)
  233. if not embeddings:
  234. return 0
  235. # 存储到数据库(跳过嵌入失败的 chunk)
  236. stored = 0
  237. skipped = 0
  238. errors = []
  239. with self._get_conn() as conn:
  240. cursor = conn.cursor()
  241. for chunk, embedding in zip(chunks, embeddings):
  242. if embedding is None:
  243. # 嵌入失败,跳过该 chunk(仅存储 BM25 索引供关键词检索)
  244. skipped += 1
  245. chunk_id = chunk.get("chunk_id")
  246. if not chunk_id:
  247. if chunk.get("chunk_type") == "summary":
  248. chunk_id = f"ch{int(chunk['chapter']):04d}_summary"
  249. else:
  250. chunk_id = f"ch{int(chunk['chapter']):04d}_s{int(chunk['scene_index'])}"
  251. try:
  252. self._update_bm25_index(cursor, chunk_id, chunk.get("content", ""))
  253. except Exception as e:
  254. errors.append(f"BM25 index failed for {chunk_id}: {e}")
  255. continue
  256. chunk_type = chunk.get("chunk_type") or "scene"
  257. chunk_id = chunk.get("chunk_id")
  258. if not chunk_id:
  259. if chunk_type == "summary":
  260. chunk_id = f"ch{int(chunk['chapter']):04d}_summary"
  261. else:
  262. chunk_id = f"ch{int(chunk['chapter']):04d}_s{int(chunk['scene_index'])}"
  263. # 将向量序列化为 bytes
  264. embedding_bytes = self._serialize_embedding(embedding)
  265. cursor.execute("""
  266. INSERT OR REPLACE INTO vectors
  267. (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file)
  268. VALUES (?, ?, ?, ?, ?, ?, ?, ?)
  269. """, (
  270. chunk_id,
  271. chunk["chapter"],
  272. chunk.get("scene_index", 0) if chunk_type == "scene" else 0,
  273. chunk.get("content", ""),
  274. embedding_bytes,
  275. chunk.get("parent_chunk_id"),
  276. chunk_type,
  277. chunk.get("source_file"),
  278. ))
  279. # 同时更新 BM25 索引
  280. try:
  281. self._update_bm25_index(cursor, chunk_id, chunk.get("content", ""))
  282. except Exception as e:
  283. errors.append(f"BM25 index failed for {chunk_id}: {e}")
  284. stored += 1
  285. try:
  286. conn.commit()
  287. except Exception as e:
  288. logger.error("SQLite commit failed: %s", e)
  289. errors.append(f"SQLite commit failed: {e}")
  290. # 输出警告日志
  291. if skipped > 0:
  292. logger.warning(
  293. "Vector embedding: %s stored, %s skipped (embedding failed)",
  294. stored,
  295. skipped,
  296. )
  297. if errors:
  298. for err in errors[:5]: # 最多显示5条
  299. logger.warning("%s", err)
  300. return stored
  301. def _serialize_embedding(self, embedding: List[float]) -> bytes:
  302. """序列化向量"""
  303. import struct
  304. return struct.pack(f"{len(embedding)}f", *embedding)
  305. def _deserialize_embedding(self, data: bytes) -> List[float]:
  306. """反序列化向量"""
  307. import struct
  308. count = len(data) // 4
  309. return list(struct.unpack(f"{count}f", data))
  310. def _log_query(
  311. self,
  312. query: str,
  313. query_type: str,
  314. results: List[SearchResult],
  315. latency_ms: int,
  316. chapter: int | None = None,
  317. ) -> None:
  318. try:
  319. hit_sources = Counter([r.chunk_type or "unknown" for r in results])
  320. self.index_manager.log_rag_query(
  321. query=query,
  322. query_type=query_type,
  323. results_count=len(results),
  324. hit_sources=json.dumps(hit_sources, ensure_ascii=False),
  325. latency_ms=latency_ms,
  326. chapter=chapter,
  327. )
  328. except Exception as exc:
  329. logger.warning("failed to log rag query: %s", exc)
  330. # ==================== BM25 索引 ====================
  331. def _tokenize(self, text: str) -> List[str]:
  332. """简单分词(中文按字符,英文按单词)"""
  333. # 中文字符
  334. chinese = re.findall(r'[\u4e00-\u9fff]+', text)
  335. chinese_chars = list("".join(chinese))
  336. # 英文单词
  337. english = re.findall(r'[a-zA-Z]+', text.lower())
  338. return chinese_chars + english
  339. def _update_bm25_index(self, cursor, chunk_id: str, content: str):
  340. """更新 BM25 索引"""
  341. # 删除旧索引
  342. cursor.execute("DELETE FROM bm25_index WHERE chunk_id = ?", (chunk_id,))
  343. cursor.execute("DELETE FROM doc_stats WHERE chunk_id = ?", (chunk_id,))
  344. # 分词
  345. tokens = self._tokenize(content)
  346. doc_length = len(tokens)
  347. # 计算词频
  348. tf_counter = Counter(tokens)
  349. # 插入倒排索引
  350. for term, count in tf_counter.items():
  351. tf = count / doc_length if doc_length > 0 else 0
  352. cursor.execute("""
  353. INSERT INTO bm25_index (term, chunk_id, tf)
  354. VALUES (?, ?, ?)
  355. """, (term, chunk_id, tf))
  356. # 更新文档统计
  357. cursor.execute("""
  358. INSERT INTO doc_stats (chunk_id, doc_length)
  359. VALUES (?, ?)
  360. """, (chunk_id, doc_length))
  361. # ==================== 向量检索 ====================
  362. async def vector_search(
  363. self,
  364. query: str,
  365. top_k: int = None,
  366. chunk_type: str | None = None,
  367. log_query: bool = True,
  368. chapter: int | None = None,
  369. ) -> List[SearchResult]:
  370. """向量相似度搜索"""
  371. top_k = top_k or self.config.vector_top_k
  372. start_time = time.perf_counter()
  373. # 获取查询向量
  374. query_embeddings = await self.api_client.embed([query])
  375. if not query_embeddings:
  376. self._update_degraded_mode()
  377. return []
  378. self._degraded_mode_reason = None
  379. query_embedding = query_embeddings[0]
  380. # 从数据库读取所有向量并计算相似度
  381. with self._get_conn() as conn:
  382. cursor = conn.cursor()
  383. if chunk_type:
  384. cursor.execute(
  385. "SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors WHERE chunk_type = ?",
  386. (chunk_type,),
  387. )
  388. else:
  389. cursor.execute(
  390. "SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors"
  391. )
  392. results = []
  393. for row in cursor.fetchall():
  394. (
  395. chunk_id,
  396. chapter,
  397. scene_index,
  398. content,
  399. embedding_bytes,
  400. parent_chunk_id,
  401. chunk_type_value,
  402. source_file,
  403. ) = row
  404. if not embedding_bytes:
  405. continue
  406. embedding = self._deserialize_embedding(embedding_bytes)
  407. # 计算余弦相似度
  408. score = self._cosine_similarity(query_embedding, embedding)
  409. results.append(SearchResult(
  410. chunk_id=chunk_id,
  411. chapter=chapter,
  412. scene_index=scene_index,
  413. content=content,
  414. score=score,
  415. source="vector",
  416. parent_chunk_id=parent_chunk_id,
  417. chunk_type=chunk_type_value,
  418. source_file=source_file,
  419. ))
  420. # 排序并返回 top_k
  421. results.sort(key=lambda x: x.score, reverse=True)
  422. results = results[:top_k]
  423. if log_query:
  424. latency_ms = int((time.perf_counter() - start_time) * 1000)
  425. self._log_query(query, "vector", results, latency_ms, chapter=chapter)
  426. return results
  427. def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
  428. """计算余弦相似度"""
  429. dot_product = sum(x * y for x, y in zip(a, b))
  430. norm_a = math.sqrt(sum(x * x for x in a))
  431. norm_b = math.sqrt(sum(x * x for x in b))
  432. if norm_a == 0 or norm_b == 0:
  433. return 0.0
  434. return dot_product / (norm_a * norm_b)
  435. # ==================== BM25 检索 ====================
  436. def bm25_search(
  437. self,
  438. query: str,
  439. top_k: int = None,
  440. k1: float = 1.5,
  441. b: float = 0.75,
  442. chunk_type: str | None = None,
  443. log_query: bool = True,
  444. chapter: int | None = None,
  445. ) -> List[SearchResult]:
  446. """BM25 关键词搜索"""
  447. top_k = top_k or self.config.bm25_top_k
  448. start_time = time.perf_counter()
  449. query_terms = self._tokenize(query)
  450. if not query_terms:
  451. return []
  452. with self._get_conn() as conn:
  453. cursor = conn.cursor()
  454. # 获取文档总数和平均长度
  455. cursor.execute("SELECT COUNT(*), AVG(doc_length) FROM doc_stats")
  456. row = cursor.fetchone()
  457. total_docs = row[0] or 1
  458. avg_doc_length = row[1] or 1
  459. # 计算每个文档的 BM25 分数
  460. doc_scores = {}
  461. for term in set(query_terms):
  462. # 获取包含该词的文档
  463. cursor.execute("""
  464. SELECT b.chunk_id, b.tf, d.doc_length
  465. FROM bm25_index b
  466. JOIN doc_stats d ON b.chunk_id = d.chunk_id
  467. WHERE b.term = ?
  468. """, (term,))
  469. docs_with_term = cursor.fetchall()
  470. df = len(docs_with_term)
  471. if df == 0:
  472. continue
  473. # IDF
  474. idf = math.log((total_docs - df + 0.5) / (df + 0.5) + 1)
  475. for chunk_id, tf, doc_length in docs_with_term:
  476. # BM25 公式
  477. score = idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * doc_length / avg_doc_length))
  478. if chunk_id not in doc_scores:
  479. doc_scores[chunk_id] = 0
  480. doc_scores[chunk_id] += score
  481. # 获取文档内容
  482. results = []
  483. for chunk_id, score in doc_scores.items():
  484. if chunk_type:
  485. cursor.execute(
  486. """
  487. SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
  488. FROM vectors
  489. WHERE chunk_id = ? AND chunk_type = ?
  490. """,
  491. (chunk_id, chunk_type),
  492. )
  493. else:
  494. cursor.execute(
  495. """
  496. SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
  497. FROM vectors
  498. WHERE chunk_id = ?
  499. """,
  500. (chunk_id,),
  501. )
  502. row = cursor.fetchone()
  503. if row:
  504. results.append(SearchResult(
  505. chunk_id=chunk_id,
  506. chapter=row[0],
  507. scene_index=row[1],
  508. content=row[2],
  509. score=score,
  510. source="bm25",
  511. parent_chunk_id=row[3],
  512. chunk_type=row[4],
  513. source_file=row[5],
  514. ))
  515. results.sort(key=lambda x: x.score, reverse=True)
  516. results = results[:top_k]
  517. if log_query:
  518. latency_ms = int((time.perf_counter() - start_time) * 1000)
  519. self._log_query(query, "bm25", results, latency_ms, chapter=chapter)
  520. return results
  521. # ==================== 混合检索 ====================
  522. async def hybrid_search(
  523. self,
  524. query: str,
  525. vector_top_k: int = None,
  526. bm25_top_k: int = None,
  527. rerank_top_n: int = None,
  528. chunk_type: str | None = None,
  529. log_query: bool = True,
  530. ) -> List[SearchResult]:
  531. """
  532. 混合检索:向量 + BM25 + RRF 融合 + Rerank
  533. 步骤:
  534. 1. 向量检索 top_k
  535. 2. BM25 检索 top_k
  536. 3. RRF 融合
  537. 4. Rerank 精排
  538. """
  539. vector_top_k = vector_top_k or self.config.vector_top_k
  540. bm25_top_k = bm25_top_k or self.config.bm25_top_k
  541. rerank_top_n = rerank_top_n or self.config.rerank_top_n
  542. start_time = time.perf_counter()
  543. # 小规模:全表向量扫描(召回更稳);大规模:预筛选避免 O(n) 扫描拖慢
  544. vectors_count = await asyncio.to_thread(self._get_vectors_count)
  545. use_full_scan = vectors_count <= int(self.config.vector_full_scan_max_vectors)
  546. if use_full_scan:
  547. # 并行执行向量和 BM25 检索
  548. vector_results, bm25_results = await asyncio.gather(
  549. self.vector_search(query, vector_top_k, chunk_type=chunk_type, log_query=False),
  550. asyncio.to_thread(self.bm25_search, query, bm25_top_k, 1.5, 0.75, chunk_type, False),
  551. )
  552. else:
  553. bm25_candidates = max(
  554. int(self.config.vector_prefilter_bm25_candidates),
  555. int(bm25_top_k),
  556. int(vector_top_k) * 5,
  557. int(rerank_top_n) * 10,
  558. )
  559. recent_candidates = max(
  560. int(self.config.vector_prefilter_recent_candidates),
  561. int(vector_top_k) * 5,
  562. int(rerank_top_n) * 10,
  563. )
  564. bm25_task = asyncio.to_thread(self.bm25_search, query, bm25_candidates, 1.5, 0.75, chunk_type, False)
  565. recent_task = asyncio.to_thread(self._get_recent_chunk_ids, recent_candidates, chunk_type)
  566. embed_task = self.api_client.embed([query])
  567. bm25_candidates_results, recent_ids, query_embeddings = await asyncio.gather(
  568. bm25_task,
  569. recent_task,
  570. embed_task,
  571. )
  572. if not query_embeddings:
  573. self._update_degraded_mode()
  574. return []
  575. self._degraded_mode_reason = None
  576. query_embedding = query_embeddings[0]
  577. candidate_ids = {r.chunk_id for r in bm25_candidates_results}
  578. candidate_ids.update(recent_ids)
  579. rows = await asyncio.to_thread(self._fetch_vectors_by_chunk_ids, list(candidate_ids))
  580. if chunk_type:
  581. rows = [r for r in rows if len(r) > 6 and r[6] == chunk_type]
  582. vector_results = await asyncio.to_thread(
  583. self._vector_search_rows,
  584. query_embedding,
  585. rows,
  586. top_k=int(vector_top_k),
  587. )
  588. # BM25 结果用于融合时只取 top_k
  589. bm25_results = list(bm25_candidates_results)[: int(bm25_top_k)]
  590. # RRF 融合
  591. rrf_scores = {}
  592. k = self.config.rrf_k
  593. for rank, result in enumerate(vector_results):
  594. if result.chunk_id not in rrf_scores:
  595. rrf_scores[result.chunk_id] = {"result": result, "score": 0}
  596. rrf_scores[result.chunk_id]["score"] += 1 / (k + rank + 1)
  597. for rank, result in enumerate(bm25_results):
  598. if result.chunk_id not in rrf_scores:
  599. rrf_scores[result.chunk_id] = {"result": result, "score": 0}
  600. rrf_scores[result.chunk_id]["score"] += 1 / (k + rank + 1)
  601. # 按 RRF 分数排序
  602. sorted_results = sorted(
  603. rrf_scores.values(),
  604. key=lambda x: x["score"],
  605. reverse=True
  606. )
  607. # 取 top candidates 进行 rerank
  608. candidates = [item["result"] for item in sorted_results[:rerank_top_n * 2]]
  609. if not candidates:
  610. final_results: List[SearchResult] = []
  611. latency_ms = int((time.perf_counter() - start_time) * 1000)
  612. if log_query:
  613. self._log_query(query, "hybrid", final_results, latency_ms)
  614. return final_results
  615. # 调用 Rerank API
  616. documents = [c.content for c in candidates]
  617. rerank_results = await self.api_client.rerank(query, documents, top_n=rerank_top_n)
  618. if not rerank_results:
  619. # Rerank 失败,返回 RRF 结果
  620. final_results = [item["result"] for item in sorted_results[:rerank_top_n]]
  621. latency_ms = int((time.perf_counter() - start_time) * 1000)
  622. if log_query:
  623. self._log_query(query, "hybrid", final_results, latency_ms)
  624. return final_results
  625. # 组装最终结果
  626. final_results = []
  627. for r in rerank_results:
  628. idx = r.get("index", 0)
  629. if idx < len(candidates):
  630. result = candidates[idx]
  631. result.score = r.get("relevance_score", 0)
  632. result.source = "hybrid"
  633. final_results.append(result)
  634. latency_ms = int((time.perf_counter() - start_time) * 1000)
  635. if log_query:
  636. self._log_query(query, "hybrid", final_results, latency_ms)
  637. return final_results
  638. def _get_chunks_by_ids(self, chunk_ids: List[str]) -> List[SearchResult]:
  639. rows = self._fetch_vectors_by_chunk_ids(chunk_ids)
  640. results: List[SearchResult] = []
  641. for row in rows:
  642. (
  643. chunk_id,
  644. chapter,
  645. scene_index,
  646. content,
  647. _embedding_bytes,
  648. parent_chunk_id,
  649. chunk_type,
  650. source_file,
  651. ) = row
  652. results.append(
  653. SearchResult(
  654. chunk_id=chunk_id,
  655. chapter=chapter,
  656. scene_index=scene_index,
  657. content=content,
  658. score=0.0,
  659. source="parent",
  660. parent_chunk_id=parent_chunk_id,
  661. chunk_type=chunk_type,
  662. source_file=source_file,
  663. )
  664. )
  665. return results
  666. def _merge_results(
  667. self,
  668. parents: List[SearchResult],
  669. children: List[SearchResult],
  670. ) -> List[SearchResult]:
  671. parent_map = {p.chunk_id: p for p in parents}
  672. merged: List[SearchResult] = []
  673. seen = set()
  674. for child in children:
  675. parent_id = child.parent_chunk_id
  676. if parent_id and parent_id in parent_map and parent_id not in seen:
  677. merged.append(parent_map[parent_id])
  678. seen.add(parent_id)
  679. merged.append(child)
  680. return merged
  681. async def search_with_backtrack(self, query: str, top_k: int = 5) -> List[SearchResult]:
  682. start_time = time.perf_counter()
  683. child_results = await self.hybrid_search(
  684. query,
  685. vector_top_k=top_k * 2,
  686. bm25_top_k=top_k * 2,
  687. rerank_top_n=top_k,
  688. chunk_type="scene",
  689. log_query=False,
  690. )
  691. parent_ids = sorted({r.parent_chunk_id for r in child_results if r.parent_chunk_id})
  692. parents = self._get_chunks_by_ids(parent_ids) if parent_ids else []
  693. merged = self._merge_results(parents, child_results[:top_k])
  694. latency_ms = int((time.perf_counter() - start_time) * 1000)
  695. self._log_query(query, "backtrack", merged, latency_ms)
  696. return merged
  697. # ==================== 统计 ====================
  698. def get_stats(self) -> Dict[str, int]:
  699. """获取 RAG 统计"""
  700. with self._get_conn() as conn:
  701. cursor = conn.cursor()
  702. cursor.execute("SELECT COUNT(*) FROM vectors")
  703. vectors = cursor.fetchone()[0]
  704. cursor.execute("SELECT COUNT(DISTINCT term) FROM bm25_index")
  705. terms = cursor.fetchone()[0]
  706. cursor.execute("SELECT MAX(chapter) FROM vectors")
  707. max_chapter = cursor.fetchone()[0] or 0
  708. return {
  709. "vectors": vectors,
  710. "terms": terms,
  711. "max_chapter": max_chapter
  712. }
  713. # ==================== CLI 接口 ====================
  714. def main():
  715. import argparse
  716. from .cli_output import print_success, print_error
  717. parser = argparse.ArgumentParser(description="RAG Adapter CLI")
  718. parser.add_argument("--project-root", type=str, help="项目根目录")
  719. subparsers = parser.add_subparsers(dest="command")
  720. # 获取统计
  721. subparsers.add_parser("stats")
  722. # 写入索引
  723. index_parser = subparsers.add_parser("index-chapter")
  724. index_parser.add_argument("--chapter", type=int, required=True)
  725. index_parser.add_argument("--scenes", required=True, help="JSON 格式的场景列表")
  726. index_parser.add_argument("--summary", required=False, help="章节摘要文本")
  727. # 搜索
  728. search_parser = subparsers.add_parser("search")
  729. search_parser.add_argument("--query", required=True)
  730. search_parser.add_argument("--mode", choices=["vector", "bm25", "hybrid", "backtrack"], default="hybrid")
  731. search_parser.add_argument("--top-k", type=int, default=5)
  732. search_parser.add_argument("--chunk-type", choices=["scene", "summary"], default=None)
  733. args = parser.parse_args()
  734. # 初始化
  735. config = None
  736. if args.project_root:
  737. from .config import DataModulesConfig
  738. config = DataModulesConfig.from_project_root(args.project_root)
  739. adapter = RAGAdapter(config)
  740. tool_name = f"rag_adapter:{args.command or 'unknown'}"
  741. def emit_success(data=None, message: str = "ok"):
  742. print_success(data, message=message)
  743. safe_log_tool_call(adapter.index_manager, tool_name=tool_name, success=True)
  744. def emit_error(code: str, message: str, suggestion: str | None = None):
  745. print_error(code, message, suggestion=suggestion)
  746. safe_log_tool_call(
  747. adapter.index_manager,
  748. tool_name=tool_name,
  749. success=False,
  750. error_code=code,
  751. error_message=message,
  752. )
  753. if args.command == "stats":
  754. stats = adapter.get_stats()
  755. emit_success(stats, message="stats")
  756. elif args.command == "index-chapter":
  757. scenes = json.loads(args.scenes)
  758. chunks = []
  759. # summary chunk
  760. summary_text = args.summary
  761. if not summary_text and config:
  762. summary_path = config.webnovel_dir / "summaries" / f"ch{args.chapter:04d}.md"
  763. if summary_path.exists():
  764. summary_text = summary_path.read_text(encoding="utf-8")
  765. parent_chunk_id = None
  766. if summary_text:
  767. parent_chunk_id = f"ch{args.chapter:04d}_summary"
  768. chunks.append(
  769. {
  770. "chapter": args.chapter,
  771. "scene_index": 0,
  772. "content": summary_text,
  773. "chunk_type": "summary",
  774. "chunk_id": parent_chunk_id,
  775. "source_file": f"summaries/ch{args.chapter:04d}.md",
  776. }
  777. )
  778. for s in scenes:
  779. scene_index = s.get("index", 0)
  780. chunk_id = f"ch{args.chapter:04d}_s{int(scene_index)}"
  781. chunks.append(
  782. {
  783. "chapter": args.chapter,
  784. "scene_index": scene_index,
  785. "content": s.get("content", ""),
  786. "chunk_type": "scene",
  787. "parent_chunk_id": parent_chunk_id,
  788. "chunk_id": chunk_id,
  789. "source_file": f"正文/第{args.chapter:04d}章.md#scene_{int(scene_index)}",
  790. }
  791. )
  792. stored = asyncio.run(adapter.store_chunks(chunks))
  793. skipped = len(chunks) - stored
  794. result = {"stored": stored, "skipped": skipped, "total": len(chunks)}
  795. if skipped > 0:
  796. emit_success(result, message="indexed_with_warnings")
  797. else:
  798. emit_success(result, message="indexed")
  799. elif args.command == "search":
  800. if args.mode == "vector":
  801. results = asyncio.run(adapter.vector_search(args.query, args.top_k, chunk_type=args.chunk_type))
  802. elif args.mode == "bm25":
  803. results = adapter.bm25_search(args.query, args.top_k, chunk_type=args.chunk_type)
  804. elif args.mode == "backtrack":
  805. results = asyncio.run(adapter.search_with_backtrack(args.query, args.top_k))
  806. else:
  807. results = asyncio.run(adapter.hybrid_search(args.query, args.top_k, args.top_k, args.top_k, chunk_type=args.chunk_type))
  808. payload = [r.__dict__ for r in results]
  809. degraded_reason = adapter.degraded_mode_reason
  810. if degraded_reason:
  811. warnings = [{"code": "DEGRADED_MODE", "reason": degraded_reason}]
  812. print_success(payload, message="search_results", warnings=warnings)
  813. safe_log_tool_call(adapter.index_manager, tool_name=tool_name, success=True)
  814. else:
  815. emit_success(payload, message="search_results")
  816. else:
  817. emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")
  818. if __name__ == "__main__":
  819. import sys
  820. if sys.platform == "win32":
  821. enable_windows_utf8_stdio()
  822. main()