test_rag_adapter.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. RAGAdapter tests
  5. """
  6. import sys
  7. import json
  8. import asyncio
  9. import pytest
  10. import data_modules.rag_adapter as rag_module
  11. from data_modules.rag_adapter import RAGAdapter
  12. from data_modules.config import DataModulesConfig
  13. class StubClient:
  14. async def embed(self, texts):
  15. return [[1.0, 0.0] for _ in texts]
  16. async def embed_batch(self, texts, skip_failures=True):
  17. return [[1.0, 0.0] for _ in texts]
  18. async def rerank(self, query, documents, top_n=None):
  19. top_n = top_n or len(documents)
  20. return [{"index": i, "relevance_score": 1.0 / (i + 1)} for i in range(min(top_n, len(documents)))]
  21. class StubClientWithFailures(StubClient):
  22. async def embed_batch(self, texts, skip_failures=True):
  23. if len(texts) == 1:
  24. return [None]
  25. return [None, [1.0, 0.0]]
  26. @pytest.fixture
  27. def temp_project(tmp_path, monkeypatch):
  28. cfg = DataModulesConfig.from_project_root(tmp_path)
  29. cfg.ensure_dirs()
  30. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  31. return cfg
  32. @pytest.mark.asyncio
  33. async def test_store_and_search(temp_project):
  34. adapter = RAGAdapter(temp_project)
  35. chunks = [
  36. {"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
  37. {"chapter": 1, "scene_index": 2, "content": "药老传授炼药技巧"},
  38. ]
  39. stored = await adapter.store_chunks(chunks)
  40. assert stored == 2
  41. vec_results = await adapter.vector_search("萧炎", top_k=2)
  42. assert len(vec_results) == 2
  43. bm25_results = adapter.bm25_search("萧炎", top_k=2)
  44. assert len(bm25_results) >= 1
  45. stats = adapter.get_stats()
  46. assert stats["vectors"] == 2
  47. @pytest.mark.asyncio
  48. async def test_store_chunks_with_embedding_failure(tmp_path, monkeypatch):
  49. cfg = DataModulesConfig.from_project_root(tmp_path)
  50. cfg.ensure_dirs()
  51. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientWithFailures())
  52. adapter = RAGAdapter(cfg)
  53. chunks = [
  54. {"chapter": 1, "scene_index": 1, "content": "短内容"},
  55. {"chapter": 1, "scene_index": 2, "content": "稍长内容用于索引"},
  56. ]
  57. stored = await adapter.store_chunks(chunks)
  58. assert stored == 1
  59. @pytest.mark.asyncio
  60. async def test_hybrid_search_full_scan(temp_project):
  61. adapter = RAGAdapter(temp_project)
  62. await adapter.store_chunks(
  63. [{"chapter": 1, "scene_index": 1, "content": "萧炎修炼"}]
  64. )
  65. results = await adapter.hybrid_search("萧炎", vector_top_k=5, bm25_top_k=5, rerank_top_n=1)
  66. assert results
  67. assert results[0].source == "hybrid"
  68. @pytest.mark.asyncio
  69. async def test_hybrid_search_prefilter(tmp_path, monkeypatch):
  70. cfg = DataModulesConfig.from_project_root(tmp_path)
  71. cfg.ensure_dirs()
  72. cfg.vector_full_scan_max_vectors = 0
  73. monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
  74. adapter = RAGAdapter(cfg)
  75. await adapter.store_chunks(
  76. [
  77. {"chapter": 1, "scene_index": 1, "content": "萧炎修炼"},
  78. {"chapter": 2, "scene_index": 1, "content": "药老出场"},
  79. ]
  80. )
  81. results = await adapter.hybrid_search("药老", vector_top_k=2, bm25_top_k=2, rerank_top_n=1)
  82. assert results
  83. def test_vector_helpers(temp_project):
  84. adapter = RAGAdapter(temp_project)
  85. emb = [1.0, 0.0]
  86. data = adapter._serialize_embedding(emb)
  87. assert adapter._deserialize_embedding(data) == emb
  88. assert adapter._cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
  89. def test_recent_and_fetch_vectors(temp_project):
  90. adapter = RAGAdapter(temp_project)
  91. with adapter._get_conn() as conn:
  92. cursor = conn.cursor()
  93. cursor.execute(
  94. "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding) VALUES (?, ?, ?, ?, ?)",
  95. ("ch1_s1", 1, 1, "内容", b""),
  96. )
  97. conn.commit()
  98. assert adapter._get_vectors_count() == 1
  99. assert adapter._get_recent_chunk_ids(1) == ["ch1_s1"]
  100. rows = adapter._fetch_vectors_by_chunk_ids(["ch1_s1"])
  101. assert len(rows) == 1
  102. def test_rag_adapter_cli(temp_project, monkeypatch, capsys):
  103. # stats
  104. def run_cli(args):
  105. monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
  106. rag_module.main()
  107. root = str(temp_project.project_root)
  108. run_cli(["--project-root", root, "stats"])
  109. # index-chapter
  110. run_cli(
  111. [
  112. "--project-root",
  113. root,
  114. "index-chapter",
  115. "--chapter",
  116. "1",
  117. "--scenes",
  118. json.dumps([{"index": 1, "summary": "摘要", "content": "内容"}], ensure_ascii=False),
  119. ]
  120. )
  121. # search
  122. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "bm25", "--top-k", "5"])
  123. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "vector", "--top-k", "5"])
  124. run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "hybrid", "--top-k", "5"])
  125. capsys.readouterr()