test_rag_adapter.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. RAGAdapter tests
  5. """
  6. import sys
  7. import json
  8. import asyncio
  9. import logging
  10. import sqlite3
  11. from contextlib import closing
  12. import pytest
  13. import data_modules.rag_adapter as rag_module
  14. from data_modules.rag_adapter import RAGAdapter
  15. from data_modules.config import DataModulesConfig
  16. from data_modules.index_manager import EntityMeta, RelationshipMeta
  17. class StubClient:
  18. async def embed(self, texts):
  19. return [[1.0, 0.0] for _ in texts]
  20. async def embed_batch(self, texts, skip_failures=True):
  21. return [[1.0, 0.0] for _ in texts]
  22. async def rerank(self, query, documents, top_n=None):
  23. top_n = top_n or len(documents)
  24. return [{"index": i, "relevance_score": 1.0 / (i + 1)} for i in range(min(top_n, len(documents)))]
  25. class StubClientWithFailures(StubClient):
  26. async def embed_batch(self, texts, skip_failures=True):
  27. if len(texts) == 1:
  28. return [None]
  29. return [None, [1.0, 0.0]]
  30. class StubEmbedClient401:
  31. def __init__(self):
  32. self.last_error_status = 401
  33. self.last_error_message = "auth failed"
  34. class StubClientAuthFailure(StubClient):
  35. def __init__(self):
  36. self._embed_client = StubEmbedClient401()
  37. async def embed(self, texts):
  38. return None
  39. class StubClientRerankFailure(StubClient):
  40. async def rerank(self, query, documents, top_n=None):
  41. return []
  42. @pytest.fixture
  43. def temp_project(tmp_path, monkeypatch):
  44. cfg = DataModulesConfig.from_project_root(tmp_path)
  45. cfg.ensure_dirs()
  46. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  47. return cfg
  48. @pytest.mark.asyncio
  49. async def test_store_and_search(temp_project):
  50. adapter = RAGAdapter(temp_project)
  51. chunks = [
  52. {"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
  53. {"chapter": 1, "scene_index": 2, "content": "药老传授炼药技巧"},
  54. ]
  55. stored = await adapter.store_chunks(chunks)
  56. assert stored == 2
  57. vec_results = await adapter.vector_search("萧炎", top_k=2)
  58. assert len(vec_results) == 2
  59. bm25_results = adapter.bm25_search("萧炎", top_k=2)
  60. assert len(bm25_results) >= 1
  61. stats = adapter.get_stats()
  62. assert stats["vectors"] == 2
  63. @pytest.mark.asyncio
  64. async def test_store_chunks_with_embedding_failure(tmp_path, monkeypatch):
  65. cfg = DataModulesConfig.from_project_root(tmp_path)
  66. cfg.ensure_dirs()
  67. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientWithFailures())
  68. adapter = RAGAdapter(cfg)
  69. chunks = [
  70. {"chapter": 1, "scene_index": 1, "content": "短内容"},
  71. {"chapter": 1, "scene_index": 2, "content": "稍长内容用于索引"},
  72. ]
  73. stored = await adapter.store_chunks(chunks)
  74. assert stored == 1
  75. @pytest.mark.asyncio
  76. async def test_hybrid_search_full_scan(temp_project):
  77. adapter = RAGAdapter(temp_project)
  78. await adapter.store_chunks(
  79. [{"chapter": 1, "scene_index": 1, "content": "萧炎修炼"}]
  80. )
  81. results = await adapter.hybrid_search("萧炎", vector_top_k=5, bm25_top_k=5, rerank_top_n=1)
  82. assert results
  83. assert results[0].source == "hybrid"
  84. @pytest.mark.asyncio
  85. async def test_hybrid_search_prefilter(tmp_path, monkeypatch):
  86. cfg = DataModulesConfig.from_project_root(tmp_path)
  87. cfg.ensure_dirs()
  88. cfg.vector_full_scan_max_vectors = 0
  89. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  90. adapter = RAGAdapter(cfg)
  91. await adapter.store_chunks(
  92. [
  93. {"chapter": 1, "scene_index": 1, "content": "萧炎修炼"},
  94. {"chapter": 2, "scene_index": 1, "content": "药老出场"},
  95. ]
  96. )
  97. results = await adapter.hybrid_search("药老", vector_top_k=2, bm25_top_k=2, rerank_top_n=1)
  98. assert results
  99. @pytest.mark.asyncio
  100. async def test_search_respects_chapter_filter_across_strategies(tmp_path, monkeypatch):
  101. cfg = DataModulesConfig.from_project_root(tmp_path)
  102. cfg.ensure_dirs()
  103. cfg.vector_full_scan_max_vectors = 0 # 强制走预筛选分支
  104. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  105. adapter = RAGAdapter(cfg)
  106. await adapter.store_chunks(
  107. [
  108. {"chapter": 1, "scene_index": 1, "content": "前文线索,尚未涉及关键宝物"},
  109. {"chapter": 2, "scene_index": 1, "content": "秘宝现世,引发争夺"},
  110. {"chapter": 3, "scene_index": 1, "content": "秘宝大战彻底爆发"},
  111. ]
  112. )
  113. vector_results = await adapter.vector_search("秘宝", top_k=5, chapter=1)
  114. assert vector_results
  115. assert all((r.chapter or 0) <= 1 for r in vector_results)
  116. bm25_results = adapter.bm25_search("秘宝", top_k=5, chapter=1)
  117. assert bm25_results
  118. assert all((r.chapter or 0) <= 1 for r in bm25_results)
  119. hybrid_results = await adapter.hybrid_search(
  120. "秘宝",
  121. vector_top_k=5,
  122. bm25_top_k=5,
  123. rerank_top_n=3,
  124. chapter=1,
  125. )
  126. assert hybrid_results
  127. assert all((r.chapter or 0) <= 1 for r in hybrid_results)
  128. @pytest.mark.asyncio
  129. async def test_graph_hybrid_search_with_entity_expansion(tmp_path, monkeypatch):
  130. cfg = DataModulesConfig.from_project_root(tmp_path)
  131. cfg.ensure_dirs()
  132. cfg.graph_rag_enabled = True
  133. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  134. adapter = RAGAdapter(cfg)
  135. adapter.index_manager.upsert_entity(
  136. EntityMeta(
  137. id="xiaoyan",
  138. type="角色",
  139. canonical_name="萧炎",
  140. current={},
  141. first_appearance=1,
  142. last_appearance=2,
  143. )
  144. )
  145. adapter.index_manager.upsert_entity(
  146. EntityMeta(
  147. id="yaolao",
  148. type="角色",
  149. canonical_name="药老",
  150. current={},
  151. first_appearance=1,
  152. last_appearance=2,
  153. )
  154. )
  155. adapter.index_manager.register_alias("萧炎", "xiaoyan", "角色")
  156. adapter.index_manager.register_alias("药老", "yaolao", "角色")
  157. adapter.index_manager.upsert_relationship(
  158. RelationshipMeta(
  159. from_entity="xiaoyan",
  160. to_entity="yaolao",
  161. type="师徒",
  162. description="收徒",
  163. chapter=1,
  164. )
  165. )
  166. await adapter.store_chunks(
  167. [
  168. {"chapter": 1, "scene_index": 1, "content": "萧炎拜药老为师,正式成为师徒"},
  169. {"chapter": 2, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
  170. ]
  171. )
  172. results = await adapter.graph_hybrid_search(
  173. "萧炎和药老关系",
  174. top_k=2,
  175. center_entities=["萧炎", "药老"],
  176. )
  177. assert results
  178. assert any("药老" in r.content for r in results)
  179. assert all(r.source == "graph_hybrid" for r in results)
  180. @pytest.mark.asyncio
  181. async def test_search_auto_uses_graph_strategy_when_enabled(tmp_path, monkeypatch):
  182. cfg = DataModulesConfig.from_project_root(tmp_path)
  183. cfg.ensure_dirs()
  184. cfg.graph_rag_enabled = True
  185. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  186. adapter = RAGAdapter(cfg)
  187. adapter.index_manager.upsert_entity(
  188. EntityMeta(
  189. id="xiaoyan",
  190. type="角色",
  191. canonical_name="萧炎",
  192. current={},
  193. first_appearance=1,
  194. last_appearance=1,
  195. )
  196. )
  197. adapter.index_manager.register_alias("萧炎", "xiaoyan", "角色")
  198. await adapter.store_chunks(
  199. [{"chapter": 1, "scene_index": 1, "content": "萧炎突破斗师"}]
  200. )
  201. results = await adapter.search("萧炎关系", top_k=1, strategy="auto")
  202. assert results
  203. assert results[0].source in {"graph_hybrid", "hybrid"}
  204. @pytest.mark.asyncio
  205. async def test_graph_hybrid_search_fallback_when_graph_disabled(tmp_path, monkeypatch):
  206. cfg = DataModulesConfig.from_project_root(tmp_path)
  207. cfg.ensure_dirs()
  208. cfg.graph_rag_enabled = False
  209. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  210. adapter = RAGAdapter(cfg)
  211. await adapter.store_chunks(
  212. [{"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"}]
  213. )
  214. modes = []
  215. def _record_log(query, mode, results, latency_ms, chapter=None):
  216. modes.append(mode)
  217. monkeypatch.setattr(adapter, "_log_query", _record_log)
  218. results = await adapter.graph_hybrid_search("萧炎关系", top_k=1)
  219. assert results
  220. assert modes
  221. assert modes[-1] == "graph_hybrid_fallback"
  222. assert all(r.source == "hybrid" for r in results)
  223. @pytest.mark.asyncio
  224. async def test_graph_hybrid_search_rerank_failure_uses_candidates(tmp_path, monkeypatch):
  225. cfg = DataModulesConfig.from_project_root(tmp_path)
  226. cfg.ensure_dirs()
  227. cfg.graph_rag_enabled = True
  228. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientRerankFailure())
  229. adapter = RAGAdapter(cfg)
  230. adapter.index_manager.upsert_entity(
  231. EntityMeta(
  232. id="xiaoyan",
  233. type="角色",
  234. canonical_name="萧炎",
  235. current={},
  236. first_appearance=1,
  237. last_appearance=2,
  238. )
  239. )
  240. adapter.index_manager.upsert_entity(
  241. EntityMeta(
  242. id="yaolao",
  243. type="角色",
  244. canonical_name="药老",
  245. current={},
  246. first_appearance=1,
  247. last_appearance=2,
  248. )
  249. )
  250. adapter.index_manager.register_alias("萧炎", "xiaoyan", "角色")
  251. adapter.index_manager.register_alias("药老", "yaolao", "角色")
  252. adapter.index_manager.upsert_relationship(
  253. RelationshipMeta(
  254. from_entity="xiaoyan",
  255. to_entity="yaolao",
  256. type="师徒",
  257. description="收徒",
  258. chapter=1,
  259. )
  260. )
  261. await adapter.store_chunks(
  262. [
  263. {"chapter": 1, "scene_index": 1, "content": "萧炎拜药老为师,正式成为师徒"},
  264. {"chapter": 2, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
  265. ]
  266. )
  267. results = await adapter.graph_hybrid_search(
  268. "萧炎和药老关系",
  269. top_k=2,
  270. center_entities=["萧炎", "药老"],
  271. )
  272. assert results
  273. assert len(results) <= 2
  274. assert all(r.source == "graph_hybrid" for r in results)
  275. @pytest.mark.asyncio
  276. async def test_search_unknown_strategy_falls_back_to_hybrid(tmp_path, monkeypatch):
  277. cfg = DataModulesConfig.from_project_root(tmp_path)
  278. cfg.ensure_dirs()
  279. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  280. adapter = RAGAdapter(cfg)
  281. await adapter.store_chunks(
  282. [{"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"}]
  283. )
  284. results = await adapter.search("萧炎", top_k=1, strategy="not_exists")
  285. assert results
  286. assert all(r.source == "hybrid" for r in results)
  287. @pytest.mark.asyncio
  288. async def test_search_with_backtrack(temp_project):
  289. adapter = RAGAdapter(temp_project)
  290. chunks = [
  291. {
  292. "chapter": 1,
  293. "scene_index": 0,
  294. "content": "章节摘要",
  295. "chunk_type": "summary",
  296. "chunk_id": "ch0001_summary",
  297. "source_file": "summaries/ch0001.md",
  298. },
  299. {
  300. "chapter": 1,
  301. "scene_index": 1,
  302. "content": "场景内容",
  303. "chunk_type": "scene",
  304. "chunk_id": "ch0001_s1",
  305. "parent_chunk_id": "ch0001_summary",
  306. "source_file": "正文/第0001章.md#scene_1",
  307. },
  308. ]
  309. await adapter.store_chunks(chunks)
  310. results = await adapter.search_with_backtrack("场景", top_k=1)
  311. assert any(r.chunk_type == "summary" for r in results)
  312. def test_vector_helpers(temp_project):
  313. adapter = RAGAdapter(temp_project)
  314. emb = [1.0, 0.0]
  315. data = adapter._serialize_embedding(emb)
  316. assert adapter._deserialize_embedding(data) == emb
  317. assert adapter._cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
  318. def test_recent_and_fetch_vectors(temp_project):
  319. adapter = RAGAdapter(temp_project)
  320. with adapter._get_conn() as conn:
  321. cursor = conn.cursor()
  322. cursor.execute(
  323. "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  324. ("ch0001_s1", 1, 1, "内容", b"", None, "scene", "正文/第0001章.md#scene_1"),
  325. )
  326. cursor.execute(
  327. "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  328. ("ch0002_s1", 2, 1, "后文内容", b"", None, "scene", "正文/第0002章.md#scene_1"),
  329. )
  330. conn.commit()
  331. assert adapter._get_vectors_count() == 2
  332. assert adapter._get_recent_chunk_ids(1) == ["ch0002_s1"]
  333. assert adapter._get_recent_chunk_ids(10, chapter=1) == ["ch0001_s1"]
  334. rows = adapter._fetch_vectors_by_chunk_ids(["ch0001_s1"])
  335. assert len(rows) == 1
  336. def test_init_db_migrates_legacy_vectors_schema(tmp_path, monkeypatch):
  337. cfg = DataModulesConfig.from_project_root(tmp_path)
  338. cfg.ensure_dirs()
  339. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  340. # 旧结构:缺少 parent_chunk_id/chunk_type/source_file/created_at
  341. with closing(sqlite3.connect(str(cfg.vector_db))) as conn:
  342. cursor = conn.cursor()
  343. cursor.execute(
  344. """
  345. CREATE TABLE vectors (
  346. chunk_id TEXT PRIMARY KEY,
  347. chapter INTEGER,
  348. scene_index INTEGER,
  349. content TEXT,
  350. embedding BLOB
  351. )
  352. """
  353. )
  354. cursor.execute(
  355. """
  356. INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding)
  357. VALUES (?, ?, ?, ?, ?)
  358. """,
  359. ("ch0001_s1", 1, 1, "旧数据", b""),
  360. )
  361. conn.commit()
  362. adapter = RAGAdapter(cfg)
  363. with adapter._get_conn() as conn:
  364. cursor = conn.cursor()
  365. cursor.execute("PRAGMA table_info(vectors)")
  366. cols = {row[1] for row in cursor.fetchall()}
  367. assert {"parent_chunk_id", "chunk_type", "source_file", "created_at"}.issubset(cols)
  368. cursor.execute("SELECT COUNT(*) FROM vectors")
  369. assert cursor.fetchone()[0] == 1
  370. cursor.execute("SELECT chunk_type FROM vectors WHERE chunk_id = ?", ("ch0001_s1",))
  371. row = cursor.fetchone()
  372. assert row is not None
  373. assert row[0] == "scene"
  374. backup_dir = cfg.webnovel_dir / "backups"
  375. backups = list(backup_dir.glob("vectors.db.schema_migration.v*.bak"))
  376. assert backups
  377. def test_rag_adapter_cli(temp_project, monkeypatch, capsys):
  378. # stats
  379. def run_cli(args):
  380. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  381. rag_module.main()
  382. root = str(temp_project.project_root)
  383. run_cli(["--project-root", root, "stats"])
  384. # index-chapter
  385. run_cli(
  386. [
  387. "--project-root",
  388. root,
  389. "index-chapter",
  390. "--chapter",
  391. "1",
  392. "--scenes",
  393. json.dumps([{"index": 1, "summary": "摘要", "content": "内容"}], ensure_ascii=False),
  394. ]
  395. )
  396. # search
  397. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "bm25", "--top-k", "5"])
  398. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "vector", "--top-k", "5"])
  399. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "hybrid", "--top-k", "5"])
  400. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "auto", "--top-k", "5"])
  401. capsys.readouterr()
  402. def test_rag_adapter_log_query_failure_is_reported(temp_project, monkeypatch, caplog):
  403. adapter = RAGAdapter(temp_project)
  404. def _raise_log_error(*args, **kwargs):
  405. raise RuntimeError("log write failed")
  406. monkeypatch.setattr(adapter.index_manager, "log_rag_query", _raise_log_error)
  407. with caplog.at_level(logging.WARNING):
  408. adapter._log_query("q", "vector", [], 1)
  409. message_text = "\n".join(record.getMessage() for record in caplog.records)
  410. assert "failed to log rag query" in message_text
  411. def test_rag_adapter_cli_search_shows_degraded_warning(temp_project, monkeypatch, capsys):
  412. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientAuthFailure())
  413. def run_cli(args):
  414. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  415. rag_module.main()
  416. root = str(temp_project.project_root)
  417. run_cli(["--project-root", root, "search", "--query", "测试", "--mode", "vector", "--top-k", "3"])
  418. captured = capsys.readouterr()
  419. payload = json.loads(captured.out.strip().splitlines()[-1])
  420. assert payload.get("status") == "success"
  421. warnings = payload.get("warnings") or []
  422. assert warnings
  423. assert warnings[0].get("code") == "DEGRADED_MODE"
  424. assert warnings[0].get("reason") == "embedding_auth_failed"