test_rag_adapter.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. RAGAdapter tests
  5. """
  6. import sys
  7. import json
  8. import asyncio
  9. import pytest
  10. import data_modules.rag_adapter as rag_module
  11. from data_modules.rag_adapter import RAGAdapter
  12. from data_modules.config import DataModulesConfig
  13. class StubClient:
  14. async def embed(self, texts):
  15. return [[1.0, 0.0] for _ in texts]
  16. async def embed_batch(self, texts, skip_failures=True):
  17. return [[1.0, 0.0] for _ in texts]
  18. async def rerank(self, query, documents, top_n=None):
  19. top_n = top_n or len(documents)
  20. return [{"index": i, "relevance_score": 1.0 / (i + 1)} for i in range(min(top_n, len(documents)))]
  21. class StubClientWithFailures(StubClient):
  22. async def embed_batch(self, texts, skip_failures=True):
  23. if len(texts) == 1:
  24. return [None]
  25. return [None, [1.0, 0.0]]
  26. @pytest.fixture
  27. def temp_project(tmp_path, monkeypatch):
  28. cfg = DataModulesConfig.from_project_root(tmp_path)
  29. cfg.ensure_dirs()
  30. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  31. return cfg
  32. @pytest.mark.asyncio
  33. async def test_store_and_search(temp_project):
  34. adapter = RAGAdapter(temp_project)
  35. chunks = [
  36. {"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
  37. {"chapter": 1, "scene_index": 2, "content": "药老传授炼药技巧"},
  38. ]
  39. stored = await adapter.store_chunks(chunks)
  40. assert stored == 2
  41. vec_results = await adapter.vector_search("萧炎", top_k=2)
  42. assert len(vec_results) == 2
  43. bm25_results = adapter.bm25_search("萧炎", top_k=2)
  44. assert len(bm25_results) >= 1
  45. stats = adapter.get_stats()
  46. assert stats["vectors"] == 2
  47. @pytest.mark.asyncio
  48. async def test_store_chunks_with_embedding_failure(tmp_path, monkeypatch):
  49. cfg = DataModulesConfig.from_project_root(tmp_path)
  50. cfg.ensure_dirs()
  51. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientWithFailures())
  52. adapter = RAGAdapter(cfg)
  53. chunks = [
  54. {"chapter": 1, "scene_index": 1, "content": "短内容"},
  55. {"chapter": 1, "scene_index": 2, "content": "稍长内容用于索引"},
  56. ]
  57. stored = await adapter.store_chunks(chunks)
  58. assert stored == 1
  59. @pytest.mark.asyncio
  60. async def test_hybrid_search_full_scan(temp_project):
  61. adapter = RAGAdapter(temp_project)
  62. await adapter.store_chunks(
  63. [{"chapter": 1, "scene_index": 1, "content": "萧炎修炼"}]
  64. )
  65. results = await adapter.hybrid_search("萧炎", vector_top_k=5, bm25_top_k=5, rerank_top_n=1)
  66. assert results
  67. assert results[0].source == "hybrid"
  68. @pytest.mark.asyncio
  69. async def test_hybrid_search_prefilter(tmp_path, monkeypatch):
  70. cfg = DataModulesConfig.from_project_root(tmp_path)
  71. cfg.ensure_dirs()
  72. cfg.vector_full_scan_max_vectors = 0
  73. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  74. adapter = RAGAdapter(cfg)
  75. await adapter.store_chunks(
  76. [
  77. {"chapter": 1, "scene_index": 1, "content": "萧炎修炼"},
  78. {"chapter": 2, "scene_index": 1, "content": "药老出场"},
  79. ]
  80. )
  81. results = await adapter.hybrid_search("药老", vector_top_k=2, bm25_top_k=2, rerank_top_n=1)
  82. assert results
  83. @pytest.mark.asyncio
  84. async def test_search_with_backtrack(temp_project):
  85. adapter = RAGAdapter(temp_project)
  86. chunks = [
  87. {
  88. "chapter": 1,
  89. "scene_index": 0,
  90. "content": "章节摘要",
  91. "chunk_type": "summary",
  92. "chunk_id": "ch0001_summary",
  93. "source_file": "summaries/ch0001.md",
  94. },
  95. {
  96. "chapter": 1,
  97. "scene_index": 1,
  98. "content": "场景内容",
  99. "chunk_type": "scene",
  100. "chunk_id": "ch0001_s1",
  101. "parent_chunk_id": "ch0001_summary",
  102. "source_file": "正文/第0001章.md#scene_1",
  103. },
  104. ]
  105. await adapter.store_chunks(chunks)
  106. results = await adapter.search_with_backtrack("场景", top_k=1)
  107. assert any(r.chunk_type == "summary" for r in results)
  108. def test_vector_helpers(temp_project):
  109. adapter = RAGAdapter(temp_project)
  110. emb = [1.0, 0.0]
  111. data = adapter._serialize_embedding(emb)
  112. assert adapter._deserialize_embedding(data) == emb
  113. assert adapter._cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
  114. def test_recent_and_fetch_vectors(temp_project):
  115. adapter = RAGAdapter(temp_project)
  116. with adapter._get_conn() as conn:
  117. cursor = conn.cursor()
  118. cursor.execute(
  119. "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  120. ("ch0001_s1", 1, 1, "内容", b"", None, "scene", "正文/第0001章.md#scene_1"),
  121. )
  122. conn.commit()
  123. assert adapter._get_vectors_count() == 1
  124. assert adapter._get_recent_chunk_ids(1) == ["ch0001_s1"]
  125. rows = adapter._fetch_vectors_by_chunk_ids(["ch0001_s1"])
  126. assert len(rows) == 1
  127. def test_rag_adapter_cli(temp_project, monkeypatch, capsys):
  128. # stats
  129. def run_cli(args):
  130. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  131. rag_module.main()
  132. root = str(temp_project.project_root)
  133. run_cli(["--project-root", root, "stats"])
  134. # index-chapter
  135. run_cli(
  136. [
  137. "--project-root",
  138. root,
  139. "index-chapter",
  140. "--chapter",
  141. "1",
  142. "--scenes",
  143. json.dumps([{"index": 1, "summary": "摘要", "content": "内容"}], ensure_ascii=False),
  144. ]
  145. )
  146. # search
  147. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "bm25", "--top-k", "5"])
  148. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "vector", "--top-k", "5"])
  149. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "hybrid", "--top-k", "5"])
  150. capsys.readouterr()