test_rag_adapter.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. RAGAdapter tests
  5. """
  6. import sys
  7. import json
  8. import asyncio
  9. import logging
  10. import sqlite3
  11. import pytest
  12. import data_modules.rag_adapter as rag_module
  13. from data_modules.rag_adapter import RAGAdapter
  14. from data_modules.config import DataModulesConfig
  15. class StubClient:
  16. async def embed(self, texts):
  17. return [[1.0, 0.0] for _ in texts]
  18. async def embed_batch(self, texts, skip_failures=True):
  19. return [[1.0, 0.0] for _ in texts]
  20. async def rerank(self, query, documents, top_n=None):
  21. top_n = top_n or len(documents)
  22. return [{"index": i, "relevance_score": 1.0 / (i + 1)} for i in range(min(top_n, len(documents)))]
  23. class StubClientWithFailures(StubClient):
  24. async def embed_batch(self, texts, skip_failures=True):
  25. if len(texts) == 1:
  26. return [None]
  27. return [None, [1.0, 0.0]]
  28. class StubEmbedClient401:
  29. def __init__(self):
  30. self.last_error_status = 401
  31. self.last_error_message = "auth failed"
  32. class StubClientAuthFailure(StubClient):
  33. def __init__(self):
  34. self._embed_client = StubEmbedClient401()
  35. async def embed(self, texts):
  36. return None
  37. @pytest.fixture
  38. def temp_project(tmp_path, monkeypatch):
  39. cfg = DataModulesConfig.from_project_root(tmp_path)
  40. cfg.ensure_dirs()
  41. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  42. return cfg
  43. @pytest.mark.asyncio
  44. async def test_store_and_search(temp_project):
  45. adapter = RAGAdapter(temp_project)
  46. chunks = [
  47. {"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
  48. {"chapter": 1, "scene_index": 2, "content": "药老传授炼药技巧"},
  49. ]
  50. stored = await adapter.store_chunks(chunks)
  51. assert stored == 2
  52. vec_results = await adapter.vector_search("萧炎", top_k=2)
  53. assert len(vec_results) == 2
  54. bm25_results = adapter.bm25_search("萧炎", top_k=2)
  55. assert len(bm25_results) >= 1
  56. stats = adapter.get_stats()
  57. assert stats["vectors"] == 2
  58. @pytest.mark.asyncio
  59. async def test_store_chunks_with_embedding_failure(tmp_path, monkeypatch):
  60. cfg = DataModulesConfig.from_project_root(tmp_path)
  61. cfg.ensure_dirs()
  62. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientWithFailures())
  63. adapter = RAGAdapter(cfg)
  64. chunks = [
  65. {"chapter": 1, "scene_index": 1, "content": "短内容"},
  66. {"chapter": 1, "scene_index": 2, "content": "稍长内容用于索引"},
  67. ]
  68. stored = await adapter.store_chunks(chunks)
  69. assert stored == 1
  70. @pytest.mark.asyncio
  71. async def test_hybrid_search_full_scan(temp_project):
  72. adapter = RAGAdapter(temp_project)
  73. await adapter.store_chunks(
  74. [{"chapter": 1, "scene_index": 1, "content": "萧炎修炼"}]
  75. )
  76. results = await adapter.hybrid_search("萧炎", vector_top_k=5, bm25_top_k=5, rerank_top_n=1)
  77. assert results
  78. assert results[0].source == "hybrid"
  79. @pytest.mark.asyncio
  80. async def test_hybrid_search_prefilter(tmp_path, monkeypatch):
  81. cfg = DataModulesConfig.from_project_root(tmp_path)
  82. cfg.ensure_dirs()
  83. cfg.vector_full_scan_max_vectors = 0
  84. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  85. adapter = RAGAdapter(cfg)
  86. await adapter.store_chunks(
  87. [
  88. {"chapter": 1, "scene_index": 1, "content": "萧炎修炼"},
  89. {"chapter": 2, "scene_index": 1, "content": "药老出场"},
  90. ]
  91. )
  92. results = await adapter.hybrid_search("药老", vector_top_k=2, bm25_top_k=2, rerank_top_n=1)
  93. assert results
  94. @pytest.mark.asyncio
  95. async def test_search_with_backtrack(temp_project):
  96. adapter = RAGAdapter(temp_project)
  97. chunks = [
  98. {
  99. "chapter": 1,
  100. "scene_index": 0,
  101. "content": "章节摘要",
  102. "chunk_type": "summary",
  103. "chunk_id": "ch0001_summary",
  104. "source_file": "summaries/ch0001.md",
  105. },
  106. {
  107. "chapter": 1,
  108. "scene_index": 1,
  109. "content": "场景内容",
  110. "chunk_type": "scene",
  111. "chunk_id": "ch0001_s1",
  112. "parent_chunk_id": "ch0001_summary",
  113. "source_file": "正文/第0001章.md#scene_1",
  114. },
  115. ]
  116. await adapter.store_chunks(chunks)
  117. results = await adapter.search_with_backtrack("场景", top_k=1)
  118. assert any(r.chunk_type == "summary" for r in results)
  119. def test_vector_helpers(temp_project):
  120. adapter = RAGAdapter(temp_project)
  121. emb = [1.0, 0.0]
  122. data = adapter._serialize_embedding(emb)
  123. assert adapter._deserialize_embedding(data) == emb
  124. assert adapter._cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
  125. def test_recent_and_fetch_vectors(temp_project):
  126. adapter = RAGAdapter(temp_project)
  127. with adapter._get_conn() as conn:
  128. cursor = conn.cursor()
  129. cursor.execute(
  130. "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  131. ("ch0001_s1", 1, 1, "内容", b"", None, "scene", "正文/第0001章.md#scene_1"),
  132. )
  133. conn.commit()
  134. assert adapter._get_vectors_count() == 1
  135. assert adapter._get_recent_chunk_ids(1) == ["ch0001_s1"]
  136. rows = adapter._fetch_vectors_by_chunk_ids(["ch0001_s1"])
  137. assert len(rows) == 1
  138. def test_init_db_migrates_legacy_vectors_schema(tmp_path, monkeypatch):
  139. cfg = DataModulesConfig.from_project_root(tmp_path)
  140. cfg.ensure_dirs()
  141. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  142. # 旧结构:缺少 parent_chunk_id/chunk_type/source_file/created_at
  143. with sqlite3.connect(str(cfg.vector_db)) as conn:
  144. cursor = conn.cursor()
  145. cursor.execute(
  146. """
  147. CREATE TABLE vectors (
  148. chunk_id TEXT PRIMARY KEY,
  149. chapter INTEGER,
  150. scene_index INTEGER,
  151. content TEXT,
  152. embedding BLOB
  153. )
  154. """
  155. )
  156. cursor.execute(
  157. """
  158. INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding)
  159. VALUES (?, ?, ?, ?, ?)
  160. """,
  161. ("ch0001_s1", 1, 1, "旧数据", b""),
  162. )
  163. conn.commit()
  164. adapter = RAGAdapter(cfg)
  165. with adapter._get_conn() as conn:
  166. cursor = conn.cursor()
  167. cursor.execute("PRAGMA table_info(vectors)")
  168. cols = {row[1] for row in cursor.fetchall()}
  169. assert {"parent_chunk_id", "chunk_type", "source_file", "created_at"}.issubset(cols)
  170. cursor.execute("SELECT COUNT(*) FROM vectors")
  171. assert cursor.fetchone()[0] == 1
  172. cursor.execute("SELECT chunk_type FROM vectors WHERE chunk_id = ?", ("ch0001_s1",))
  173. row = cursor.fetchone()
  174. assert row is not None
  175. assert row[0] == "scene"
  176. backup_dir = cfg.webnovel_dir / "backups"
  177. backups = list(backup_dir.glob("vectors.db.schema_migration.v*.bak"))
  178. assert backups
  179. def test_rag_adapter_cli(temp_project, monkeypatch, capsys):
  180. # stats
  181. def run_cli(args):
  182. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  183. rag_module.main()
  184. root = str(temp_project.project_root)
  185. run_cli(["--project-root", root, "stats"])
  186. # index-chapter
  187. run_cli(
  188. [
  189. "--project-root",
  190. root,
  191. "index-chapter",
  192. "--chapter",
  193. "1",
  194. "--scenes",
  195. json.dumps([{"index": 1, "summary": "摘要", "content": "内容"}], ensure_ascii=False),
  196. ]
  197. )
  198. # search
  199. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "bm25", "--top-k", "5"])
  200. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "vector", "--top-k", "5"])
  201. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "hybrid", "--top-k", "5"])
  202. capsys.readouterr()
  203. def test_rag_adapter_log_query_failure_is_reported(temp_project, monkeypatch, caplog):
  204. adapter = RAGAdapter(temp_project)
  205. def _raise_log_error(*args, **kwargs):
  206. raise RuntimeError("log write failed")
  207. monkeypatch.setattr(adapter.index_manager, "log_rag_query", _raise_log_error)
  208. with caplog.at_level(logging.WARNING):
  209. adapter._log_query("q", "vector", [], 1)
  210. message_text = "\n".join(record.getMessage() for record in caplog.records)
  211. assert "failed to log rag query" in message_text
  212. def test_rag_adapter_cli_search_shows_degraded_warning(temp_project, monkeypatch, capsys):
  213. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientAuthFailure())
  214. def run_cli(args):
  215. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  216. rag_module.main()
  217. root = str(temp_project.project_root)
  218. run_cli(["--project-root", root, "search", "--query", "测试", "--mode", "vector", "--top-k", "3"])
  219. captured = capsys.readouterr()
  220. payload = json.loads(captured.out.strip().splitlines()[-1])
  221. assert payload.get("status") == "success"
  222. warnings = payload.get("warnings") or []
  223. assert warnings
  224. assert warnings[0].get("code") == "DEGRADED_MODE"
  225. assert warnings[0].get("reason") == "embedding_auth_failed"