test_rag_adapter.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. RAGAdapter tests
  5. """
  6. import sys
  7. import json
  8. import asyncio
  9. import logging
  10. import pytest
  11. import data_modules.rag_adapter as rag_module
  12. from data_modules.rag_adapter import RAGAdapter
  13. from data_modules.config import DataModulesConfig
  14. class StubClient:
  15. async def embed(self, texts):
  16. return [[1.0, 0.0] for _ in texts]
  17. async def embed_batch(self, texts, skip_failures=True):
  18. return [[1.0, 0.0] for _ in texts]
  19. async def rerank(self, query, documents, top_n=None):
  20. top_n = top_n or len(documents)
  21. return [{"index": i, "relevance_score": 1.0 / (i + 1)} for i in range(min(top_n, len(documents)))]
  22. class StubClientWithFailures(StubClient):
  23. async def embed_batch(self, texts, skip_failures=True):
  24. if len(texts) == 1:
  25. return [None]
  26. return [None, [1.0, 0.0]]
  27. @pytest.fixture
  28. def temp_project(tmp_path, monkeypatch):
  29. cfg = DataModulesConfig.from_project_root(tmp_path)
  30. cfg.ensure_dirs()
  31. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  32. return cfg
  33. @pytest.mark.asyncio
  34. async def test_store_and_search(temp_project):
  35. adapter = RAGAdapter(temp_project)
  36. chunks = [
  37. {"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
  38. {"chapter": 1, "scene_index": 2, "content": "药老传授炼药技巧"},
  39. ]
  40. stored = await adapter.store_chunks(chunks)
  41. assert stored == 2
  42. vec_results = await adapter.vector_search("萧炎", top_k=2)
  43. assert len(vec_results) == 2
  44. bm25_results = adapter.bm25_search("萧炎", top_k=2)
  45. assert len(bm25_results) >= 1
  46. stats = adapter.get_stats()
  47. assert stats["vectors"] == 2
  48. @pytest.mark.asyncio
  49. async def test_store_chunks_with_embedding_failure(tmp_path, monkeypatch):
  50. cfg = DataModulesConfig.from_project_root(tmp_path)
  51. cfg.ensure_dirs()
  52. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientWithFailures())
  53. adapter = RAGAdapter(cfg)
  54. chunks = [
  55. {"chapter": 1, "scene_index": 1, "content": "短内容"},
  56. {"chapter": 1, "scene_index": 2, "content": "稍长内容用于索引"},
  57. ]
  58. stored = await adapter.store_chunks(chunks)
  59. assert stored == 1
  60. @pytest.mark.asyncio
  61. async def test_hybrid_search_full_scan(temp_project):
  62. adapter = RAGAdapter(temp_project)
  63. await adapter.store_chunks(
  64. [{"chapter": 1, "scene_index": 1, "content": "萧炎修炼"}]
  65. )
  66. results = await adapter.hybrid_search("萧炎", vector_top_k=5, bm25_top_k=5, rerank_top_n=1)
  67. assert results
  68. assert results[0].source == "hybrid"
  69. @pytest.mark.asyncio
  70. async def test_hybrid_search_prefilter(tmp_path, monkeypatch):
  71. cfg = DataModulesConfig.from_project_root(tmp_path)
  72. cfg.ensure_dirs()
  73. cfg.vector_full_scan_max_vectors = 0
  74. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  75. adapter = RAGAdapter(cfg)
  76. await adapter.store_chunks(
  77. [
  78. {"chapter": 1, "scene_index": 1, "content": "萧炎修炼"},
  79. {"chapter": 2, "scene_index": 1, "content": "药老出场"},
  80. ]
  81. )
  82. results = await adapter.hybrid_search("药老", vector_top_k=2, bm25_top_k=2, rerank_top_n=1)
  83. assert results
  84. @pytest.mark.asyncio
  85. async def test_search_with_backtrack(temp_project):
  86. adapter = RAGAdapter(temp_project)
  87. chunks = [
  88. {
  89. "chapter": 1,
  90. "scene_index": 0,
  91. "content": "章节摘要",
  92. "chunk_type": "summary",
  93. "chunk_id": "ch0001_summary",
  94. "source_file": "summaries/ch0001.md",
  95. },
  96. {
  97. "chapter": 1,
  98. "scene_index": 1,
  99. "content": "场景内容",
  100. "chunk_type": "scene",
  101. "chunk_id": "ch0001_s1",
  102. "parent_chunk_id": "ch0001_summary",
  103. "source_file": "正文/第0001章.md#scene_1",
  104. },
  105. ]
  106. await adapter.store_chunks(chunks)
  107. results = await adapter.search_with_backtrack("场景", top_k=1)
  108. assert any(r.chunk_type == "summary" for r in results)
  109. def test_vector_helpers(temp_project):
  110. adapter = RAGAdapter(temp_project)
  111. emb = [1.0, 0.0]
  112. data = adapter._serialize_embedding(emb)
  113. assert adapter._deserialize_embedding(data) == emb
  114. assert adapter._cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
  115. def test_recent_and_fetch_vectors(temp_project):
  116. adapter = RAGAdapter(temp_project)
  117. with adapter._get_conn() as conn:
  118. cursor = conn.cursor()
  119. cursor.execute(
  120. "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  121. ("ch0001_s1", 1, 1, "内容", b"", None, "scene", "正文/第0001章.md#scene_1"),
  122. )
  123. conn.commit()
  124. assert adapter._get_vectors_count() == 1
  125. assert adapter._get_recent_chunk_ids(1) == ["ch0001_s1"]
  126. rows = adapter._fetch_vectors_by_chunk_ids(["ch0001_s1"])
  127. assert len(rows) == 1
  128. def test_rag_adapter_cli(temp_project, monkeypatch, capsys):
  129. # stats
  130. def run_cli(args):
  131. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  132. rag_module.main()
  133. root = str(temp_project.project_root)
  134. run_cli(["--project-root", root, "stats"])
  135. # index-chapter
  136. run_cli(
  137. [
  138. "--project-root",
  139. root,
  140. "index-chapter",
  141. "--chapter",
  142. "1",
  143. "--scenes",
  144. json.dumps([{"index": 1, "summary": "摘要", "content": "内容"}], ensure_ascii=False),
  145. ]
  146. )
  147. # search
  148. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "bm25", "--top-k", "5"])
  149. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "vector", "--top-k", "5"])
  150. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "hybrid", "--top-k", "5"])
  151. capsys.readouterr()
  152. def test_rag_adapter_log_query_failure_is_reported(temp_project, monkeypatch, caplog):
  153. adapter = RAGAdapter(temp_project)
  154. def _raise_log_error(*args, **kwargs):
  155. raise RuntimeError("log write failed")
  156. monkeypatch.setattr(adapter.index_manager, "log_rag_query", _raise_log_error)
  157. with caplog.at_level(logging.WARNING):
  158. adapter._log_query("q", "vector", [], 1)
  159. message_text = "\n".join(record.getMessage() for record in caplog.records)
  160. assert "failed to log rag query" in message_text