#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ RAGAdapter tests """ import sys import json import asyncio import logging import pytest import data_modules.rag_adapter as rag_module from data_modules.rag_adapter import RAGAdapter from data_modules.config import DataModulesConfig class StubClient: async def embed(self, texts): return [[1.0, 0.0] for _ in texts] async def embed_batch(self, texts, skip_failures=True): return [[1.0, 0.0] for _ in texts] async def rerank(self, query, documents, top_n=None): top_n = top_n or len(documents) return [{"index": i, "relevance_score": 1.0 / (i + 1)} for i in range(min(top_n, len(documents)))] class StubClientWithFailures(StubClient): async def embed_batch(self, texts, skip_failures=True): if len(texts) == 1: return [None] return [None, [1.0, 0.0]] class StubEmbedClient401: def __init__(self): self.last_error_status = 401 self.last_error_message = "auth failed" class StubClientAuthFailure(StubClient): def __init__(self): self._embed_client = StubEmbedClient401() async def embed(self, texts): return None @pytest.fixture def temp_project(tmp_path, monkeypatch): cfg = DataModulesConfig.from_project_root(tmp_path) cfg.ensure_dirs() monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient()) return cfg @pytest.mark.asyncio async def test_store_and_search(temp_project): adapter = RAGAdapter(temp_project) chunks = [ {"chapter": 1, "scene_index": 1, "content": "萧炎在天云宗修炼斗气"}, {"chapter": 1, "scene_index": 2, "content": "药老传授炼药技巧"}, ] stored = await adapter.store_chunks(chunks) assert stored == 2 vec_results = await adapter.vector_search("萧炎", top_k=2) assert len(vec_results) == 2 bm25_results = adapter.bm25_search("萧炎", top_k=2) assert len(bm25_results) >= 1 stats = adapter.get_stats() assert stats["vectors"] == 2 @pytest.mark.asyncio async def test_store_chunks_with_embedding_failure(tmp_path, monkeypatch): cfg = DataModulesConfig.from_project_root(tmp_path) cfg.ensure_dirs() monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientWithFailures()) adapter = RAGAdapter(cfg) chunks = [ {"chapter": 1, "scene_index": 1, "content": "短内容"}, {"chapter": 1, "scene_index": 2, "content": "稍长内容用于索引"}, ] stored = await adapter.store_chunks(chunks) assert stored == 1 @pytest.mark.asyncio async def test_hybrid_search_full_scan(temp_project): adapter = RAGAdapter(temp_project) await adapter.store_chunks( [{"chapter": 1, "scene_index": 1, "content": "萧炎修炼"}] ) results = await adapter.hybrid_search("萧炎", vector_top_k=5, bm25_top_k=5, rerank_top_n=1) assert results assert results[0].source == "hybrid" @pytest.mark.asyncio async def test_hybrid_search_prefilter(tmp_path, monkeypatch): cfg = DataModulesConfig.from_project_root(tmp_path) cfg.ensure_dirs() cfg.vector_full_scan_max_vectors = 0 monkeypatch.setattr(rag_module, "get_client", lambda config: StubClient()) adapter = RAGAdapter(cfg) await adapter.store_chunks( [ {"chapter": 1, "scene_index": 1, "content": "萧炎修炼"}, {"chapter": 2, "scene_index": 1, "content": "药老出场"}, ] ) results = await adapter.hybrid_search("药老", vector_top_k=2, bm25_top_k=2, rerank_top_n=1) assert results @pytest.mark.asyncio async def test_search_with_backtrack(temp_project): adapter = RAGAdapter(temp_project) chunks = [ { "chapter": 1, "scene_index": 0, "content": "章节摘要", "chunk_type": "summary", "chunk_id": "ch0001_summary", "source_file": "summaries/ch0001.md", }, { "chapter": 1, "scene_index": 1, "content": "场景内容", "chunk_type": "scene", "chunk_id": "ch0001_s1", "parent_chunk_id": "ch0001_summary", "source_file": "正文/第0001章.md#scene_1", }, ] await adapter.store_chunks(chunks) results = await adapter.search_with_backtrack("场景", top_k=1) assert any(r.chunk_type == "summary" for r in results) def test_vector_helpers(temp_project): adapter = RAGAdapter(temp_project) emb = [1.0, 0.0] data = adapter._serialize_embedding(emb) assert adapter._deserialize_embedding(data) == emb assert adapter._cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0 def test_recent_and_fetch_vectors(temp_project): adapter = RAGAdapter(temp_project) with adapter._get_conn() as conn: cursor = conn.cursor() cursor.execute( "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", ("ch0001_s1", 1, 1, "内容", b"", None, "scene", "正文/第0001章.md#scene_1"), ) conn.commit() assert adapter._get_vectors_count() == 1 assert adapter._get_recent_chunk_ids(1) == ["ch0001_s1"] rows = adapter._fetch_vectors_by_chunk_ids(["ch0001_s1"]) assert len(rows) == 1 def test_rag_adapter_cli(temp_project, monkeypatch, capsys): # stats def run_cli(args): monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args) rag_module.main() root = str(temp_project.project_root) run_cli(["--project-root", root, "stats"]) # index-chapter run_cli( [ "--project-root", root, "index-chapter", "--chapter", "1", "--scenes", json.dumps([{"index": 1, "summary": "摘要", "content": "内容"}], ensure_ascii=False), ] ) # search run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "bm25", "--top-k", "5"]) run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "vector", "--top-k", "5"]) run_cli(["--project-root", root, "search", "--query", "内容", "--mode", "hybrid", "--top-k", "5"]) capsys.readouterr() def test_rag_adapter_log_query_failure_is_reported(temp_project, monkeypatch, caplog): adapter = RAGAdapter(temp_project) def _raise_log_error(*args, **kwargs): raise RuntimeError("log write failed") monkeypatch.setattr(adapter.index_manager, "log_rag_query", _raise_log_error) with caplog.at_level(logging.WARNING): adapter._log_query("q", "vector", [], 1) message_text = "\n".join(record.getMessage() for record in caplog.records) assert "failed to log rag query" in message_text def test_rag_adapter_cli_search_shows_degraded_warning(temp_project, monkeypatch, capsys): monkeypatch.setattr(rag_module, "get_client", lambda config: StubClientAuthFailure()) def run_cli(args): monkeypatch.setattr(sys, "argv", ["rag_adapter"] + args) rag_module.main() root = str(temp_project.project_root) run_cli(["--project-root", root, "search", "--query", "测试", "--mode", "vector", "--top-k", "3"]) captured = capsys.readouterr() payload = json.loads(captured.out.strip().splitlines()[-1]) assert payload.get("status") == "success" warnings = payload.get("warnings") or [] assert warnings assert warnings[0].get("code") == "DEGRADED_MODE" assert warnings[0].get("reason") == "embedding_auth_failed"