test_rag_adapter.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. RAGAdapter tests
  5. """
  6. import sys
  7. import json
  8. import asyncio
  9. import logging
  10. import pytest
  11. import data_modules.rag_adapter as rag_module
  12. from data_modules.rag_adapter import RAGAdapter
  13. from data_modules.config import DataModulesConfig
  14. class StubClient:
  15. async def embed(self, texts):
  16. return [[1.0, 0.0] for _ in texts]
  17. async def embed_batch(self, texts, skip_failures=True):
  18. return [[1.0, 0.0] for _ in texts]
  19. async def rerank(self, query, documents, top_n=None):
  20. top_n = top_n or len(documents)
  21. return [{"index": i, "relevance_score": 1.0 / (i + 1)} for i in range(min(top_n, len(documents)))]
  22. class StubClientWithFailures(StubClient):
  23. async def embed_batch(self, texts, skip_failures=True):
  24. if len(texts) == 1:
  25. return [None]
  26. return [None, [1.0, 0.0]]
  27. class StubEmbedClient401:
  28. def __init__(self):
  29. self.last_error_status = 401
  30. self.last_error_message = "auth failed"
  31. class StubClientAuthFailure(StubClient):
  32. def __init__(self):
  33. self._embed_client = StubEmbedClient401()
  34. async def embed(self, texts):
  35. return None
  36. @pytest.fixture
  37. def temp_project(tmp_path, monkeypatch):
  38. cfg = DataModulesConfig.from_project_root(tmp_path)
  39. cfg.ensure_dirs()
  40. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  41. return cfg
  42. @pytest.mark.asyncio
  43. async def test_store_and_search(temp_project):
  44. adapter = RAGAdapter(temp_project)
  45. chunks = [
  46. {"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
  47. {"chapter": 1, "scene_index": 2, "content": "药老传授炼药技巧"},
  48. ]
  49. stored = await adapter.store_chunks(chunks)
  50. assert stored == 2
  51. vec_results = await adapter.vector_search("萧炎", top_k=2)
  52. assert len(vec_results) == 2
  53. bm25_results = adapter.bm25_search("萧炎", top_k=2)
  54. assert len(bm25_results) >= 1
  55. stats = adapter.get_stats()
  56. assert stats["vectors"] == 2
  57. @pytest.mark.asyncio
  58. async def test_store_chunks_with_embedding_failure(tmp_path, monkeypatch):
  59. cfg = DataModulesConfig.from_project_root(tmp_path)
  60. cfg.ensure_dirs()
  61. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientWithFailures())
  62. adapter = RAGAdapter(cfg)
  63. chunks = [
  64. {"chapter": 1, "scene_index": 1, "content": "短内容"},
  65. {"chapter": 1, "scene_index": 2, "content": "稍长内容用于索引"},
  66. ]
  67. stored = await adapter.store_chunks(chunks)
  68. assert stored == 1
  69. @pytest.mark.asyncio
  70. async def test_hybrid_search_full_scan(temp_project):
  71. adapter = RAGAdapter(temp_project)
  72. await adapter.store_chunks(
  73. [{"chapter": 1, "scene_index": 1, "content": "萧炎修炼"}]
  74. )
  75. results = await adapter.hybrid_search("萧炎", vector_top_k=5, bm25_top_k=5, rerank_top_n=1)
  76. assert results
  77. assert results[0].source == "hybrid"
  78. @pytest.mark.asyncio
  79. async def test_hybrid_search_prefilter(tmp_path, monkeypatch):
  80. cfg = DataModulesConfig.from_project_root(tmp_path)
  81. cfg.ensure_dirs()
  82. cfg.vector_full_scan_max_vectors = 0
  83. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  84. adapter = RAGAdapter(cfg)
  85. await adapter.store_chunks(
  86. [
  87. {"chapter": 1, "scene_index": 1, "content": "萧炎修炼"},
  88. {"chapter": 2, "scene_index": 1, "content": "药老出场"},
  89. ]
  90. )
  91. results = await adapter.hybrid_search("药老", vector_top_k=2, bm25_top_k=2, rerank_top_n=1)
  92. assert results
  93. @pytest.mark.asyncio
  94. async def test_search_with_backtrack(temp_project):
  95. adapter = RAGAdapter(temp_project)
  96. chunks = [
  97. {
  98. "chapter": 1,
  99. "scene_index": 0,
  100. "content": "章节摘要",
  101. "chunk_type": "summary",
  102. "chunk_id": "ch0001_summary",
  103. "source_file": "summaries/ch0001.md",
  104. },
  105. {
  106. "chapter": 1,
  107. "scene_index": 1,
  108. "content": "场景内容",
  109. "chunk_type": "scene",
  110. "chunk_id": "ch0001_s1",
  111. "parent_chunk_id": "ch0001_summary",
  112. "source_file": "正文/第0001章.md#scene_1",
  113. },
  114. ]
  115. await adapter.store_chunks(chunks)
  116. results = await adapter.search_with_backtrack("场景", top_k=1)
  117. assert any(r.chunk_type == "summary" for r in results)
  118. def test_vector_helpers(temp_project):
  119. adapter = RAGAdapter(temp_project)
  120. emb = [1.0, 0.0]
  121. data = adapter._serialize_embedding(emb)
  122. assert adapter._deserialize_embedding(data) == emb
  123. assert adapter._cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
  124. def test_recent_and_fetch_vectors(temp_project):
  125. adapter = RAGAdapter(temp_project)
  126. with adapter._get_conn() as conn:
  127. cursor = conn.cursor()
  128. cursor.execute(
  129. "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  130. ("ch0001_s1", 1, 1, "内容", b"", None, "scene", "正文/第0001章.md#scene_1"),
  131. )
  132. conn.commit()
  133. assert adapter._get_vectors_count() == 1
  134. assert adapter._get_recent_chunk_ids(1) == ["ch0001_s1"]
  135. rows = adapter._fetch_vectors_by_chunk_ids(["ch0001_s1"])
  136. assert len(rows) == 1
  137. def test_rag_adapter_cli(temp_project, monkeypatch, capsys):
  138. # stats
  139. def run_cli(args):
  140. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  141. rag_module.main()
  142. root = str(temp_project.project_root)
  143. run_cli(["--project-root", root, "stats"])
  144. # index-chapter
  145. run_cli(
  146. [
  147. "--project-root",
  148. root,
  149. "index-chapter",
  150. "--chapter",
  151. "1",
  152. "--scenes",
  153. json.dumps([{"index": 1, "summary": "摘要", "content": "内容"}], ensure_ascii=False),
  154. ]
  155. )
  156. # search
  157. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "bm25", "--top-k", "5"])
  158. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "vector", "--top-k", "5"])
  159. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "hybrid", "--top-k", "5"])
  160. capsys.readouterr()
  161. def test_rag_adapter_log_query_failure_is_reported(temp_project, monkeypatch, caplog):
  162. adapter = RAGAdapter(temp_project)
  163. def _raise_log_error(*args, **kwargs):
  164. raise RuntimeError("log write failed")
  165. monkeypatch.setattr(adapter.index_manager, "log_rag_query", _raise_log_error)
  166. with caplog.at_level(logging.WARNING):
  167. adapter._log_query("q", "vector", [], 1)
  168. message_text = "\n".join(record.getMessage() for record in caplog.records)
  169. assert "failed to log rag query" in message_text
  170. def test_rag_adapter_cli_search_shows_degraded_warning(temp_project, monkeypatch, capsys):
  171. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientAuthFailure())
  172. def run_cli(args):
  173. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  174. rag_module.main()
  175. root = str(temp_project.project_root)
  176. run_cli(["--project-root", root, "search", "--query", "测试", "--mode", "vector", "--top-k", "3"])
  177. captured = capsys.readouterr()
  178. payload = json.loads(captured.out.strip().splitlines()[-1])
  179. assert payload.get("status") == "success"
  180. warnings = payload.get("warnings") or []
  181. assert warnings
  182. assert warnings[0].get("code") == "DEGRADED_MODE"
  183. assert warnings[0].get("reason") == "embedding_auth_failed"