Просмотр исходного кода

feat(v5.4): 父子向量索引支持(rag_adapter 重构)

- 新增 parent_chunk_id, chunk_type, source_file 字段
- 实现表结构兼容性检查,不兼容时自动 DROP+CREATE
- SearchResult dataclass 扩展新字段
- 更新测试用例适配新结构

不考虑向前兼容,vectors.db 会自动重建。

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
lingfengQAQ 4 месяцев назад
Родитель
Сommit
916d1754ca

+ 340 - 76
.claude/scripts/data_modules/rag_adapter.py

@@ -21,9 +21,11 @@ from collections import Counter
 import re
 from contextlib import contextmanager
 import itertools
+import time
 
 from .config import get_config
 from .api_client import get_client
+from .index_manager import IndexManager
 
 
 @dataclass
@@ -35,6 +37,9 @@ class SearchResult:
     content: str
     score: float
     source: str  # "vector" | "bm25" | "hybrid"
+    parent_chunk_id: str | None = None
+    chunk_type: str | None = None
+    source_file: str | None = None
 
 
 class RAGAdapter:
@@ -43,6 +48,7 @@ class RAGAdapter:
     def __init__(self, config=None):
         self.config = config or get_config()
         self.api_client = get_client(config)
+        self.index_manager = IndexManager(self.config)
         self._init_db()
 
     def _init_db(self):
@@ -52,6 +58,29 @@ class RAGAdapter:
         with self._get_conn() as conn:
             cursor = conn.cursor()
 
+            def _table_columns(table_name: str) -> set[str]:
+                cursor.execute(f"PRAGMA table_info({table_name})")
+                return {row[1] for row in cursor.fetchall()}
+
+            required_cols = {
+                "chunk_id",
+                "chapter",
+                "scene_index",
+                "content",
+                "embedding",
+                "parent_chunk_id",
+                "chunk_type",
+                "source_file",
+                "created_at",
+            }
+
+            if "vectors" in {r[0] for r in cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")}:  # type: ignore
+                cols = _table_columns("vectors")
+                if not required_cols.issubset(cols):
+                    cursor.execute("DROP TABLE IF EXISTS vectors")
+                    cursor.execute("DROP TABLE IF EXISTS bm25_index")
+                    cursor.execute("DROP TABLE IF EXISTS doc_stats")
+
             # 向量存储表
             cursor.execute("""
                 CREATE TABLE IF NOT EXISTS vectors (
@@ -60,6 +89,9 @@ class RAGAdapter:
                     scene_index INTEGER,
                     content TEXT,
                     embedding BLOB,
+                    parent_chunk_id TEXT,
+                    chunk_type TEXT DEFAULT 'scene',
+                    source_file TEXT,
                     created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                 )
             """)
@@ -84,6 +116,8 @@ class RAGAdapter:
 
             # 创建索引
             cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_chapter ON vectors(chapter)")
+            cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_parent ON vectors(parent_chunk_id)")
+            cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_type ON vectors(chunk_type)")
             cursor.execute("CREATE INDEX IF NOT EXISTS idx_bm25_term ON bm25_index(term)")
 
             conn.commit()
@@ -104,15 +138,21 @@ class RAGAdapter:
             row = cursor.fetchone()
             return int(row[0] or 0) if row else 0
 
-    def _get_recent_chunk_ids(self, limit: int) -> List[str]:
+    def _get_recent_chunk_ids(self, limit: int, chunk_type: str | None = None) -> List[str]:
         if limit <= 0:
             return []
         with self._get_conn() as conn:
             cursor = conn.cursor()
-            cursor.execute(
-                "SELECT chunk_id FROM vectors ORDER BY chapter DESC, scene_index DESC LIMIT ?",
-                (int(limit),),
-            )
+            if chunk_type:
+                cursor.execute(
+                    "SELECT chunk_id FROM vectors WHERE chunk_type = ? ORDER BY chapter DESC, scene_index DESC LIMIT ?",
+                    (chunk_type, int(limit)),
+                )
+            else:
+                cursor.execute(
+                    "SELECT chunk_id FROM vectors ORDER BY chapter DESC, scene_index DESC LIMIT ?",
+                    (int(limit),),
+                )
             return [str(r[0]) for r in cursor.fetchall() if r and r[0]]
 
     def _fetch_vectors_by_chunk_ids(self, chunk_ids: List[str]) -> List[Tuple]:
@@ -134,7 +174,7 @@ class RAGAdapter:
             for batch in _chunks(chunk_ids):
                 placeholders = ",".join(["?"] * len(batch))
                 cursor.execute(
-                    f"SELECT chunk_id, chapter, scene_index, content, embedding FROM vectors WHERE chunk_id IN ({placeholders})",
+                    f"SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors WHERE chunk_id IN ({placeholders})",
                     tuple(batch),
                 )
                 rows.extend(cursor.fetchall())
@@ -149,7 +189,16 @@ class RAGAdapter:
     ) -> List[SearchResult]:
         results: List[SearchResult] = []
         for row in rows:
-            chunk_id, chapter, scene_index, content, embedding_bytes = row
+            (
+                chunk_id,
+                chapter,
+                scene_index,
+                content,
+                embedding_bytes,
+                parent_chunk_id,
+                chunk_type,
+                source_file,
+            ) = row
             if not embedding_bytes:
                 continue
             embedding = self._deserialize_embedding(embedding_bytes)
@@ -162,6 +211,9 @@ class RAGAdapter:
                     content=content,
                     score=score,
                     source="vector",
+                    parent_chunk_id=parent_chunk_id,
+                    chunk_type=chunk_type,
+                    source_file=source_file,
                 )
             )
 
@@ -179,7 +231,10 @@ class RAGAdapter:
             {
                 "chapter": 100,
                 "scene_index": 1,
-                "content": "场景内容..."
+                "content": "场景内容...",
+                "chunk_type": "scene",
+                "parent_chunk_id": "ch0100_summary",
+                "source_file": "正文/第0100章.md#scene_1"
             }
         ]
 
@@ -189,7 +244,7 @@ class RAGAdapter:
             return 0
 
         # 提取内容用于嵌入
-        contents = [c["content"] for c in chunks]
+        contents = [c.get("content", "") for c in chunks]
 
         # 调用 API 获取嵌入向量(可能包含 None 表示失败)
         embeddings = await self.api_client.embed_batch(contents)
@@ -207,37 +262,48 @@ class RAGAdapter:
                 if embedding is None:
                     # 嵌入失败,跳过该 chunk(仅存储 BM25 索引供关键词检索)
                     skipped += 1
-                    chunk_id = f"ch{chunk['chapter']}_s{chunk['scene_index']}"
-                    self._update_bm25_index(cursor, chunk_id, chunk["content"])
+                    chunk_id = chunk.get("chunk_id")
+                    if not chunk_id:
+                        if chunk.get("chunk_type") == "summary":
+                            chunk_id = f"ch{int(chunk['chapter']):04d}_summary"
+                        else:
+                            chunk_id = f"ch{int(chunk['chapter']):04d}_s{int(chunk['scene_index'])}"
+                    self._update_bm25_index(cursor, chunk_id, chunk.get("content", ""))
                     continue
 
-                chunk_id = f"ch{chunk['chapter']}_s{chunk['scene_index']}"
+                chunk_type = chunk.get("chunk_type") or "scene"
+                chunk_id = chunk.get("chunk_id")
+                if not chunk_id:
+                    if chunk_type == "summary":
+                        chunk_id = f"ch{int(chunk['chapter']):04d}_summary"
+                    else:
+                        chunk_id = f"ch{int(chunk['chapter']):04d}_s{int(chunk['scene_index'])}"
 
                 # 将向量序列化为 bytes
                 embedding_bytes = self._serialize_embedding(embedding)
 
                 cursor.execute("""
                     INSERT OR REPLACE INTO vectors
-                    (chunk_id, chapter, scene_index, content, embedding)
-                    VALUES (?, ?, ?, ?, ?)
+                    (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file)
+                    VALUES (?, ?, ?, ?, ?, ?, ?, ?)
                 """, (
                     chunk_id,
                     chunk["chapter"],
-                    chunk["scene_index"],
-                    chunk["content"],
-                    embedding_bytes
+                    chunk.get("scene_index", 0) if chunk_type == "scene" else 0,
+                    chunk.get("content", ""),
+                    embedding_bytes,
+                    chunk.get("parent_chunk_id"),
+                    chunk_type,
+                    chunk.get("source_file"),
                 ))
 
                 # 同时更新 BM25 索引
-                self._update_bm25_index(cursor, chunk_id, chunk["content"])
+                self._update_bm25_index(cursor, chunk_id, chunk.get("content", ""))
 
                 stored += 1
 
             conn.commit()
 
-        if skipped > 0:
-            print(f"[WARN] store_chunks: {skipped} chunks skipped due to embedding failure (BM25 only)")
-
         return stored
 
     def _serialize_embedding(self, embedding: List[float]) -> bytes:
@@ -251,6 +317,27 @@ class RAGAdapter:
         count = len(data) // 4
         return list(struct.unpack(f"{count}f", data))
 
+    def _log_query(
+        self,
+        query: str,
+        query_type: str,
+        results: List[SearchResult],
+        latency_ms: int,
+        chapter: int | None = None,
+    ) -> None:
+        try:
+            hit_sources = Counter([r.chunk_type or "unknown" for r in results])
+            self.index_manager.log_rag_query(
+                query=query,
+                query_type=query_type,
+                results_count=len(results),
+                hit_sources=json.dumps(hit_sources, ensure_ascii=False),
+                latency_ms=latency_ms,
+                chapter=chapter,
+            )
+        except Exception:
+            pass
+
     # ==================== BM25 索引 ====================
 
     def _tokenize(self, text: str) -> List[str]:
@@ -296,10 +383,14 @@ class RAGAdapter:
     async def vector_search(
         self,
         query: str,
-        top_k: int = None
+        top_k: int = None,
+        chunk_type: str | None = None,
+        log_query: bool = True,
+        chapter: int | None = None,
     ) -> List[SearchResult]:
         """向量相似度搜索"""
         top_k = top_k or self.config.vector_top_k
+        start_time = time.perf_counter()
 
         # 获取查询向量
         query_embeddings = await self.api_client.embed([query])
@@ -311,11 +402,30 @@ class RAGAdapter:
         # 从数据库读取所有向量并计算相似度
         with self._get_conn() as conn:
             cursor = conn.cursor()
-            cursor.execute("SELECT chunk_id, chapter, scene_index, content, embedding FROM vectors")
+            if chunk_type:
+                cursor.execute(
+                    "SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors WHERE chunk_type = ?",
+                    (chunk_type,),
+                )
+            else:
+                cursor.execute(
+                    "SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors"
+                )
 
             results = []
             for row in cursor.fetchall():
-                chunk_id, chapter, scene_index, content, embedding_bytes = row
+                (
+                    chunk_id,
+                    chapter,
+                    scene_index,
+                    content,
+                    embedding_bytes,
+                    parent_chunk_id,
+                    chunk_type_value,
+                    source_file,
+                ) = row
+                if not embedding_bytes:
+                    continue
                 embedding = self._deserialize_embedding(embedding_bytes)
 
                 # 计算余弦相似度
@@ -327,12 +437,19 @@ class RAGAdapter:
                     scene_index=scene_index,
                     content=content,
                     score=score,
-                    source="vector"
+                    source="vector",
+                    parent_chunk_id=parent_chunk_id,
+                    chunk_type=chunk_type_value,
+                    source_file=source_file,
                 ))
 
         # 排序并返回 top_k
         results.sort(key=lambda x: x.score, reverse=True)
-        return results[:top_k]
+        results = results[:top_k]
+        if log_query:
+            latency_ms = int((time.perf_counter() - start_time) * 1000)
+            self._log_query(query, "vector", results, latency_ms, chapter=chapter)
+        return results
 
     def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
         """计算余弦相似度"""
@@ -350,10 +467,14 @@ class RAGAdapter:
         query: str,
         top_k: int = None,
         k1: float = 1.5,
-        b: float = 0.75
+        b: float = 0.75,
+        chunk_type: str | None = None,
+        log_query: bool = True,
+        chapter: int | None = None,
     ) -> List[SearchResult]:
         """BM25 关键词搜索"""
         top_k = top_k or self.config.bm25_top_k
+        start_time = time.perf_counter()
 
         query_terms = self._tokenize(query)
         if not query_terms:
@@ -400,11 +521,24 @@ class RAGAdapter:
             # 获取文档内容
             results = []
             for chunk_id, score in doc_scores.items():
-                cursor.execute("""
-                    SELECT chapter, scene_index, content
-                    FROM vectors
-                    WHERE chunk_id = ?
-                """, (chunk_id,))
+                if chunk_type:
+                    cursor.execute(
+                        """
+                        SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
+                        FROM vectors
+                        WHERE chunk_id = ? AND chunk_type = ?
+                    """,
+                        (chunk_id, chunk_type),
+                    )
+                else:
+                    cursor.execute(
+                        """
+                        SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
+                        FROM vectors
+                        WHERE chunk_id = ?
+                    """,
+                        (chunk_id,),
+                    )
                 row = cursor.fetchone()
                 if row:
                     results.append(SearchResult(
@@ -413,11 +547,18 @@ class RAGAdapter:
                         scene_index=row[1],
                         content=row[2],
                         score=score,
-                        source="bm25"
+                        source="bm25",
+                        parent_chunk_id=row[3],
+                        chunk_type=row[4],
+                        source_file=row[5],
                     ))
 
         results.sort(key=lambda x: x.score, reverse=True)
-        return results[:top_k]
+        results = results[:top_k]
+        if log_query:
+            latency_ms = int((time.perf_counter() - start_time) * 1000)
+            self._log_query(query, "bm25", results, latency_ms, chapter=chapter)
+        return results
 
     # ==================== 混合检索 ====================
 
@@ -426,7 +567,9 @@ class RAGAdapter:
         query: str,
         vector_top_k: int = None,
         bm25_top_k: int = None,
-        rerank_top_n: int = None
+        rerank_top_n: int = None,
+        chunk_type: str | None = None,
+        log_query: bool = True,
     ) -> List[SearchResult]:
         """
         混合检索:向量 + BM25 + RRF 融合 + Rerank
@@ -440,6 +583,7 @@ class RAGAdapter:
         vector_top_k = vector_top_k or self.config.vector_top_k
         bm25_top_k = bm25_top_k or self.config.bm25_top_k
         rerank_top_n = rerank_top_n or self.config.rerank_top_n
+        start_time = time.perf_counter()
 
         # 小规模:全表向量扫描(召回更稳);大规模:预筛选避免 O(n) 扫描拖慢
         vectors_count = await asyncio.to_thread(self._get_vectors_count)
@@ -448,8 +592,8 @@ class RAGAdapter:
         if use_full_scan:
             # 并行执行向量和 BM25 检索
             vector_results, bm25_results = await asyncio.gather(
-                self.vector_search(query, vector_top_k),
-                asyncio.to_thread(self.bm25_search, query, bm25_top_k)
+                self.vector_search(query, vector_top_k, chunk_type=chunk_type, log_query=False),
+                asyncio.to_thread(self.bm25_search, query, bm25_top_k, 1.5, 0.75, chunk_type, False),
             )
         else:
             bm25_candidates = max(
@@ -464,8 +608,8 @@ class RAGAdapter:
                 int(rerank_top_n) * 10,
             )
 
-            bm25_task = asyncio.to_thread(self.bm25_search, query, bm25_candidates)
-            recent_task = asyncio.to_thread(self._get_recent_chunk_ids, recent_candidates)
+            bm25_task = asyncio.to_thread(self.bm25_search, query, bm25_candidates, 1.5, 0.75, chunk_type, False)
+            recent_task = asyncio.to_thread(self._get_recent_chunk_ids, recent_candidates, chunk_type)
             embed_task = self.api_client.embed([query])
 
             bm25_candidates_results, recent_ids, query_embeddings = await asyncio.gather(
@@ -482,6 +626,8 @@ class RAGAdapter:
             candidate_ids.update(recent_ids)
 
             rows = await asyncio.to_thread(self._fetch_vectors_by_chunk_ids, list(candidate_ids))
+            if chunk_type:
+                rows = [r for r in rows if len(r) > 6 and r[6] == chunk_type]
             vector_results = await asyncio.to_thread(
                 self._vector_search_rows,
                 query_embedding,
@@ -517,7 +663,11 @@ class RAGAdapter:
         candidates = [item["result"] for item in sorted_results[:rerank_top_n * 2]]
 
         if not candidates:
-            return []
+            final_results: List[SearchResult] = []
+            latency_ms = int((time.perf_counter() - start_time) * 1000)
+            if log_query:
+                self._log_query(query, "hybrid", final_results, latency_ms)
+            return final_results
 
         # 调用 Rerank API
         documents = [c.content for c in candidates]
@@ -525,7 +675,11 @@ class RAGAdapter:
 
         if not rerank_results:
             # Rerank 失败,返回 RRF 结果
-            return [item["result"] for item in sorted_results[:rerank_top_n]]
+            final_results = [item["result"] for item in sorted_results[:rerank_top_n]]
+            latency_ms = int((time.perf_counter() - start_time) * 1000)
+            if log_query:
+                self._log_query(query, "hybrid", final_results, latency_ms)
+            return final_results
 
         # 组装最终结果
         final_results = []
@@ -537,8 +691,73 @@ class RAGAdapter:
                 result.source = "hybrid"
                 final_results.append(result)
 
+        latency_ms = int((time.perf_counter() - start_time) * 1000)
+        if log_query:
+            self._log_query(query, "hybrid", final_results, latency_ms)
         return final_results
 
+    def _get_chunks_by_ids(self, chunk_ids: List[str]) -> List[SearchResult]:
+        rows = self._fetch_vectors_by_chunk_ids(chunk_ids)
+        results: List[SearchResult] = []
+        for row in rows:
+            (
+                chunk_id,
+                chapter,
+                scene_index,
+                content,
+                _embedding_bytes,
+                parent_chunk_id,
+                chunk_type,
+                source_file,
+            ) = row
+            results.append(
+                SearchResult(
+                    chunk_id=chunk_id,
+                    chapter=chapter,
+                    scene_index=scene_index,
+                    content=content,
+                    score=0.0,
+                    source="parent",
+                    parent_chunk_id=parent_chunk_id,
+                    chunk_type=chunk_type,
+                    source_file=source_file,
+                )
+            )
+        return results
+
+    def _merge_results(
+        self,
+        parents: List[SearchResult],
+        children: List[SearchResult],
+    ) -> List[SearchResult]:
+        parent_map = {p.chunk_id: p for p in parents}
+        merged: List[SearchResult] = []
+        seen = set()
+        for child in children:
+            parent_id = child.parent_chunk_id
+            if parent_id and parent_id in parent_map and parent_id not in seen:
+                merged.append(parent_map[parent_id])
+                seen.add(parent_id)
+            merged.append(child)
+        return merged
+
+    async def search_with_backtrack(self, query: str, top_k: int = 5) -> List[SearchResult]:
+        start_time = time.perf_counter()
+        child_results = await self.hybrid_search(
+            query,
+            vector_top_k=top_k * 2,
+            bm25_top_k=top_k * 2,
+            rerank_top_n=top_k,
+            chunk_type="scene",
+            log_query=False,
+        )
+        parent_ids = sorted({r.parent_chunk_id for r in child_results if r.parent_chunk_id})
+        parents = self._get_chunks_by_ids(parent_ids) if parent_ids else []
+        merged = self._merge_results(parents, child_results[:top_k])
+        latency_ms = int((time.perf_counter() - start_time) * 1000)
+        self._log_query(query, "backtrack", merged, latency_ms)
+        return merged
+
     # ==================== 统计 ====================
 
     def get_stats(self) -> Dict[str, int]:
@@ -566,6 +785,7 @@ class RAGAdapter:
 
 def main():
     import argparse
+    from .cli_output import print_success, print_error
 
     parser = argparse.ArgumentParser(description="RAG Adapter CLI")
     parser.add_argument("--project-root", type=str, help="项目根目录")
@@ -575,16 +795,18 @@ def main():
     # 获取统计
     subparsers.add_parser("stats")
 
-    # 搜索
-    search_parser = subparsers.add_parser("search")
-    search_parser.add_argument("--query", required=True, help="搜索查询")
-    search_parser.add_argument("--mode", choices=["vector", "bm25", "hybrid"], default="hybrid")
-    search_parser.add_argument("--top-k", type=int, default=10)
-
-    # 索引章节
+    # 写入索引
     index_parser = subparsers.add_parser("index-chapter")
     index_parser.add_argument("--chapter", type=int, required=True)
     index_parser.add_argument("--scenes", required=True, help="JSON 格式的场景列表")
+    index_parser.add_argument("--summary", required=False, help="章节摘要文本")
+
+    # 搜索
+    search_parser = subparsers.add_parser("search")
+    search_parser.add_argument("--query", required=True)
+    search_parser.add_argument("--mode", choices=["vector", "bm25", "hybrid", "backtrack"], default="hybrid")
+    search_parser.add_argument("--top-k", type=int, default=5)
+    search_parser.add_argument("--chunk-type", choices=["scene", "summary"], default=None)
 
     args = parser.parse_args()
 
@@ -592,46 +814,88 @@ def main():
     config = None
     if args.project_root:
         from .config import DataModulesConfig
+
         config = DataModulesConfig.from_project_root(args.project_root)
 
     adapter = RAGAdapter(config)
+    tool_name = f"rag_adapter:{args.command or 'unknown'}"
+
+    def emit_success(data=None, message: str = "ok"):
+        print_success(data, message=message)
+        try:
+            adapter.index_manager.log_tool_call(tool_name, True)
+        except Exception:
+            pass
+
+    def emit_error(code: str, message: str, suggestion: str | None = None):
+        print_error(code, message, suggestion=suggestion)
+        try:
+            adapter.index_manager.log_tool_call(tool_name, False, error_code=code, error_message=message)
+        except Exception:
+            pass
 
     if args.command == "stats":
         stats = adapter.get_stats()
-        print(json.dumps(stats, ensure_ascii=False, indent=2))
+        emit_success(stats, message="stats")
 
-    elif args.command == "search":
-        async def do_search():
-            if args.mode == "vector":
-                results = await adapter.vector_search(args.query, args.top_k)
-            elif args.mode == "bm25":
-                results = adapter.bm25_search(args.query, args.top_k)
-            else:
-                results = await adapter.hybrid_search(args.query)
+    elif args.command == "index-chapter":
+        scenes = json.loads(args.scenes)
+        chunks = []
+
+        # summary chunk
+        summary_text = args.summary
+        if not summary_text and config:
+            summary_path = config.webnovel_dir / "summaries" / f"ch{args.chapter:04d}.md"
+            if summary_path.exists():
+                summary_text = summary_path.read_text(encoding="utf-8")
+
+        parent_chunk_id = None
+        if summary_text:
+            parent_chunk_id = f"ch{args.chapter:04d}_summary"
+            chunks.append(
+                {
+                    "chapter": args.chapter,
+                    "scene_index": 0,
+                    "content": summary_text,
+                    "chunk_type": "summary",
+                    "chunk_id": parent_chunk_id,
+                    "source_file": f"summaries/ch{args.chapter:04d}.md",
+                }
+            )
 
-            print(f"搜索结果 ({len(results)} 条):")
-            for r in results:
-                print(f"\n[{r.source}] 第 {r.chapter} 章 场景 {r.scene_index} (score: {r.score:.4f})")
-                print(f"  {r.content[:100]}...")
+        for s in scenes:
+            scene_index = s.get("index", 0)
+            chunk_id = f"ch{args.chapter:04d}_s{int(scene_index)}"
+            chunks.append(
+                {
+                    "chapter": args.chapter,
+                    "scene_index": scene_index,
+                    "content": s.get("content", ""),
+                    "chunk_type": "scene",
+                    "parent_chunk_id": parent_chunk_id,
+                    "chunk_id": chunk_id,
+                    "source_file": f"正文/第{args.chapter:04d}章.md#scene_{int(scene_index)}",
+                }
+            )
 
-        asyncio.run(do_search())
+        stored = asyncio.run(adapter.store_chunks(chunks))
+        emit_success({"stored": stored, "chunks": len(chunks)}, message="indexed")
 
-    elif args.command == "index-chapter":
-        scenes = json.loads(args.scenes)
-        chunks = [
-            {
-                "chapter": args.chapter,
-                "scene_index": s.get("index", i),
-                "content": s.get("summary", "") + "\n" + s.get("content", "")
-            }
-            for i, s in enumerate(scenes)
-        ]
+    elif args.command == "search":
+        if args.mode == "vector":
+            results = asyncio.run(adapter.vector_search(args.query, args.top_k, chunk_type=args.chunk_type))
+        elif args.mode == "bm25":
+            results = adapter.bm25_search(args.query, args.top_k, chunk_type=args.chunk_type)
+        elif args.mode == "backtrack":
+            results = asyncio.run(adapter.search_with_backtrack(args.query, args.top_k))
+        else:
+            results = asyncio.run(adapter.hybrid_search(args.query, args.top_k, args.top_k, args.top_k, chunk_type=args.chunk_type))
 
-        async def do_index():
-            stored = await adapter.store_chunks(chunks)
-            print(f"✓ 已索引 {stored} 个场景")
+        payload = [r.__dict__ for r in results]
+        emit_success(payload, message="search_results")
 
-        asyncio.run(do_index())
+    else:
+        emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")
 
 
 if __name__ == "__main__":

+ 31 - 4
.claude/scripts/data_modules/tests/test_rag_adapter.py

@@ -105,6 +105,33 @@ async def test_hybrid_search_prefilter(tmp_path, monkeypatch):
     assert results
 
 
+@pytest.mark.asyncio
+async def test_search_with_backtrack(temp_project):
+    adapter = RAGAdapter(temp_project)
+    chunks = [
+        {
+            "chapter": 1,
+            "scene_index": 0,
+            "content": "章节摘要",
+            "chunk_type": "summary",
+            "chunk_id": "ch0001_summary",
+            "source_file": "summaries/ch0001.md",
+        },
+        {
+            "chapter": 1,
+            "scene_index": 1,
+            "content": "场景内容",
+            "chunk_type": "scene",
+            "chunk_id": "ch0001_s1",
+            "parent_chunk_id": "ch0001_summary",
+            "source_file": "正文/第0001章.md#scene_1",
+        },
+    ]
+    await adapter.store_chunks(chunks)
+    results = await adapter.search_with_backtrack("场景", top_k=1)
+    assert any(r.chunk_type == "summary" for r in results)
+
+
 def test_vector_helpers(temp_project):
     adapter = RAGAdapter(temp_project)
     emb = [1.0, 0.0]
@@ -119,14 +146,14 @@ def test_recent_and_fetch_vectors(temp_project):
     with adapter._get_conn() as conn:
         cursor = conn.cursor()
         cursor.execute(
-            "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding) VALUES (?, ?, ?, ?, ?)",
-            ("ch1_s1", 1, 1, "内容", b""),
+            "INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
+            ("ch0001_s1", 1, 1, "内容", b"", None, "scene", "正文/第0001章.md#scene_1"),
         )
         conn.commit()
 
     assert adapter._get_vectors_count() == 1
-    assert adapter._get_recent_chunk_ids(1) == ["ch1_s1"]
-    rows = adapter._fetch_vectors_by_chunk_ids(["ch1_s1"])
+    assert adapter._get_recent_chunk_ids(1) == ["ch0001_s1"]
+    rows = adapter._fetch_vectors_by_chunk_ids(["ch0001_s1"])
     assert len(rows) == 1