test_rag_adapter.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. RAGAdapter tests
  5. """
  6. import sys
  7. import json
  8. import asyncio
  9. import logging
  10. import sqlite3
  11. from contextlib import closing
  12. import pytest
  13. import data_modules.rag_adapter as rag_module
  14. from data_modules.rag_adapter import RAGAdapter
  15. from data_modules.config import DataModulesConfig
  16. from data_modules.index_manager import EntityMeta, RelationshipMeta
  17. class StubClient:
  18. async def embed(self, texts):
  19. return [[1.0, 0.0] for _ in texts]
  20. async def embed_batch(self, texts, skip_failures=True):
  21. return [[1.0, 0.0] for _ in texts]
  22. async def rerank(self, query, documents, top_n=None):
  23. top_n = top_n or len(documents)
  24. return [{"index": i, "relevance_score": 1.0 / (i + 1)} for i in range(min(top_n, len(documents)))]
  25. class StubClientWithFailures(StubClient):
  26. async def embed_batch(self, texts, skip_failures=True):
  27. if len(texts) == 1:
  28. return [None]
  29. return [None, [1.0, 0.0]]
  30. class StubEmbedClient401:
  31. def __init__(self):
  32. self.last_error_status = 401
  33. self.last_error_message = "auth failed"
  34. class StubClientAuthFailure(StubClient):
  35. def __init__(self):
  36. self._embed_client = StubEmbedClient401()
  37. async def embed(self, texts):
  38. return None
  39. @pytest.fixture
  40. def temp_project(tmp_path, monkeypatch):
  41. cfg = DataModulesConfig.from_project_root(tmp_path)
  42. cfg.ensure_dirs()
  43. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  44. return cfg
  45. @pytest.mark.asyncio
  46. async def test_store_and_search(temp_project):
  47. adapter = RAGAdapter(temp_project)
  48. chunks = [
  49. {"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
  50. {"chapter": 1, "scene_index": 2, "content": "药老传授炼药技巧"},
  51. ]
  52. stored = await adapter.store_chunks(chunks)
  53. assert stored == 2
  54. vec_results = await adapter.vector_search("萧炎", top_k=2)
  55. assert len(vec_results) == 2
  56. bm25_results = adapter.bm25_search("萧炎", top_k=2)
  57. assert len(bm25_results) >= 1
  58. stats = adapter.get_stats()
  59. assert stats["vectors"] == 2
  60. @pytest.mark.asyncio
  61. async def test_store_chunks_with_embedding_failure(tmp_path, monkeypatch):
  62. cfg = DataModulesConfig.from_project_root(tmp_path)
  63. cfg.ensure_dirs()
  64. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientWithFailures())
  65. adapter = RAGAdapter(cfg)
  66. chunks = [
  67. {"chapter": 1, "scene_index": 1, "content": "短内容"},
  68. {"chapter": 1, "scene_index": 2, "content": "稍长内容用于索引"},
  69. ]
  70. stored = await adapter.store_chunks(chunks)
  71. assert stored == 1
  72. @pytest.mark.asyncio
  73. async def test_hybrid_search_full_scan(temp_project):
  74. adapter = RAGAdapter(temp_project)
  75. await adapter.store_chunks(
  76. [{"chapter": 1, "scene_index": 1, "content": "萧炎修炼"}]
  77. )
  78. results = await adapter.hybrid_search("萧炎", vector_top_k=5, bm25_top_k=5, rerank_top_n=1)
  79. assert results
  80. assert results[0].source == "hybrid"
  81. @pytest.mark.asyncio
  82. async def test_hybrid_search_prefilter(tmp_path, monkeypatch):
  83. cfg = DataModulesConfig.from_project_root(tmp_path)
  84. cfg.ensure_dirs()
  85. cfg.vector_full_scan_max_vectors = 0
  86. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  87. adapter = RAGAdapter(cfg)
  88. await adapter.store_chunks(
  89. [
  90. {"chapter": 1, "scene_index": 1, "content": "萧炎修炼"},
  91. {"chapter": 2, "scene_index": 1, "content": "药老出场"},
  92. ]
  93. )
  94. results = await adapter.hybrid_search("药老", vector_top_k=2, bm25_top_k=2, rerank_top_n=1)
  95. assert results
  96. @pytest.mark.asyncio
  97. async def test_search_respects_chapter_filter_across_strategies(tmp_path, monkeypatch):
  98. cfg = DataModulesConfig.from_project_root(tmp_path)
  99. cfg.ensure_dirs()
  100. cfg.vector_full_scan_max_vectors = 0 # 强制走预筛选分支
  101. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  102. adapter = RAGAdapter(cfg)
  103. await adapter.store_chunks(
  104. [
  105. {"chapter": 1, "scene_index": 1, "content": "前文线索,尚未涉及关键宝物"},
  106. {"chapter": 2, "scene_index": 1, "content": "秘宝现世,引发争夺"},
  107. {"chapter": 3, "scene_index": 1, "content": "秘宝大战彻底爆发"},
  108. ]
  109. )
  110. vector_results = await adapter.vector_search("秘宝", top_k=5, chapter=1)
  111. assert vector_results
  112. assert all((r.chapter or 0) <= 1 for r in vector_results)
  113. bm25_results = adapter.bm25_search("秘宝", top_k=5, chapter=1)
  114. assert bm25_results
  115. assert all((r.chapter or 0) <= 1 for r in bm25_results)
  116. hybrid_results = await adapter.hybrid_search(
  117. "秘宝",
  118. vector_top_k=5,
  119. bm25_top_k=5,
  120. rerank_top_n=3,
  121. chapter=1,
  122. )
  123. assert hybrid_results
  124. assert all((r.chapter or 0) <= 1 for r in hybrid_results)
  125. @pytest.mark.asyncio
  126. async def test_graph_hybrid_search_with_entity_expansion(tmp_path, monkeypatch):
  127. cfg = DataModulesConfig.from_project_root(tmp_path)
  128. cfg.ensure_dirs()
  129. cfg.graph_rag_enabled = True
  130. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  131. adapter = RAGAdapter(cfg)
  132. adapter.index_manager.upsert_entity(
  133. EntityMeta(
  134. id="xiaoyan",
  135. type="角色",
  136. canonical_name="萧炎",
  137. current={},
  138. first_appearance=1,
  139. last_appearance=2,
  140. )
  141. )
  142. adapter.index_manager.upsert_entity(
  143. EntityMeta(
  144. id="yaolao",
  145. type="角色",
  146. canonical_name="药老",
  147. current={},
  148. first_appearance=1,
  149. last_appearance=2,
  150. )
  151. )
  152. adapter.index_manager.register_alias("萧炎", "xiaoyan", "角色")
  153. adapter.index_manager.register_alias("药老", "yaolao", "角色")
  154. adapter.index_manager.upsert_relationship(
  155. RelationshipMeta(
  156. from_entity="xiaoyan",
  157. to_entity="yaolao",
  158. type="师徒",
  159. description="收徒",
  160. chapter=1,
  161. )
  162. )
  163. await adapter.store_chunks(
  164. [
  165. {"chapter": 1, "scene_index": 1, "content": "萧炎拜药老为师,正式成为师徒"},
  166. {"chapter": 2, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
  167. ]
  168. )
  169. results = await adapter.graph_hybrid_search(
  170. "萧炎和药老关系",
  171. top_k=2,
  172. center_entities=["萧炎", "药老"],
  173. )
  174. assert results
  175. assert any("药老" in r.content for r in results)
  176. assert all(r.source == "graph_hybrid" for r in results)
  177. @pytest.mark.asyncio
  178. async def test_search_auto_uses_graph_strategy_when_enabled(tmp_path, monkeypatch):
  179. cfg = DataModulesConfig.from_project_root(tmp_path)
  180. cfg.ensure_dirs()
  181. cfg.graph_rag_enabled = True
  182. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  183. adapter = RAGAdapter(cfg)
  184. adapter.index_manager.upsert_entity(
  185. EntityMeta(
  186. id="xiaoyan",
  187. type="角色",
  188. canonical_name="萧炎",
  189. current={},
  190. first_appearance=1,
  191. last_appearance=1,
  192. )
  193. )
  194. adapter.index_manager.register_alias("萧炎", "xiaoyan", "角色")
  195. await adapter.store_chunks(
  196. [{"chapter": 1, "scene_index": 1, "content": "萧炎突破斗师"}]
  197. )
  198. results = await adapter.search("萧炎关系", top_k=1, strategy="auto")
  199. assert results
  200. assert results[0].source in {"graph_hybrid", "hybrid"}
  201. @pytest.mark.asyncio
  202. async def test_search_with_backtrack(temp_project):
  203. adapter = RAGAdapter(temp_project)
  204. chunks = [
  205. {
  206. "chapter": 1,
  207. "scene_index": 0,
  208. "content": "章节摘要",
  209. "chunk_type": "summary",
  210. "chunk_id": "ch0001_summary",
  211. "source_file": "summaries/ch0001.md",
  212. },
  213. {
  214. "chapter": 1,
  215. "scene_index": 1,
  216. "content": "场景内容",
  217. "chunk_type": "scene",
  218. "chunk_id": "ch0001_s1",
  219. "parent_chunk_id": "ch0001_summary",
  220. "source_file": "正文/第0001章.md#scene_1",
  221. },
  222. ]
  223. await adapter.store_chunks(chunks)
  224. results = await adapter.search_with_backtrack("场景", top_k=1)
  225. assert any(r.chunk_type == "summary" for r in results)
  226. def test_vector_helpers(temp_project):
  227. adapter = RAGAdapter(temp_project)
  228. emb = [1.0, 0.0]
  229. data = adapter._serialize_embedding(emb)
  230. assert adapter._deserialize_embedding(data) == emb
  231. assert adapter._cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
  232. def test_recent_and_fetch_vectors(temp_project):
  233. adapter = RAGAdapter(temp_project)
  234. with adapter._get_conn() as conn:
  235. cursor = conn.cursor()
  236. cursor.execute(
  237. "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  238. ("ch0001_s1", 1, 1, "内容", b"", None, "scene", "正文/第0001章.md#scene_1"),
  239. )
  240. cursor.execute(
  241. "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  242. ("ch0002_s1", 2, 1, "后文内容", b"", None, "scene", "正文/第0002章.md#scene_1"),
  243. )
  244. conn.commit()
  245. assert adapter._get_vectors_count() == 2
  246. assert adapter._get_recent_chunk_ids(1) == ["ch0002_s1"]
  247. assert adapter._get_recent_chunk_ids(10, chapter=1) == ["ch0001_s1"]
  248. rows = adapter._fetch_vectors_by_chunk_ids(["ch0001_s1"])
  249. assert len(rows) == 1
  250. def test_init_db_migrates_legacy_vectors_schema(tmp_path, monkeypatch):
  251. cfg = DataModulesConfig.from_project_root(tmp_path)
  252. cfg.ensure_dirs()
  253. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  254. # 旧结构:缺少 parent_chunk_id/chunk_type/source_file/created_at
  255. with closing(sqlite3.connect(str(cfg.vector_db))) as conn:
  256. cursor = conn.cursor()
  257. cursor.execute(
  258. """
  259. CREATE TABLE vectors (
  260. chunk_id TEXT PRIMARY KEY,
  261. chapter INTEGER,
  262. scene_index INTEGER,
  263. content TEXT,
  264. embedding BLOB
  265. )
  266. """
  267. )
  268. cursor.execute(
  269. """
  270. INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding)
  271. VALUES (?, ?, ?, ?, ?)
  272. """,
  273. ("ch0001_s1", 1, 1, "旧数据", b""),
  274. )
  275. conn.commit()
  276. adapter = RAGAdapter(cfg)
  277. with adapter._get_conn() as conn:
  278. cursor = conn.cursor()
  279. cursor.execute("PRAGMA table_info(vectors)")
  280. cols = {row[1] for row in cursor.fetchall()}
  281. assert {"parent_chunk_id", "chunk_type", "source_file", "created_at"}.issubset(cols)
  282. cursor.execute("SELECT COUNT(*) FROM vectors")
  283. assert cursor.fetchone()[0] == 1
  284. cursor.execute("SELECT chunk_type FROM vectors WHERE chunk_id = ?", ("ch0001_s1",))
  285. row = cursor.fetchone()
  286. assert row is not None
  287. assert row[0] == "scene"
  288. backup_dir = cfg.webnovel_dir / "backups"
  289. backups = list(backup_dir.glob("vectors.db.schema_migration.v*.bak"))
  290. assert backups
  291. def test_rag_adapter_cli(temp_project, monkeypatch, capsys):
  292. # stats
  293. def run_cli(args):
  294. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  295. rag_module.main()
  296. root = str(temp_project.project_root)
  297. run_cli(["--project-root", root, "stats"])
  298. # index-chapter
  299. run_cli(
  300. [
  301. "--project-root",
  302. root,
  303. "index-chapter",
  304. "--chapter",
  305. "1",
  306. "--scenes",
  307. json.dumps([{"index": 1, "summary": "摘要", "content": "内容"}], ensure_ascii=False),
  308. ]
  309. )
  310. # search
  311. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "bm25", "--top-k", "5"])
  312. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "vector", "--top-k", "5"])
  313. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "hybrid", "--top-k", "5"])
  314. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "auto", "--top-k", "5"])
  315. capsys.readouterr()
  316. def test_rag_adapter_log_query_failure_is_reported(temp_project, monkeypatch, caplog):
  317. adapter = RAGAdapter(temp_project)
  318. def _raise_log_error(*args, **kwargs):
  319. raise RuntimeError("log write failed")
  320. monkeypatch.setattr(adapter.index_manager, "log_rag_query", _raise_log_error)
  321. with caplog.at_level(logging.WARNING):
  322. adapter._log_query("q", "vector", [], 1)
  323. message_text = "\n".join(record.getMessage() for record in caplog.records)
  324. assert "failed to log rag query" in message_text
  325. def test_rag_adapter_cli_search_shows_degraded_warning(temp_project, monkeypatch, capsys):
  326. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientAuthFailure())
  327. def run_cli(args):
  328. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  329. rag_module.main()
  330. root = str(temp_project.project_root)
  331. run_cli(["--project-root", root, "search", "--query", "测试", "--mode", "vector", "--top-k", "3"])
  332. captured = capsys.readouterr()
  333. payload = json.loads(captured.out.strip().splitlines()[-1])
  334. assert payload.get("status") == "success"
  335. warnings = payload.get("warnings") or []
  336. assert warnings
  337. assert warnings[0].get("code") == "DEGRADED_MODE"
  338. assert warnings[0].get("reason") == "embedding_auth_failed"