rag_adapter.py 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. RAG Adapter - RAG 检索适配模块
  5. 封装向量检索功能:
  6. - 向量嵌入 (调用 Modal API)
  7. - 语义搜索
  8. - 重排序
  9. - 混合检索 (向量 + BM25)
  10. """
  11. import asyncio
  12. import sqlite3
  13. import json
  14. import math
  15. import logging
  16. import shutil
  17. from pathlib import Path
  18. from runtime_compat import enable_windows_utf8_stdio
  19. from typing import Dict, List, Optional, Any, Tuple
  20. from dataclasses import dataclass
  21. from collections import Counter
  22. import re
  23. from contextlib import contextmanager
  24. import itertools
  25. import time
  26. from datetime import datetime
  27. from .config import get_config
  28. from .api_client import get_client
  29. from .index_manager import IndexManager
  30. from .query_router import QueryRouter
  31. from .observability import safe_append_perf_timing, safe_log_tool_call
  32. logger = logging.getLogger(__name__)
  33. RAG_SCHEMA_VERSION = "2"
  34. VECTOR_REQUIRED_COLUMNS = (
  35. "chunk_id",
  36. "chapter",
  37. "scene_index",
  38. "content",
  39. "embedding",
  40. "parent_chunk_id",
  41. "chunk_type",
  42. "source_file",
  43. "created_at",
  44. )
  45. @dataclass
  46. class SearchResult:
  47. """搜索结果"""
  48. chunk_id: str
  49. chapter: int
  50. scene_index: int
  51. content: str
  52. score: float
  53. source: str # "vector" | "bm25" | "hybrid"
  54. parent_chunk_id: str | None = None
  55. chunk_type: str | None = None
  56. source_file: str | None = None
  57. class RAGAdapter:
  58. """RAG 检索适配器"""
  59. def __init__(self, config=None):
  60. self.config = config or get_config()
  61. self.api_client = get_client(config)
  62. self.index_manager = IndexManager(self.config)
  63. self.query_router = QueryRouter()
  64. self._degraded_mode_reason: Optional[str] = None
  65. self._init_db()
  66. @property
  67. def degraded_mode_reason(self) -> Optional[str]:
  68. return self._degraded_mode_reason
  69. def _update_degraded_mode(self) -> None:
  70. self._degraded_mode_reason = None
  71. embed_client = getattr(self.api_client, "_embed_client", None)
  72. status = getattr(embed_client, "last_error_status", None)
  73. if status == 401:
  74. self._degraded_mode_reason = "embedding_auth_failed"
  75. def _init_db(self):
  76. """初始化向量数据库"""
  77. self.config.ensure_dirs()
  78. needs_migration, existing_cols = self._inspect_vectors_schema()
  79. if needs_migration:
  80. backup_path = self._backup_vector_db(reason="schema_migration")
  81. try:
  82. with self._get_conn() as conn:
  83. cursor = conn.cursor()
  84. self._rebuild_vectors_table(cursor, existing_cols)
  85. conn.commit()
  86. logger.warning(
  87. "vectors 表结构已迁移(备份: %s)",
  88. str(backup_path),
  89. )
  90. except Exception:
  91. try:
  92. self._restore_vector_db_from_backup(backup_path)
  93. logger.error("vectors 表迁移失败,已从备份恢复: %s", str(backup_path))
  94. except Exception as restore_exc:
  95. logger.exception("vectors 表迁移失败,且恢复备份失败: %s", restore_exc)
  96. raise
  97. with self._get_conn() as conn:
  98. cursor = conn.cursor()
  99. self._ensure_schema_meta(cursor)
  100. self._ensure_tables(cursor)
  101. conn.commit()
  102. def _table_exists(self, cursor, table_name: str) -> bool:
  103. cursor.execute(
  104. "SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
  105. (table_name,),
  106. )
  107. return cursor.fetchone() is not None
  108. def _table_columns(self, cursor, table_name: str) -> set[str]:
  109. cursor.execute(f"PRAGMA table_info({table_name})")
  110. return {row[1] for row in cursor.fetchall()}
  111. def _inspect_vectors_schema(self) -> tuple[bool, set[str]]:
  112. with self._get_conn() as conn:
  113. cursor = conn.cursor()
  114. if not self._table_exists(cursor, "vectors"):
  115. return False, set()
  116. cols = self._table_columns(cursor, "vectors")
  117. required_cols = set(VECTOR_REQUIRED_COLUMNS)
  118. return (not required_cols.issubset(cols), cols)
  119. def _backup_vector_db(self, reason: str) -> Path:
  120. db_path = Path(self.config.vector_db)
  121. if not db_path.exists():
  122. raise FileNotFoundError(f"vectors.db 不存在: {db_path}")
  123. backup_dir = self.config.webnovel_dir / "backups"
  124. backup_dir.mkdir(parents=True, exist_ok=True)
  125. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  126. backup_path = backup_dir / f"vectors.db.{reason}.v{RAG_SCHEMA_VERSION}.{timestamp}.bak"
  127. shutil.copy2(db_path, backup_path)
  128. return backup_path
  129. def _restore_vector_db_from_backup(self, backup_path: Path) -> None:
  130. db_path = Path(self.config.vector_db)
  131. shutil.copy2(backup_path, db_path)
  132. def _rebuild_vectors_table(self, cursor, existing_cols: set[str]) -> None:
  133. if not self._table_exists(cursor, "vectors"):
  134. return
  135. cursor.execute("DROP TABLE IF EXISTS vectors_migrating")
  136. cursor.execute("""
  137. CREATE TABLE vectors_migrating (
  138. chunk_id TEXT PRIMARY KEY,
  139. chapter INTEGER,
  140. scene_index INTEGER,
  141. content TEXT,
  142. embedding BLOB,
  143. parent_chunk_id TEXT,
  144. chunk_type TEXT DEFAULT 'scene',
  145. source_file TEXT,
  146. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  147. )
  148. """)
  149. copy_columns = [
  150. col
  151. for col in VECTOR_REQUIRED_COLUMNS
  152. if col in existing_cols
  153. ]
  154. if copy_columns:
  155. cols_sql = ", ".join(copy_columns)
  156. cursor.execute(
  157. f"INSERT OR REPLACE INTO vectors_migrating ({cols_sql}) SELECT {cols_sql} FROM vectors"
  158. )
  159. cursor.execute("DROP TABLE vectors")
  160. cursor.execute("ALTER TABLE vectors_migrating RENAME TO vectors")
  161. def _ensure_schema_meta(self, cursor) -> None:
  162. cursor.execute("""
  163. CREATE TABLE IF NOT EXISTS rag_schema_meta (
  164. key TEXT PRIMARY KEY,
  165. value TEXT NOT NULL,
  166. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  167. )
  168. """)
  169. cursor.execute(
  170. """
  171. INSERT INTO rag_schema_meta (key, value, updated_at)
  172. VALUES ('schema_version', ?, CURRENT_TIMESTAMP)
  173. ON CONFLICT(key) DO UPDATE SET
  174. value = excluded.value,
  175. updated_at = CURRENT_TIMESTAMP
  176. """,
  177. (RAG_SCHEMA_VERSION,),
  178. )
  179. def _ensure_tables(self, cursor) -> None:
  180. # 向量存储表
  181. cursor.execute("""
  182. CREATE TABLE IF NOT EXISTS vectors (
  183. chunk_id TEXT PRIMARY KEY,
  184. chapter INTEGER,
  185. scene_index INTEGER,
  186. content TEXT,
  187. embedding BLOB,
  188. parent_chunk_id TEXT,
  189. chunk_type TEXT DEFAULT 'scene',
  190. source_file TEXT,
  191. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  192. )
  193. """)
  194. # BM25 倒排索引表
  195. cursor.execute("""
  196. CREATE TABLE IF NOT EXISTS bm25_index (
  197. term TEXT,
  198. chunk_id TEXT,
  199. tf REAL,
  200. PRIMARY KEY (term, chunk_id)
  201. )
  202. """)
  203. # 文档统计表
  204. cursor.execute("""
  205. CREATE TABLE IF NOT EXISTS doc_stats (
  206. chunk_id TEXT PRIMARY KEY,
  207. doc_length INTEGER
  208. )
  209. """)
  210. # 创建索引
  211. cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_chapter ON vectors(chapter)")
  212. cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_parent ON vectors(parent_chunk_id)")
  213. cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_type ON vectors(chunk_type)")
  214. cursor.execute("CREATE INDEX IF NOT EXISTS idx_bm25_term ON bm25_index(term)")
  215. @contextmanager
  216. def _get_conn(self):
  217. """获取数据库连接(确保关闭,避免 Windows 下文件句柄泄漏)"""
  218. conn = sqlite3.connect(str(self.config.vector_db))
  219. try:
  220. yield conn
  221. finally:
  222. conn.close()
  223. def _get_vectors_count(self) -> int:
  224. with self._get_conn() as conn:
  225. cursor = conn.cursor()
  226. cursor.execute("SELECT COUNT(*) FROM vectors")
  227. row = cursor.fetchone()
  228. return int(row[0] or 0) if row else 0
  229. def _get_recent_chunk_ids(
  230. self,
  231. limit: int,
  232. chunk_type: str | None = None,
  233. chapter: int | None = None,
  234. ) -> List[str]:
  235. if limit <= 0:
  236. return []
  237. with self._get_conn() as conn:
  238. cursor = conn.cursor()
  239. if chunk_type and chapter is not None:
  240. cursor.execute(
  241. """
  242. SELECT chunk_id
  243. FROM vectors
  244. WHERE chunk_type = ? AND chapter <= ?
  245. ORDER BY chapter DESC, scene_index DESC
  246. LIMIT ?
  247. """,
  248. (chunk_type, int(chapter), int(limit)),
  249. )
  250. elif chunk_type:
  251. cursor.execute(
  252. """
  253. SELECT chunk_id
  254. FROM vectors
  255. WHERE chunk_type = ?
  256. ORDER BY chapter DESC, scene_index DESC
  257. LIMIT ?
  258. """,
  259. (chunk_type, int(limit)),
  260. )
  261. elif chapter is not None:
  262. cursor.execute(
  263. """
  264. SELECT chunk_id
  265. FROM vectors
  266. WHERE chapter <= ?
  267. ORDER BY chapter DESC, scene_index DESC
  268. LIMIT ?
  269. """,
  270. (int(chapter), int(limit)),
  271. )
  272. else:
  273. cursor.execute(
  274. "SELECT chunk_id FROM vectors ORDER BY chapter DESC, scene_index DESC LIMIT ?",
  275. (int(limit),),
  276. )
  277. return [str(r[0]) for r in cursor.fetchall() if r and r[0]]
  278. def _fetch_vectors_by_chunk_ids(self, chunk_ids: List[str]) -> List[Tuple]:
  279. if not chunk_ids:
  280. return []
  281. # SQLite 参数数量限制(默认 999),这里做分片查询
  282. def _chunks(xs: List[str], size: int = 500):
  283. it = iter(xs)
  284. while True:
  285. batch = list(itertools.islice(it, size))
  286. if not batch:
  287. break
  288. yield batch
  289. rows: List[Tuple] = []
  290. with self._get_conn() as conn:
  291. cursor = conn.cursor()
  292. for batch in _chunks(chunk_ids):
  293. placeholders = ",".join(["?"] * len(batch))
  294. cursor.execute(
  295. f"SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors WHERE chunk_id IN ({placeholders})",
  296. tuple(batch),
  297. )
  298. rows.extend(cursor.fetchall())
  299. return rows
  300. def _vector_search_rows(
  301. self,
  302. query_embedding: List[float],
  303. rows: List[Tuple],
  304. *,
  305. top_k: int,
  306. ) -> List[SearchResult]:
  307. results: List[SearchResult] = []
  308. for row in rows:
  309. (
  310. chunk_id,
  311. chapter,
  312. scene_index,
  313. content,
  314. embedding_bytes,
  315. parent_chunk_id,
  316. chunk_type,
  317. source_file,
  318. ) = row
  319. if not embedding_bytes:
  320. continue
  321. embedding = self._deserialize_embedding(embedding_bytes)
  322. score = self._cosine_similarity(query_embedding, embedding)
  323. results.append(
  324. SearchResult(
  325. chunk_id=chunk_id,
  326. chapter=chapter,
  327. scene_index=scene_index,
  328. content=content,
  329. score=score,
  330. source="vector",
  331. parent_chunk_id=parent_chunk_id,
  332. chunk_type=chunk_type,
  333. source_file=source_file,
  334. )
  335. )
  336. results.sort(key=lambda x: x.score, reverse=True)
  337. return results[:top_k]
  338. # ==================== 向量存储 ====================
  339. async def store_chunks(self, chunks: List[Dict]) -> int:
  340. """
  341. 存储场景切片的向量
  342. chunks 格式:
  343. [
  344. {
  345. "chapter": 100,
  346. "scene_index": 1,
  347. "content": "场景内容...",
  348. "chunk_type": "scene",
  349. "parent_chunk_id": "ch0100_summary",
  350. "source_file": "正文/第0100章.md#scene_1"
  351. }
  352. ]
  353. 返回存储数量
  354. """
  355. if not chunks:
  356. return 0
  357. # 提取内容用于嵌入
  358. contents = [c.get("content", "") for c in chunks]
  359. # 调用 API 获取嵌入向量(可能包含 None 表示失败)
  360. embeddings = await self.api_client.embed_batch(contents)
  361. if not embeddings:
  362. return 0
  363. # 存储到数据库(跳过嵌入失败的 chunk)
  364. stored = 0
  365. skipped = 0
  366. errors = []
  367. with self._get_conn() as conn:
  368. cursor = conn.cursor()
  369. for chunk, embedding in zip(chunks, embeddings):
  370. if embedding is None:
  371. # 嵌入失败,跳过该 chunk(仅存储 BM25 索引供关键词检索)
  372. skipped += 1
  373. chunk_id = chunk.get("chunk_id")
  374. if not chunk_id:
  375. if chunk.get("chunk_type") == "summary":
  376. chunk_id = f"ch{int(chunk['chapter']):04d}_summary"
  377. else:
  378. chunk_id = f"ch{int(chunk['chapter']):04d}_s{int(chunk['scene_index'])}"
  379. try:
  380. self._update_bm25_index(cursor, chunk_id, chunk.get("content", ""))
  381. except Exception as e:
  382. errors.append(f"BM25 index failed for {chunk_id}: {e}")
  383. continue
  384. chunk_type = chunk.get("chunk_type") or "scene"
  385. chunk_id = chunk.get("chunk_id")
  386. if not chunk_id:
  387. if chunk_type == "summary":
  388. chunk_id = f"ch{int(chunk['chapter']):04d}_summary"
  389. else:
  390. chunk_id = f"ch{int(chunk['chapter']):04d}_s{int(chunk['scene_index'])}"
  391. # 将向量序列化为 bytes
  392. embedding_bytes = self._serialize_embedding(embedding)
  393. cursor.execute("""
  394. INSERT OR REPLACE INTO vectors
  395. (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file)
  396. VALUES (?, ?, ?, ?, ?, ?, ?, ?)
  397. """, (
  398. chunk_id,
  399. chunk["chapter"],
  400. chunk.get("scene_index", 0) if chunk_type == "scene" else 0,
  401. chunk.get("content", ""),
  402. embedding_bytes,
  403. chunk.get("parent_chunk_id"),
  404. chunk_type,
  405. chunk.get("source_file"),
  406. ))
  407. # 同时更新 BM25 索引
  408. try:
  409. self._update_bm25_index(cursor, chunk_id, chunk.get("content", ""))
  410. except Exception as e:
  411. errors.append(f"BM25 index failed for {chunk_id}: {e}")
  412. stored += 1
  413. try:
  414. conn.commit()
  415. except Exception as e:
  416. logger.error("SQLite commit failed: %s", e)
  417. errors.append(f"SQLite commit failed: {e}")
  418. # 输出警告日志
  419. if skipped > 0:
  420. logger.warning(
  421. "Vector embedding: %s stored, %s skipped (embedding failed)",
  422. stored,
  423. skipped,
  424. )
  425. if errors:
  426. for err in errors[:5]: # 最多显示5条
  427. logger.warning("%s", err)
  428. return stored
  429. def _serialize_embedding(self, embedding: List[float]) -> bytes:
  430. """序列化向量"""
  431. import struct
  432. return struct.pack(f"{len(embedding)}f", *embedding)
  433. def _deserialize_embedding(self, data: bytes) -> List[float]:
  434. """反序列化向量"""
  435. import struct
  436. count = len(data) // 4
  437. return list(struct.unpack(f"{count}f", data))
  438. def _log_query(
  439. self,
  440. query: str,
  441. query_type: str,
  442. results: List[SearchResult],
  443. latency_ms: int,
  444. chapter: int | None = None,
  445. ) -> None:
  446. try:
  447. hit_sources = Counter([r.chunk_type or "unknown" for r in results])
  448. self.index_manager.log_rag_query(
  449. query=query,
  450. query_type=query_type,
  451. results_count=len(results),
  452. hit_sources=json.dumps(hit_sources, ensure_ascii=False),
  453. latency_ms=latency_ms,
  454. chapter=chapter,
  455. )
  456. except Exception as exc:
  457. logger.warning("failed to log rag query: %s", exc)
  458. # ==================== BM25 索引 ====================
  459. def _tokenize(self, text: str) -> List[str]:
  460. """简单分词(中文按字符,英文按单词)"""
  461. # 中文字符
  462. chinese = re.findall(r'[\u4e00-\u9fff]+', text)
  463. chinese_chars = list("".join(chinese))
  464. # 英文单词
  465. english = re.findall(r'[a-zA-Z]+', text.lower())
  466. return chinese_chars + english
  467. def _update_bm25_index(self, cursor, chunk_id: str, content: str):
  468. """更新 BM25 索引"""
  469. # 删除旧索引
  470. cursor.execute("DELETE FROM bm25_index WHERE chunk_id = ?", (chunk_id,))
  471. cursor.execute("DELETE FROM doc_stats WHERE chunk_id = ?", (chunk_id,))
  472. # 分词
  473. tokens = self._tokenize(content)
  474. doc_length = len(tokens)
  475. # 计算词频
  476. tf_counter = Counter(tokens)
  477. # 插入倒排索引
  478. for term, count in tf_counter.items():
  479. tf = count / doc_length if doc_length > 0 else 0
  480. cursor.execute("""
  481. INSERT INTO bm25_index (term, chunk_id, tf)
  482. VALUES (?, ?, ?)
  483. """, (term, chunk_id, tf))
  484. # 更新文档统计
  485. cursor.execute("""
  486. INSERT INTO doc_stats (chunk_id, doc_length)
  487. VALUES (?, ?)
  488. """, (chunk_id, doc_length))
  489. # ==================== 向量检索 ====================
  490. async def vector_search(
  491. self,
  492. query: str,
  493. top_k: int = None,
  494. chunk_type: str | None = None,
  495. log_query: bool = True,
  496. chapter: int | None = None,
  497. ) -> List[SearchResult]:
  498. """向量相似度搜索"""
  499. top_k = top_k or self.config.vector_top_k
  500. start_time = time.perf_counter()
  501. # 获取查询向量
  502. query_embeddings = await self.api_client.embed([query])
  503. if not query_embeddings:
  504. self._update_degraded_mode()
  505. return []
  506. self._degraded_mode_reason = None
  507. query_embedding = query_embeddings[0]
  508. # 从数据库读取所有向量并计算相似度
  509. with self._get_conn() as conn:
  510. cursor = conn.cursor()
  511. if chunk_type and chapter is not None:
  512. cursor.execute(
  513. """
  514. SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file
  515. FROM vectors
  516. WHERE chunk_type = ? AND chapter <= ?
  517. """,
  518. (chunk_type, int(chapter)),
  519. )
  520. elif chunk_type:
  521. cursor.execute(
  522. "SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors WHERE chunk_type = ?",
  523. (chunk_type,),
  524. )
  525. elif chapter is not None:
  526. cursor.execute(
  527. """
  528. SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file
  529. FROM vectors
  530. WHERE chapter <= ?
  531. """,
  532. (int(chapter),),
  533. )
  534. else:
  535. cursor.execute(
  536. "SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors"
  537. )
  538. results = []
  539. for row in cursor.fetchall():
  540. (
  541. chunk_id,
  542. chapter,
  543. scene_index,
  544. content,
  545. embedding_bytes,
  546. parent_chunk_id,
  547. chunk_type_value,
  548. source_file,
  549. ) = row
  550. if not embedding_bytes:
  551. continue
  552. embedding = self._deserialize_embedding(embedding_bytes)
  553. # 计算余弦相似度
  554. score = self._cosine_similarity(query_embedding, embedding)
  555. results.append(SearchResult(
  556. chunk_id=chunk_id,
  557. chapter=chapter,
  558. scene_index=scene_index,
  559. content=content,
  560. score=score,
  561. source="vector",
  562. parent_chunk_id=parent_chunk_id,
  563. chunk_type=chunk_type_value,
  564. source_file=source_file,
  565. ))
  566. # 排序并返回 top_k
  567. results.sort(key=lambda x: x.score, reverse=True)
  568. results = results[:top_k]
  569. if log_query:
  570. latency_ms = int((time.perf_counter() - start_time) * 1000)
  571. self._log_query(query, "vector", results, latency_ms, chapter=chapter)
  572. return results
  573. def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
  574. """计算余弦相似度"""
  575. dot_product = sum(x * y for x, y in zip(a, b))
  576. norm_a = math.sqrt(sum(x * x for x in a))
  577. norm_b = math.sqrt(sum(x * x for x in b))
  578. if norm_a == 0 or norm_b == 0:
  579. return 0.0
  580. return dot_product / (norm_a * norm_b)
  581. # ==================== BM25 检索 ====================
  582. def bm25_search(
  583. self,
  584. query: str,
  585. top_k: int = None,
  586. k1: float = 1.5,
  587. b: float = 0.75,
  588. chunk_type: str | None = None,
  589. log_query: bool = True,
  590. chapter: int | None = None,
  591. ) -> List[SearchResult]:
  592. """BM25 关键词搜索"""
  593. top_k = top_k or self.config.bm25_top_k
  594. start_time = time.perf_counter()
  595. query_terms = self._tokenize(query)
  596. if not query_terms:
  597. return []
  598. with self._get_conn() as conn:
  599. cursor = conn.cursor()
  600. # 获取文档总数和平均长度
  601. cursor.execute("SELECT COUNT(*), AVG(doc_length) FROM doc_stats")
  602. row = cursor.fetchone()
  603. total_docs = row[0] or 1
  604. avg_doc_length = row[1] or 1
  605. # 计算每个文档的 BM25 分数
  606. doc_scores = {}
  607. for term in set(query_terms):
  608. # 获取包含该词的文档
  609. cursor.execute("""
  610. SELECT b.chunk_id, b.tf, d.doc_length
  611. FROM bm25_index b
  612. JOIN doc_stats d ON b.chunk_id = d.chunk_id
  613. WHERE b.term = ?
  614. """, (term,))
  615. docs_with_term = cursor.fetchall()
  616. df = len(docs_with_term)
  617. if df == 0:
  618. continue
  619. # IDF
  620. idf = math.log((total_docs - df + 0.5) / (df + 0.5) + 1)
  621. for chunk_id, tf, doc_length in docs_with_term:
  622. # BM25 公式
  623. score = idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * doc_length / avg_doc_length))
  624. if chunk_id not in doc_scores:
  625. doc_scores[chunk_id] = 0
  626. doc_scores[chunk_id] += score
  627. # 获取文档内容
  628. results = []
  629. for chunk_id, score in doc_scores.items():
  630. if chunk_type and chapter is not None:
  631. cursor.execute(
  632. """
  633. SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
  634. FROM vectors
  635. WHERE chunk_id = ? AND chunk_type = ? AND chapter <= ?
  636. """,
  637. (chunk_id, chunk_type, int(chapter)),
  638. )
  639. elif chunk_type:
  640. cursor.execute(
  641. """
  642. SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
  643. FROM vectors
  644. WHERE chunk_id = ? AND chunk_type = ?
  645. """,
  646. (chunk_id, chunk_type),
  647. )
  648. elif chapter is not None:
  649. cursor.execute(
  650. """
  651. SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
  652. FROM vectors
  653. WHERE chunk_id = ? AND chapter <= ?
  654. """,
  655. (chunk_id, int(chapter)),
  656. )
  657. else:
  658. cursor.execute(
  659. """
  660. SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
  661. FROM vectors
  662. WHERE chunk_id = ?
  663. """,
  664. (chunk_id,),
  665. )
  666. row = cursor.fetchone()
  667. if row:
  668. results.append(SearchResult(
  669. chunk_id=chunk_id,
  670. chapter=row[0],
  671. scene_index=row[1],
  672. content=row[2],
  673. score=score,
  674. source="bm25",
  675. parent_chunk_id=row[3],
  676. chunk_type=row[4],
  677. source_file=row[5],
  678. ))
  679. results.sort(key=lambda x: x.score, reverse=True)
  680. results = results[:top_k]
  681. if log_query:
  682. latency_ms = int((time.perf_counter() - start_time) * 1000)
  683. self._log_query(query, "bm25", results, latency_ms, chapter=chapter)
  684. return results
  685. def _extract_query_seed_entities(self, query: str) -> List[str]:
  686. """从查询中提取种子实体(通过别名和实体 ID 匹配)。"""
  687. tokens = set(re.findall(r"[\u4e00-\u9fff]{2,8}|[A-Za-z][A-Za-z0-9_]{1,24}", query))
  688. entity_ids: List[str] = []
  689. for token in tokens:
  690. if len(entity_ids) >= int(self.config.graph_rag_max_expanded_entities):
  691. break
  692. # 1) 通过别名匹配
  693. alias_hits = self.index_manager.get_entities_by_alias(token)
  694. for hit in alias_hits:
  695. entity_id = str(hit.get("id") or "").strip()
  696. if entity_id and entity_id not in entity_ids:
  697. entity_ids.append(entity_id)
  698. if len(entity_ids) >= int(self.config.graph_rag_max_expanded_entities):
  699. break
  700. # 2) 通过实体 ID 直匹配
  701. entity = self.index_manager.get_entity(token)
  702. if entity:
  703. entity_id = str(entity.get("id") or "").strip()
  704. if entity_id and entity_id not in entity_ids:
  705. entity_ids.append(entity_id)
  706. return entity_ids[: int(self.config.graph_rag_max_expanded_entities)]
  707. def _normalize_entity_ids(self, candidates: List[str]) -> List[str]:
  708. """将输入实体候选(名称/别名/ID)规范化为实体 ID 列表。"""
  709. ids: List[str] = []
  710. for token in candidates:
  711. candidate = str(token or "").strip()
  712. if not candidate:
  713. continue
  714. direct = self.index_manager.get_entity(candidate)
  715. if direct and direct.get("id"):
  716. entity_id = str(direct.get("id"))
  717. if entity_id not in ids:
  718. ids.append(entity_id)
  719. continue
  720. for hit in self.index_manager.get_entities_by_alias(candidate):
  721. entity_id = str(hit.get("id") or "").strip()
  722. if entity_id and entity_id not in ids:
  723. ids.append(entity_id)
  724. return ids[: int(self.config.graph_rag_max_expanded_entities)]
  725. def _expand_related_entities(self, seed_entities: List[str], hops: int | None = None) -> List[str]:
  726. """基于关系图扩展相关实体。"""
  727. max_entities = int(self.config.graph_rag_max_expanded_entities)
  728. hops = max(1, int(hops or self.config.graph_rag_expand_hops))
  729. expanded: List[str] = []
  730. for seed in seed_entities:
  731. if seed not in expanded:
  732. expanded.append(seed)
  733. if len(expanded) >= max_entities:
  734. break
  735. graph = self.index_manager.build_relationship_subgraph(
  736. center_entity=seed,
  737. depth=hops,
  738. top_edges=max(20, int(self.config.graph_rag_candidate_limit)),
  739. )
  740. for node in graph.get("nodes", []):
  741. entity_id = str(node.get("id") or "").strip()
  742. if entity_id and entity_id not in expanded:
  743. expanded.append(entity_id)
  744. if len(expanded) >= max_entities:
  745. break
  746. if len(expanded) >= max_entities:
  747. break
  748. return expanded[:max_entities]
  749. def _collect_graph_candidate_chunk_ids(
  750. self,
  751. entity_ids: List[str],
  752. *,
  753. chapter: int | None = None,
  754. limit: int | None = None,
  755. ) -> List[str]:
  756. """根据实体名称/别名在向量库正文中筛选候选 chunk。"""
  757. if not entity_ids:
  758. return []
  759. limit = int(limit or self.config.graph_rag_candidate_limit)
  760. entity_terms: Dict[str, set[str]] = {}
  761. for entity_id in entity_ids:
  762. terms: set[str] = set()
  763. entity = self.index_manager.get_entity(entity_id)
  764. if entity:
  765. canonical_name = str(entity.get("canonical_name") or "").strip()
  766. if canonical_name:
  767. terms.add(canonical_name)
  768. for alias in self.index_manager.get_entity_aliases(entity_id):
  769. alias_text = str(alias or "").strip()
  770. if alias_text:
  771. terms.add(alias_text)
  772. if terms:
  773. entity_terms[entity_id] = terms
  774. if not entity_terms:
  775. return []
  776. with self._get_conn() as conn:
  777. cursor = conn.cursor()
  778. if chapter is None:
  779. cursor.execute(
  780. "SELECT chunk_id, chapter, content FROM vectors ORDER BY chapter DESC, scene_index DESC"
  781. )
  782. else:
  783. cursor.execute(
  784. """
  785. SELECT chunk_id, chapter, content
  786. FROM vectors
  787. WHERE chapter <= ?
  788. ORDER BY chapter DESC, scene_index DESC
  789. """,
  790. (int(chapter),),
  791. )
  792. rows = cursor.fetchall()
  793. scored: List[Tuple[str, int, int]] = []
  794. for chunk_id, chapter_no, content in rows:
  795. text = str(content or "")
  796. if not text:
  797. continue
  798. hit_score = 0
  799. for terms in entity_terms.values():
  800. hit_score += sum(1 for term in terms if term and term in text)
  801. if hit_score > 0:
  802. scored.append((str(chunk_id), int(chapter_no or 0), hit_score))
  803. scored.sort(key=lambda x: (x[2], x[1]), reverse=True)
  804. return [chunk_id for chunk_id, _chapter, _score in scored[:limit]]
  805. async def _vector_search_by_chunk_ids(
  806. self,
  807. query: str,
  808. chunk_ids: List[str],
  809. *,
  810. top_k: int,
  811. chunk_type: str | None = None,
  812. ) -> List[SearchResult]:
  813. """在指定候选 chunk 范围内执行向量检索。"""
  814. if not chunk_ids:
  815. return []
  816. query_embeddings = await self.api_client.embed([query])
  817. if not query_embeddings:
  818. self._update_degraded_mode()
  819. return []
  820. self._degraded_mode_reason = None
  821. query_embedding = query_embeddings[0]
  822. rows = await asyncio.to_thread(self._fetch_vectors_by_chunk_ids, chunk_ids)
  823. if chunk_type:
  824. rows = [r for r in rows if len(r) > 6 and r[6] == chunk_type]
  825. return await asyncio.to_thread(
  826. self._vector_search_rows,
  827. query_embedding,
  828. rows,
  829. top_k=top_k,
  830. )
  831. def _apply_graph_priors(
  832. self,
  833. result: SearchResult,
  834. *,
  835. seed_terms: set[str],
  836. related_terms: set[str],
  837. max_chapter: int,
  838. ) -> float:
  839. """为图谱候选增加先验分。"""
  840. score = float(result.score)
  841. content = str(result.content or "")
  842. if any(term and term in content for term in seed_terms):
  843. score += float(self.config.graph_rag_boost_same_entity)
  844. elif any(term and term in content for term in related_terms):
  845. score += float(self.config.graph_rag_boost_related_entity)
  846. if max_chapter > 0 and result.chapter is not None:
  847. gap = max(0, max_chapter - int(result.chapter))
  848. recency = max(0.0, 1.0 - min(gap, 100) / 100.0)
  849. score += recency * float(self.config.graph_rag_boost_recency)
  850. return score
  851. async def graph_hybrid_search(
  852. self,
  853. query: str,
  854. top_k: int = 5,
  855. *,
  856. chunk_type: str | None = None,
  857. chapter: int | None = None,
  858. center_entities: Optional[List[str]] = None,
  859. log_query: bool = True,
  860. ) -> List[SearchResult]:
  861. """
  862. 图谱增强混合检索:
  863. 1) 先走现有 hybrid 作为基础召回;
  864. 2) 基于实体关系图扩展候选;
  865. 3) 向量重算 + 图谱先验融合;
  866. 4) rerank 产出最终结果。
  867. """
  868. start_time = time.perf_counter()
  869. base_results = await self.hybrid_search(
  870. query=query,
  871. vector_top_k=max(top_k * 3, int(self.config.vector_top_k)),
  872. bm25_top_k=max(top_k * 3, int(self.config.bm25_top_k)),
  873. rerank_top_n=max(top_k * 2, int(self.config.rerank_top_n)),
  874. chunk_type=chunk_type,
  875. chapter=chapter,
  876. log_query=False,
  877. )
  878. if not bool(self.config.graph_rag_enabled):
  879. final = list(base_results)[:top_k]
  880. if log_query:
  881. latency_ms = int((time.perf_counter() - start_time) * 1000)
  882. self._log_query(query, "graph_hybrid_fallback", final, latency_ms, chapter=chapter)
  883. return final
  884. seeds = self._normalize_entity_ids([s for s in (center_entities or []) if str(s).strip()])
  885. if not seeds:
  886. seeds = self._extract_query_seed_entities(query)
  887. if not seeds:
  888. final = list(base_results)[:top_k]
  889. if log_query:
  890. latency_ms = int((time.perf_counter() - start_time) * 1000)
  891. self._log_query(query, "graph_hybrid_no_seed", final, latency_ms, chapter=chapter)
  892. return final
  893. expanded_entities = self._expand_related_entities(seeds)
  894. candidate_chunk_ids = self._collect_graph_candidate_chunk_ids(
  895. expanded_entities,
  896. chapter=chapter,
  897. limit=max(top_k * 8, int(self.config.graph_rag_candidate_limit)),
  898. )
  899. graph_vector_results = await self._vector_search_by_chunk_ids(
  900. query,
  901. candidate_chunk_ids,
  902. top_k=max(top_k * 4, int(self.config.rerank_top_n) * 2),
  903. chunk_type=chunk_type,
  904. )
  905. # 构建实体术语集用于先验分
  906. seed_terms: set[str] = set()
  907. related_terms: set[str] = set()
  908. for idx, entity_id in enumerate(expanded_entities):
  909. entity = self.index_manager.get_entity(entity_id)
  910. canonical_name = str((entity or {}).get("canonical_name") or "").strip()
  911. aliases = [str(a).strip() for a in self.index_manager.get_entity_aliases(entity_id)]
  912. terms = {t for t in [canonical_name, *aliases] if t}
  913. if idx < len(seeds):
  914. seed_terms.update(terms)
  915. else:
  916. related_terms.update(terms)
  917. max_chapter = 0
  918. try:
  919. max_chapter = int(self.get_stats().get("max_chapter") or 0)
  920. except Exception:
  921. max_chapter = 0
  922. if chapter is not None:
  923. try:
  924. max_chapter = int(chapter)
  925. except (TypeError, ValueError):
  926. pass
  927. merged: Dict[str, SearchResult] = {}
  928. for result in base_results:
  929. result.source = "graph_hybrid"
  930. merged[result.chunk_id] = result
  931. for result in graph_vector_results:
  932. adjusted = self._apply_graph_priors(
  933. result,
  934. seed_terms=seed_terms,
  935. related_terms=related_terms,
  936. max_chapter=max_chapter,
  937. )
  938. result.score = adjusted
  939. result.source = "graph_hybrid"
  940. existing = merged.get(result.chunk_id)
  941. if existing is None or result.score > existing.score:
  942. merged[result.chunk_id] = result
  943. sorted_candidates = sorted(merged.values(), key=lambda r: r.score, reverse=True)
  944. candidates = sorted_candidates[: max(top_k * 3, int(self.config.rerank_top_n) * 2)]
  945. if not candidates:
  946. if log_query:
  947. latency_ms = int((time.perf_counter() - start_time) * 1000)
  948. self._log_query(query, "graph_hybrid", [], latency_ms, chapter=chapter)
  949. return []
  950. rerank_top_n = max(top_k, int(self.config.rerank_top_n))
  951. rerank_input = [c.content for c in candidates]
  952. rerank_results = await self.api_client.rerank(query, rerank_input, top_n=rerank_top_n)
  953. final_results: List[SearchResult] = []
  954. if rerank_results:
  955. for item in rerank_results:
  956. idx = int(item.get("index", 0))
  957. if idx < 0 or idx >= len(candidates):
  958. continue
  959. picked = candidates[idx]
  960. picked.score = float(item.get("relevance_score", picked.score))
  961. picked.source = "graph_hybrid"
  962. final_results.append(picked)
  963. else:
  964. final_results = candidates[:rerank_top_n]
  965. final_results = final_results[:top_k]
  966. if log_query:
  967. latency_ms = int((time.perf_counter() - start_time) * 1000)
  968. self._log_query(query, "graph_hybrid", final_results, latency_ms, chapter=chapter)
  969. return final_results
  970. async def search(
  971. self,
  972. query: str,
  973. top_k: int = 5,
  974. *,
  975. strategy: str = "auto",
  976. chunk_type: str | None = None,
  977. chapter: int | None = None,
  978. center_entities: Optional[List[str]] = None,
  979. filters: Optional[Dict[str, Any]] = None,
  980. ) -> List[SearchResult]:
  981. """统一检索入口。"""
  982. strategy = str(strategy or "auto").lower()
  983. if filters and chapter is None:
  984. try:
  985. chapter = int((filters or {}).get("to_chapter") or 0) or None
  986. except (TypeError, ValueError):
  987. chapter = None
  988. if strategy == "auto":
  989. intent_payload = self.query_router.route_intent(query)
  990. if bool(self.config.graph_rag_enabled) and bool(intent_payload.get("needs_graph")):
  991. strategy = "graph_hybrid"
  992. if not center_entities:
  993. center_entities = list(intent_payload.get("entities") or [])
  994. else:
  995. strategy = "hybrid"
  996. if strategy not in {"vector", "bm25", "backtrack", "graph_hybrid", "hybrid"}:
  997. # 未知策略统一降级 hybrid,避免调用方传错参数导致中断。
  998. strategy = "hybrid"
  999. if strategy == "vector":
  1000. return await self.vector_search(query, top_k=top_k, chunk_type=chunk_type, chapter=chapter)
  1001. if strategy == "bm25":
  1002. return self.bm25_search(query, top_k=top_k, chunk_type=chunk_type, chapter=chapter)
  1003. if strategy == "backtrack":
  1004. return await self.search_with_backtrack(query, top_k=top_k)
  1005. if strategy == "graph_hybrid":
  1006. return await self.graph_hybrid_search(
  1007. query=query,
  1008. top_k=top_k,
  1009. chunk_type=chunk_type,
  1010. chapter=chapter,
  1011. center_entities=center_entities,
  1012. )
  1013. return await self.hybrid_search(
  1014. query=query,
  1015. vector_top_k=top_k,
  1016. bm25_top_k=top_k,
  1017. rerank_top_n=top_k,
  1018. chunk_type=chunk_type,
  1019. chapter=chapter,
  1020. )
  1021. # ==================== 混合检索 ====================
  1022. async def hybrid_search(
  1023. self,
  1024. query: str,
  1025. vector_top_k: int = None,
  1026. bm25_top_k: int = None,
  1027. rerank_top_n: int = None,
  1028. chunk_type: str | None = None,
  1029. chapter: int | None = None,
  1030. log_query: bool = True,
  1031. ) -> List[SearchResult]:
  1032. """
  1033. 混合检索:向量 + BM25 + RRF 融合 + Rerank
  1034. 步骤:
  1035. 1. 向量检索 top_k
  1036. 2. BM25 检索 top_k
  1037. 3. RRF 融合
  1038. 4. Rerank 精排
  1039. """
  1040. vector_top_k = vector_top_k or self.config.vector_top_k
  1041. bm25_top_k = bm25_top_k or self.config.bm25_top_k
  1042. rerank_top_n = rerank_top_n or self.config.rerank_top_n
  1043. start_time = time.perf_counter()
  1044. # 小规模:全表向量扫描(召回更稳);大规模:预筛选避免 O(n) 扫描拖慢
  1045. vectors_count = await asyncio.to_thread(self._get_vectors_count)
  1046. use_full_scan = vectors_count <= int(self.config.vector_full_scan_max_vectors)
  1047. if use_full_scan:
  1048. # 并行执行向量和 BM25 检索
  1049. vector_results, bm25_results = await asyncio.gather(
  1050. self.vector_search(query, vector_top_k, chunk_type=chunk_type, log_query=False, chapter=chapter),
  1051. asyncio.to_thread(self.bm25_search, query, bm25_top_k, 1.5, 0.75, chunk_type, False, chapter),
  1052. )
  1053. else:
  1054. bm25_candidates = max(
  1055. int(self.config.vector_prefilter_bm25_candidates),
  1056. int(bm25_top_k),
  1057. int(vector_top_k) * 5,
  1058. int(rerank_top_n) * 10,
  1059. )
  1060. recent_candidates = max(
  1061. int(self.config.vector_prefilter_recent_candidates),
  1062. int(vector_top_k) * 5,
  1063. int(rerank_top_n) * 10,
  1064. )
  1065. bm25_task = asyncio.to_thread(
  1066. self.bm25_search,
  1067. query,
  1068. bm25_candidates,
  1069. 1.5,
  1070. 0.75,
  1071. chunk_type,
  1072. False,
  1073. chapter,
  1074. )
  1075. recent_task = asyncio.to_thread(self._get_recent_chunk_ids, recent_candidates, chunk_type, chapter)
  1076. embed_task = self.api_client.embed([query])
  1077. bm25_candidates_results, recent_ids, query_embeddings = await asyncio.gather(
  1078. bm25_task,
  1079. recent_task,
  1080. embed_task,
  1081. )
  1082. if not query_embeddings:
  1083. self._update_degraded_mode()
  1084. return []
  1085. self._degraded_mode_reason = None
  1086. query_embedding = query_embeddings[0]
  1087. candidate_ids = {r.chunk_id for r in bm25_candidates_results}
  1088. candidate_ids.update(recent_ids)
  1089. rows = await asyncio.to_thread(self._fetch_vectors_by_chunk_ids, list(candidate_ids))
  1090. if chunk_type:
  1091. rows = [r for r in rows if len(r) > 6 and r[6] == chunk_type]
  1092. if chapter is not None:
  1093. rows = [r for r in rows if len(r) > 1 and int(r[1] or 0) <= int(chapter)]
  1094. vector_results = await asyncio.to_thread(
  1095. self._vector_search_rows,
  1096. query_embedding,
  1097. rows,
  1098. top_k=int(vector_top_k),
  1099. )
  1100. # BM25 结果用于融合时只取 top_k
  1101. bm25_results = list(bm25_candidates_results)[: int(bm25_top_k)]
  1102. # RRF 融合
  1103. rrf_scores = {}
  1104. k = self.config.rrf_k
  1105. for rank, result in enumerate(vector_results):
  1106. if result.chunk_id not in rrf_scores:
  1107. rrf_scores[result.chunk_id] = {"result": result, "score": 0}
  1108. rrf_scores[result.chunk_id]["score"] += 1 / (k + rank + 1)
  1109. for rank, result in enumerate(bm25_results):
  1110. if result.chunk_id not in rrf_scores:
  1111. rrf_scores[result.chunk_id] = {"result": result, "score": 0}
  1112. rrf_scores[result.chunk_id]["score"] += 1 / (k + rank + 1)
  1113. # 按 RRF 分数排序
  1114. sorted_results = sorted(
  1115. rrf_scores.values(),
  1116. key=lambda x: x["score"],
  1117. reverse=True
  1118. )
  1119. # 取 top candidates 进行 rerank
  1120. candidates = [item["result"] for item in sorted_results[:rerank_top_n * 2]]
  1121. if not candidates:
  1122. final_results: List[SearchResult] = []
  1123. latency_ms = int((time.perf_counter() - start_time) * 1000)
  1124. if log_query:
  1125. self._log_query(query, "hybrid", final_results, latency_ms, chapter=chapter)
  1126. return final_results
  1127. # 调用 Rerank API
  1128. documents = [c.content for c in candidates]
  1129. rerank_results = await self.api_client.rerank(query, documents, top_n=rerank_top_n)
  1130. if not rerank_results:
  1131. # Rerank 失败,返回 RRF 结果
  1132. final_results = [item["result"] for item in sorted_results[:rerank_top_n]]
  1133. latency_ms = int((time.perf_counter() - start_time) * 1000)
  1134. if log_query:
  1135. self._log_query(query, "hybrid", final_results, latency_ms, chapter=chapter)
  1136. return final_results
  1137. # 组装最终结果
  1138. final_results = []
  1139. for r in rerank_results:
  1140. idx = r.get("index", 0)
  1141. if idx < len(candidates):
  1142. result = candidates[idx]
  1143. result.score = r.get("relevance_score", 0)
  1144. result.source = "hybrid"
  1145. final_results.append(result)
  1146. latency_ms = int((time.perf_counter() - start_time) * 1000)
  1147. if log_query:
  1148. self._log_query(query, "hybrid", final_results, latency_ms, chapter=chapter)
  1149. return final_results
  1150. def _get_chunks_by_ids(self, chunk_ids: List[str]) -> List[SearchResult]:
  1151. rows = self._fetch_vectors_by_chunk_ids(chunk_ids)
  1152. results: List[SearchResult] = []
  1153. for row in rows:
  1154. (
  1155. chunk_id,
  1156. chapter,
  1157. scene_index,
  1158. content,
  1159. _embedding_bytes,
  1160. parent_chunk_id,
  1161. chunk_type,
  1162. source_file,
  1163. ) = row
  1164. results.append(
  1165. SearchResult(
  1166. chunk_id=chunk_id,
  1167. chapter=chapter,
  1168. scene_index=scene_index,
  1169. content=content,
  1170. score=0.0,
  1171. source="parent",
  1172. parent_chunk_id=parent_chunk_id,
  1173. chunk_type=chunk_type,
  1174. source_file=source_file,
  1175. )
  1176. )
  1177. return results
  1178. def _merge_results(
  1179. self,
  1180. parents: List[SearchResult],
  1181. children: List[SearchResult],
  1182. ) -> List[SearchResult]:
  1183. parent_map = {p.chunk_id: p for p in parents}
  1184. merged: List[SearchResult] = []
  1185. seen = set()
  1186. for child in children:
  1187. parent_id = child.parent_chunk_id
  1188. if parent_id and parent_id in parent_map and parent_id not in seen:
  1189. merged.append(parent_map[parent_id])
  1190. seen.add(parent_id)
  1191. merged.append(child)
  1192. return merged
  1193. async def search_with_backtrack(self, query: str, top_k: int = 5) -> List[SearchResult]:
  1194. start_time = time.perf_counter()
  1195. child_results = await self.hybrid_search(
  1196. query,
  1197. vector_top_k=top_k * 2,
  1198. bm25_top_k=top_k * 2,
  1199. rerank_top_n=top_k,
  1200. chunk_type="scene",
  1201. log_query=False,
  1202. )
  1203. parent_ids = sorted({r.parent_chunk_id for r in child_results if r.parent_chunk_id})
  1204. parents = self._get_chunks_by_ids(parent_ids) if parent_ids else []
  1205. merged = self._merge_results(parents, child_results[:top_k])
  1206. latency_ms = int((time.perf_counter() - start_time) * 1000)
  1207. self._log_query(query, "backtrack", merged, latency_ms)
  1208. return merged
  1209. # ==================== 统计 ====================
  1210. def get_stats(self) -> Dict[str, int]:
  1211. """获取 RAG 统计"""
  1212. with self._get_conn() as conn:
  1213. cursor = conn.cursor()
  1214. cursor.execute("SELECT COUNT(*) FROM vectors")
  1215. vectors = cursor.fetchone()[0]
  1216. cursor.execute("SELECT COUNT(DISTINCT term) FROM bm25_index")
  1217. terms = cursor.fetchone()[0]
  1218. cursor.execute("SELECT MAX(chapter) FROM vectors")
  1219. max_chapter = cursor.fetchone()[0] or 0
  1220. return {
  1221. "vectors": vectors,
  1222. "terms": terms,
  1223. "max_chapter": max_chapter
  1224. }
  1225. # ==================== CLI 接口 ====================
  1226. def main():
  1227. import argparse
  1228. import sys
  1229. from .cli_output import print_success, print_error
  1230. from .cli_args import normalize_global_project_root, load_json_arg
  1231. parser = argparse.ArgumentParser(description="RAG Adapter CLI")
  1232. parser.add_argument("--project-root", type=str, help="项目根目录")
  1233. subparsers = parser.add_subparsers(dest="command")
  1234. # 获取统计
  1235. subparsers.add_parser("stats")
  1236. # 写入索引
  1237. index_parser = subparsers.add_parser("index-chapter")
  1238. index_parser.add_argument("--chapter", type=int, required=True)
  1239. index_parser.add_argument("--scenes", required=True, help="JSON 格式的场景列表")
  1240. index_parser.add_argument("--summary", required=False, help="章节摘要文本")
  1241. # 搜索
  1242. search_parser = subparsers.add_parser("search")
  1243. search_parser.add_argument("--query", required=True)
  1244. search_parser.add_argument(
  1245. "--mode",
  1246. choices=["auto", "vector", "bm25", "hybrid", "graph_hybrid", "backtrack"],
  1247. default="hybrid",
  1248. )
  1249. search_parser.add_argument("--top-k", type=int, default=5)
  1250. search_parser.add_argument("--chunk-type", choices=["scene", "summary"], default=None)
  1251. search_parser.add_argument(
  1252. "--center-entities",
  1253. required=False,
  1254. help="中心实体列表(JSON 数组或逗号分隔)",
  1255. )
  1256. argv = normalize_global_project_root(sys.argv[1:])
  1257. args = parser.parse_args(argv)
  1258. command_started_at = time.perf_counter()
  1259. # 初始化
  1260. config = None
  1261. if args.project_root:
  1262. # 允许传入“工作区根目录”,统一解析到真正的 book project_root(必须包含 .webnovel/state.json)
  1263. from project_locator import resolve_project_root
  1264. from .config import DataModulesConfig
  1265. resolved_root = resolve_project_root(args.project_root)
  1266. config = DataModulesConfig.from_project_root(resolved_root)
  1267. adapter = RAGAdapter(config)
  1268. tool_name = f"rag_adapter:{args.command or 'unknown'}"
  1269. def _append_timing(success: bool, *, error_code: str | None = None, error_message: str | None = None, chapter: int | None = None):
  1270. elapsed_ms = int((time.perf_counter() - command_started_at) * 1000)
  1271. safe_append_perf_timing(
  1272. adapter.config.project_root,
  1273. tool_name=tool_name,
  1274. success=success,
  1275. elapsed_ms=elapsed_ms,
  1276. chapter=chapter,
  1277. error_code=error_code,
  1278. error_message=error_message,
  1279. )
  1280. def emit_success(data=None, message: str = "ok", chapter: int | None = None):
  1281. print_success(data, message=message)
  1282. safe_log_tool_call(adapter.index_manager, tool_name=tool_name, success=True)
  1283. _append_timing(True, chapter=chapter)
  1284. def emit_error(code: str, message: str, suggestion: str | None = None, chapter: int | None = None):
  1285. print_error(code, message, suggestion=suggestion)
  1286. safe_log_tool_call(
  1287. adapter.index_manager,
  1288. tool_name=tool_name,
  1289. success=False,
  1290. error_code=code,
  1291. error_message=message,
  1292. )
  1293. _append_timing(False, error_code=code, error_message=message, chapter=chapter)
  1294. if args.command == "stats":
  1295. stats = adapter.get_stats()
  1296. emit_success(stats, message="stats")
  1297. elif args.command == "index-chapter":
  1298. scenes = load_json_arg(args.scenes)
  1299. chunks = []
  1300. # summary chunk
  1301. summary_text = args.summary
  1302. if not summary_text and config:
  1303. summary_path = config.webnovel_dir / "summaries" / f"ch{args.chapter:04d}.md"
  1304. if summary_path.exists():
  1305. summary_text = summary_path.read_text(encoding="utf-8")
  1306. parent_chunk_id = None
  1307. if summary_text:
  1308. parent_chunk_id = f"ch{args.chapter:04d}_summary"
  1309. chunks.append(
  1310. {
  1311. "chapter": args.chapter,
  1312. "scene_index": 0,
  1313. "content": summary_text,
  1314. "chunk_type": "summary",
  1315. "chunk_id": parent_chunk_id,
  1316. "source_file": f"summaries/ch{args.chapter:04d}.md",
  1317. }
  1318. )
  1319. for s in scenes:
  1320. scene_index = s.get("index", 0)
  1321. chunk_id = f"ch{args.chapter:04d}_s{int(scene_index)}"
  1322. chunks.append(
  1323. {
  1324. "chapter": args.chapter,
  1325. "scene_index": scene_index,
  1326. "content": s.get("content", ""),
  1327. "chunk_type": "scene",
  1328. "parent_chunk_id": parent_chunk_id,
  1329. "chunk_id": chunk_id,
  1330. "source_file": f"正文/第{args.chapter:04d}章.md#scene_{int(scene_index)}",
  1331. }
  1332. )
  1333. stored = asyncio.run(adapter.store_chunks(chunks))
  1334. skipped = len(chunks) - stored
  1335. result = {"stored": stored, "skipped": skipped, "total": len(chunks)}
  1336. if skipped > 0:
  1337. emit_success(result, message="indexed_with_warnings", chapter=args.chapter)
  1338. else:
  1339. emit_success(result, message="indexed", chapter=args.chapter)
  1340. elif args.command == "search":
  1341. center_entities: List[str] | None = None
  1342. if getattr(args, "center_entities", None):
  1343. raw = str(args.center_entities).strip()
  1344. if raw:
  1345. try:
  1346. parsed = json.loads(raw)
  1347. if isinstance(parsed, list):
  1348. center_entities = [str(x).strip() for x in parsed if str(x).strip()]
  1349. except Exception:
  1350. center_entities = [x.strip() for x in re.split(r"[,,;;\s]+", raw) if x.strip()]
  1351. if args.mode == "vector":
  1352. results = asyncio.run(adapter.vector_search(args.query, args.top_k, chunk_type=args.chunk_type))
  1353. elif args.mode == "bm25":
  1354. results = adapter.bm25_search(args.query, args.top_k, chunk_type=args.chunk_type)
  1355. elif args.mode == "backtrack":
  1356. results = asyncio.run(adapter.search_with_backtrack(args.query, args.top_k))
  1357. elif args.mode == "graph_hybrid":
  1358. results = asyncio.run(
  1359. adapter.graph_hybrid_search(
  1360. args.query,
  1361. args.top_k,
  1362. chunk_type=args.chunk_type,
  1363. center_entities=center_entities,
  1364. )
  1365. )
  1366. elif args.mode == "auto":
  1367. results = asyncio.run(
  1368. adapter.search(
  1369. args.query,
  1370. args.top_k,
  1371. strategy="auto",
  1372. chunk_type=args.chunk_type,
  1373. center_entities=center_entities,
  1374. )
  1375. )
  1376. else:
  1377. results = asyncio.run(adapter.hybrid_search(args.query, args.top_k, args.top_k, args.top_k, chunk_type=args.chunk_type))
  1378. payload = [r.__dict__ for r in results]
  1379. degraded_reason = adapter.degraded_mode_reason
  1380. if degraded_reason:
  1381. warnings = [{"code": "DEGRADED_MODE", "reason": degraded_reason}]
  1382. print_success(payload, message="search_results", warnings=warnings)
  1383. safe_log_tool_call(adapter.index_manager, tool_name=tool_name, success=True)
  1384. _append_timing(True)
  1385. else:
  1386. emit_success(payload, message="search_results")
  1387. else:
  1388. emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")
  1389. if __name__ == "__main__":
  1390. import sys
  1391. if sys.platform == "win32":
  1392. enable_windows_utf8_stdio()
  1393. main()