| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- RAGAdapter tests
- """
- import sys
- import json
- import asyncio
- import logging
- import pytest
- import data_modules.rag_adapter as rag_module
- from data_modules.rag_adapter import RAGAdapter
- from data_modules.config import DataModulesConfig
- class StubClient:
- async def embed(self, texts):
- return [[1.0, 0.0] for _ in texts]
- async def embed_batch(self, texts, skip_failures=True):
- return [[1.0, 0.0] for _ in texts]
- async def rerank(self, query, documents, top_n=None):
- top_n = top_n or len(documents)
- return [{"index": i, "relevance_score": 1.0 / (i + 1)} for i in range(min(top_n, len(documents)))]
- class StubClientWithFailures(StubClient):
- async def embed_batch(self, texts, skip_failures=True):
- if len(texts) == 1:
- return [None]
- return [None, [1.0, 0.0]]
- class StubEmbedClient401:
- def __init__(self):
- self.last_error_status = 401
- self.last_error_message = "auth failed"
- class StubClientAuthFailure(StubClient):
- def __init__(self):
- self._embed_client = StubEmbedClient401()
- async def embed(self, texts):
- return None
- @pytest.fixture
- def temp_project(tmp_path, monkeypatch):
- cfg = DataModulesConfig.from_project_root(tmp_path)
- cfg.ensure_dirs()
- monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
- return cfg
- @pytest.mark.asyncio
- async def test_store_and_search(temp_project):
- adapter = RAGAdapter(temp_project)
- chunks = [
- {"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"},
- {"chapter": 1, "scene_index": 2, "content": "药老传授炼药技巧"},
- ]
- stored = await adapter.store_chunks(chunks)
- assert stored == 2
- vec_results = await adapter.vector_search("萧炎", top_k=2)
- assert len(vec_results) == 2
- bm25_results = adapter.bm25_search("萧炎", top_k=2)
- assert len(bm25_results) >= 1
- stats = adapter.get_stats()
- assert stats["vectors"] == 2
- @pytest.mark.asyncio
- async def test_store_chunks_with_embedding_failure(tmp_path, monkeypatch):
- cfg = DataModulesConfig.from_project_root(tmp_path)
- cfg.ensure_dirs()
- monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientWithFailures())
- adapter = RAGAdapter(cfg)
- chunks = [
- {"chapter": 1, "scene_index": 1, "content": "短内容"},
- {"chapter": 1, "scene_index": 2, "content": "稍长内容用于索引"},
- ]
- stored = await adapter.store_chunks(chunks)
- assert stored == 1
- @pytest.mark.asyncio
- async def test_hybrid_search_full_scan(temp_project):
- adapter = RAGAdapter(temp_project)
- await adapter.store_chunks(
- [{"chapter": 1, "scene_index": 1, "content": "萧炎修炼"}]
- )
- results = await adapter.hybrid_search("萧炎", vector_top_k=5, bm25_top_k=5, rerank_top_n=1)
- assert results
- assert results[0].source == "hybrid"
- @pytest.mark.asyncio
- async def test_hybrid_search_prefilter(tmp_path, monkeypatch):
- cfg = DataModulesConfig.from_project_root(tmp_path)
- cfg.ensure_dirs()
- cfg.vector_full_scan_max_vectors = 0
- monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient())
- adapter = RAGAdapter(cfg)
- await adapter.store_chunks(
- [
- {"chapter": 1, "scene_index": 1, "content": "萧炎修炼"},
- {"chapter": 2, "scene_index": 1, "content": "药老出场"},
- ]
- )
- results = await adapter.hybrid_search("药老", vector_top_k=2, bm25_top_k=2, rerank_top_n=1)
- assert results
- @pytest.mark.asyncio
- async def test_search_with_backtrack(temp_project):
- adapter = RAGAdapter(temp_project)
- chunks = [
- {
- "chapter": 1,
- "scene_index": 0,
- "content": "章节摘要",
- "chunk_type": "summary",
- "chunk_id": "ch0001_summary",
- "source_file": "summaries/ch0001.md",
- },
- {
- "chapter": 1,
- "scene_index": 1,
- "content": "场景内容",
- "chunk_type": "scene",
- "chunk_id": "ch0001_s1",
- "parent_chunk_id": "ch0001_summary",
- "source_file": "正文/第0001章.md#scene_1",
- },
- ]
- await adapter.store_chunks(chunks)
- results = await adapter.search_with_backtrack("场景", top_k=1)
- assert any(r.chunk_type == "summary" for r in results)
- def test_vector_helpers(temp_project):
- adapter = RAGAdapter(temp_project)
- emb = [1.0, 0.0]
- data = adapter._serialize_embedding(emb)
- assert adapter._deserialize_embedding(data) == emb
- assert adapter._cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
- def test_recent_and_fetch_vectors(temp_project):
- adapter = RAGAdapter(temp_project)
- with adapter._get_conn() as conn:
- cursor = conn.cursor()
- cursor.execute(
- "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
- ("ch0001_s1", 1, 1, "内容", b"", None, "scene", "正文/第0001章.md#scene_1"),
- )
- conn.commit()
- assert adapter._get_vectors_count() == 1
- assert adapter._get_recent_chunk_ids(1) == ["ch0001_s1"]
- rows = adapter._fetch_vectors_by_chunk_ids(["ch0001_s1"])
- assert len(rows) == 1
- def test_rag_adapter_cli(temp_project, monkeypatch, capsys):
- # stats
- def run_cli(args):
- monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
- rag_module.main()
- root = str(temp_project.project_root)
- run_cli(["--project-root", root, "stats"])
- # index-chapter
- run_cli(
- [
- "--project-root",
- root,
- "index-chapter",
- "--chapter",
- "1",
- "--scenes",
- json.dumps([{"index": 1, "summary": "摘要", "content": "内容"}], ensure_ascii=False),
- ]
- )
- # search
- run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "bm25", "--top-k", "5"])
- run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "vector", "--top-k", "5"])
- run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "hybrid", "--top-k", "5"])
- capsys.readouterr()
- def test_rag_adapter_log_query_failure_is_reported(temp_project, monkeypatch, caplog):
- adapter = RAGAdapter(temp_project)
- def _raise_log_error(*args, **kwargs):
- raise RuntimeError("log write failed")
- monkeypatch.setattr(adapter.index_manager, "log_rag_query", _raise_log_error)
- with caplog.at_level(logging.WARNING):
- adapter._log_query("q", "vector", [], 1)
- message_text = "\n".join(record.getMessage() for record in caplog.records)
- assert "failed to log rag query" in message_text
- def test_rag_adapter_cli_search_shows_degraded_warning(temp_project, monkeypatch, capsys):
- monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientAuthFailure())
- def run_cli(args):
- monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args)
- rag_module.main()
- root = str(temp_project.project_root)
- run_cli(["--project-root", root, "search", "--query", "测试", "--mode", "vector", "--top-k", "3"])
- captured = capsys.readouterr()
- payload = json.loads(captured.out.strip().splitlines()[-1])
- assert payload.get("status") == "success"
- warnings = payload.get("warnings") or []
- assert warnings
- assert warnings[0].get("code") == "DEGRADED_MODE"
- assert warnings[0].get("reason") == "embedding_auth_failed"
|