|
|
@@ -21,9 +21,11 @@ from collections import Counter
|
|
|
import re
|
|
|
from contextlib import contextmanager
|
|
|
import itertools
|
|
|
+import time
|
|
|
|
|
|
from .config import get_config
|
|
|
from .api_client import get_client
|
|
|
+from .index_manager import IndexManager
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
@@ -35,6 +37,9 @@ class SearchResult:
|
|
|
content: str
|
|
|
score: float
|
|
|
source: str # "vector" | "bm25" | "hybrid"
|
|
|
+ parent_chunk_id: str | None = None
|
|
|
+ chunk_type: str | None = None
|
|
|
+ source_file: str | None = None
|
|
|
|
|
|
|
|
|
class RAGAdapter:
|
|
|
@@ -43,6 +48,7 @@ class RAGAdapter:
|
|
|
def __init__(self, config=None):
|
|
|
self.config = config or get_config()
|
|
|
self.api_client = get_client(config)
|
|
|
+ self.index_manager = IndexManager(self.config)
|
|
|
self._init_db()
|
|
|
|
|
|
def _init_db(self):
|
|
|
@@ -52,6 +58,29 @@ class RAGAdapter:
|
|
|
with self._get_conn() as conn:
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
+ def _table_columns(table_name: str) -> set[str]:
|
|
|
+ cursor.execute(f"PRAGMA table_info({table_name})")
|
|
|
+ return {row[1] for row in cursor.fetchall()}
|
|
|
+
|
|
|
+ required_cols = {
|
|
|
+ "chunk_id",
|
|
|
+ "chapter",
|
|
|
+ "scene_index",
|
|
|
+ "content",
|
|
|
+ "embedding",
|
|
|
+ "parent_chunk_id",
|
|
|
+ "chunk_type",
|
|
|
+ "source_file",
|
|
|
+ "created_at",
|
|
|
+ }
|
|
|
+
|
|
|
+ if "vectors" in {r[0] for r in cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")}: # type: ignore
|
|
|
+ cols = _table_columns("vectors")
|
|
|
+ if not required_cols.issubset(cols):
|
|
|
+ cursor.execute("DROP TABLE IF EXISTS vectors")
|
|
|
+ cursor.execute("DROP TABLE IF EXISTS bm25_index")
|
|
|
+ cursor.execute("DROP TABLE IF EXISTS doc_stats")
|
|
|
+
|
|
|
# 向量存储表
|
|
|
cursor.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS vectors (
|
|
|
@@ -60,6 +89,9 @@ class RAGAdapter:
|
|
|
scene_index INTEGER,
|
|
|
content TEXT,
|
|
|
embedding BLOB,
|
|
|
+ parent_chunk_id TEXT,
|
|
|
+ chunk_type TEXT DEFAULT 'scene',
|
|
|
+ source_file TEXT,
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
|
)
|
|
|
""")
|
|
|
@@ -84,6 +116,8 @@ class RAGAdapter:
|
|
|
|
|
|
# 创建索引
|
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_chapter ON vectors(chapter)")
|
|
|
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_parent ON vectors(parent_chunk_id)")
|
|
|
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_vectors_type ON vectors(chunk_type)")
|
|
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bm25_term ON bm25_index(term)")
|
|
|
|
|
|
conn.commit()
|
|
|
@@ -104,15 +138,21 @@ class RAGAdapter:
|
|
|
row = cursor.fetchone()
|
|
|
return int(row[0] or 0) if row else 0
|
|
|
|
|
|
- def _get_recent_chunk_ids(self, limit: int) -> List[str]:
|
|
|
+ def _get_recent_chunk_ids(self, limit: int, chunk_type: str | None = None) -> List[str]:
|
|
|
if limit <= 0:
|
|
|
return []
|
|
|
with self._get_conn() as conn:
|
|
|
cursor = conn.cursor()
|
|
|
- cursor.execute(
|
|
|
- "SELECT chunk_id FROM vectors ORDER BY chapter DESC, scene_index DESC LIMIT ?",
|
|
|
- (int(limit),),
|
|
|
- )
|
|
|
+ if chunk_type:
|
|
|
+ cursor.execute(
|
|
|
+ "SELECT chunk_id FROM vectors WHERE chunk_type = ? ORDER BY chapter DESC, scene_index DESC LIMIT ?",
|
|
|
+ (chunk_type, int(limit)),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ cursor.execute(
|
|
|
+ "SELECT chunk_id FROM vectors ORDER BY chapter DESC, scene_index DESC LIMIT ?",
|
|
|
+ (int(limit),),
|
|
|
+ )
|
|
|
return [str(r[0]) for r in cursor.fetchall() if r and r[0]]
|
|
|
|
|
|
def _fetch_vectors_by_chunk_ids(self, chunk_ids: List[str]) -> List[Tuple]:
|
|
|
@@ -134,7 +174,7 @@ class RAGAdapter:
|
|
|
for batch in _chunks(chunk_ids):
|
|
|
placeholders = ",".join(["?"] * len(batch))
|
|
|
cursor.execute(
|
|
|
- f"SELECT chunk_id, chapter, scene_index, content, embedding FROM vectors WHERE chunk_id IN ({placeholders})",
|
|
|
+ f"SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors WHERE chunk_id IN ({placeholders})",
|
|
|
tuple(batch),
|
|
|
)
|
|
|
rows.extend(cursor.fetchall())
|
|
|
@@ -149,7 +189,16 @@ class RAGAdapter:
|
|
|
) -> List[SearchResult]:
|
|
|
results: List[SearchResult] = []
|
|
|
for row in rows:
|
|
|
- chunk_id, chapter, scene_index, content, embedding_bytes = row
|
|
|
+ (
|
|
|
+ chunk_id,
|
|
|
+ chapter,
|
|
|
+ scene_index,
|
|
|
+ content,
|
|
|
+ embedding_bytes,
|
|
|
+ parent_chunk_id,
|
|
|
+ chunk_type,
|
|
|
+ source_file,
|
|
|
+ ) = row
|
|
|
if not embedding_bytes:
|
|
|
continue
|
|
|
embedding = self._deserialize_embedding(embedding_bytes)
|
|
|
@@ -162,6 +211,9 @@ class RAGAdapter:
|
|
|
content=content,
|
|
|
score=score,
|
|
|
source="vector",
|
|
|
+ parent_chunk_id=parent_chunk_id,
|
|
|
+ chunk_type=chunk_type,
|
|
|
+ source_file=source_file,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
@@ -179,7 +231,10 @@ class RAGAdapter:
|
|
|
{
|
|
|
"chapter": 100,
|
|
|
"scene_index": 1,
|
|
|
- "content": "场景内容..."
|
|
|
+ "content": "场景内容...",
|
|
|
+ "chunk_type": "scene",
|
|
|
+ "parent_chunk_id": "ch0100_summary",
|
|
|
+ "source_file": "正文/第0100章.md#scene_1"
|
|
|
}
|
|
|
]
|
|
|
|
|
|
@@ -189,7 +244,7 @@ class RAGAdapter:
|
|
|
return 0
|
|
|
|
|
|
# 提取内容用于嵌入
|
|
|
- contents = [c["content"] for c in chunks]
|
|
|
+ contents = [c.get("content", "") for c in chunks]
|
|
|
|
|
|
# 调用 API 获取嵌入向量(可能包含 None 表示失败)
|
|
|
embeddings = await self.api_client.embed_batch(contents)
|
|
|
@@ -207,37 +262,48 @@ class RAGAdapter:
|
|
|
if embedding is None:
|
|
|
# 嵌入失败,跳过该 chunk(仅存储 BM25 索引供关键词检索)
|
|
|
skipped += 1
|
|
|
- chunk_id = f"ch{chunk['chapter']}_s{chunk['scene_index']}"
|
|
|
- self._update_bm25_index(cursor, chunk_id, chunk["content"])
|
|
|
+ chunk_id = chunk.get("chunk_id")
|
|
|
+ if not chunk_id:
|
|
|
+ if chunk.get("chunk_type") == "summary":
|
|
|
+ chunk_id = f"ch{int(chunk['chapter']):04d}_summary"
|
|
|
+ else:
|
|
|
+ chunk_id = f"ch{int(chunk['chapter']):04d}_s{int(chunk['scene_index'])}"
|
|
|
+ self._update_bm25_index(cursor, chunk_id, chunk.get("content", ""))
|
|
|
continue
|
|
|
|
|
|
- chunk_id = f"ch{chunk['chapter']}_s{chunk['scene_index']}"
|
|
|
+ chunk_type = chunk.get("chunk_type") or "scene"
|
|
|
+ chunk_id = chunk.get("chunk_id")
|
|
|
+ if not chunk_id:
|
|
|
+ if chunk_type == "summary":
|
|
|
+ chunk_id = f"ch{int(chunk['chapter']):04d}_summary"
|
|
|
+ else:
|
|
|
+ chunk_id = f"ch{int(chunk['chapter']):04d}_s{int(chunk['scene_index'])}"
|
|
|
|
|
|
# 将向量序列化为 bytes
|
|
|
embedding_bytes = self._serialize_embedding(embedding)
|
|
|
|
|
|
cursor.execute("""
|
|
|
INSERT OR REPLACE INTO vectors
|
|
|
- (chunk_id, chapter, scene_index, content, embedding)
|
|
|
- VALUES (?, ?, ?, ?, ?)
|
|
|
+ (chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file)
|
|
|
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
""", (
|
|
|
chunk_id,
|
|
|
chunk["chapter"],
|
|
|
- chunk["scene_index"],
|
|
|
- chunk["content"],
|
|
|
- embedding_bytes
|
|
|
+ chunk.get("scene_index", 0) if chunk_type == "scene" else 0,
|
|
|
+ chunk.get("content", ""),
|
|
|
+ embedding_bytes,
|
|
|
+ chunk.get("parent_chunk_id"),
|
|
|
+ chunk_type,
|
|
|
+ chunk.get("source_file"),
|
|
|
))
|
|
|
|
|
|
# 同时更新 BM25 索引
|
|
|
- self._update_bm25_index(cursor, chunk_id, chunk["content"])
|
|
|
+ self._update_bm25_index(cursor, chunk_id, chunk.get("content", ""))
|
|
|
|
|
|
stored += 1
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
- if skipped > 0:
|
|
|
- print(f"[WARN] store_chunks: {skipped} chunks skipped due to embedding failure (BM25 only)")
|
|
|
-
|
|
|
return stored
|
|
|
|
|
|
def _serialize_embedding(self, embedding: List[float]) -> bytes:
|
|
|
@@ -251,6 +317,27 @@ class RAGAdapter:
|
|
|
count = len(data) // 4
|
|
|
return list(struct.unpack(f"{count}f", data))
|
|
|
|
|
|
+ def _log_query(
|
|
|
+ self,
|
|
|
+ query: str,
|
|
|
+ query_type: str,
|
|
|
+ results: List[SearchResult],
|
|
|
+ latency_ms: int,
|
|
|
+ chapter: int | None = None,
|
|
|
+ ) -> None:
|
|
|
+ try:
|
|
|
+ hit_sources = Counter([r.chunk_type or "unknown" for r in results])
|
|
|
+ self.index_manager.log_rag_query(
|
|
|
+ query=query,
|
|
|
+ query_type=query_type,
|
|
|
+ results_count=len(results),
|
|
|
+ hit_sources=json.dumps(hit_sources, ensure_ascii=False),
|
|
|
+ latency_ms=latency_ms,
|
|
|
+ chapter=chapter,
|
|
|
+ )
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
# ==================== BM25 索引 ====================
|
|
|
|
|
|
def _tokenize(self, text: str) -> List[str]:
|
|
|
@@ -296,10 +383,14 @@ class RAGAdapter:
|
|
|
async def vector_search(
|
|
|
self,
|
|
|
query: str,
|
|
|
- top_k: int = None
|
|
|
+ top_k: int = None,
|
|
|
+ chunk_type: str | None = None,
|
|
|
+ log_query: bool = True,
|
|
|
+ chapter: int | None = None,
|
|
|
) -> List[SearchResult]:
|
|
|
"""向量相似度搜索"""
|
|
|
top_k = top_k or self.config.vector_top_k
|
|
|
+ start_time = time.perf_counter()
|
|
|
|
|
|
# 获取查询向量
|
|
|
query_embeddings = await self.api_client.embed([query])
|
|
|
@@ -311,11 +402,30 @@ class RAGAdapter:
|
|
|
# 从数据库读取所有向量并计算相似度
|
|
|
with self._get_conn() as conn:
|
|
|
cursor = conn.cursor()
|
|
|
- cursor.execute("SELECT chunk_id, chapter, scene_index, content, embedding FROM vectors")
|
|
|
+ if chunk_type:
|
|
|
+ cursor.execute(
|
|
|
+ "SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors WHERE chunk_type = ?",
|
|
|
+ (chunk_type,),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ cursor.execute(
|
|
|
+ "SELECT chunk_id, chapter, scene_index, content, embedding, parent_chunk_id, chunk_type, source_file FROM vectors"
|
|
|
+ )
|
|
|
|
|
|
results = []
|
|
|
for row in cursor.fetchall():
|
|
|
- chunk_id, chapter, scene_index, content, embedding_bytes = row
|
|
|
+ (
|
|
|
+ chunk_id,
|
|
|
+ chapter,
|
|
|
+ scene_index,
|
|
|
+ content,
|
|
|
+ embedding_bytes,
|
|
|
+ parent_chunk_id,
|
|
|
+ chunk_type_value,
|
|
|
+ source_file,
|
|
|
+ ) = row
|
|
|
+ if not embedding_bytes:
|
|
|
+ continue
|
|
|
embedding = self._deserialize_embedding(embedding_bytes)
|
|
|
|
|
|
# 计算余弦相似度
|
|
|
@@ -327,12 +437,19 @@ class RAGAdapter:
|
|
|
scene_index=scene_index,
|
|
|
content=content,
|
|
|
score=score,
|
|
|
- source="vector"
|
|
|
+ source="vector",
|
|
|
+ parent_chunk_id=parent_chunk_id,
|
|
|
+ chunk_type=chunk_type_value,
|
|
|
+ source_file=source_file,
|
|
|
))
|
|
|
|
|
|
# 排序并返回 top_k
|
|
|
results.sort(key=lambda x: x.score, reverse=True)
|
|
|
- return results[:top_k]
|
|
|
+ results = results[:top_k]
|
|
|
+ if log_query:
|
|
|
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
+ self._log_query(query, "vector", results, latency_ms, chapter=chapter)
|
|
|
+ return results
|
|
|
|
|
|
def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
|
|
|
"""计算余弦相似度"""
|
|
|
@@ -350,10 +467,14 @@ class RAGAdapter:
|
|
|
query: str,
|
|
|
top_k: int = None,
|
|
|
k1: float = 1.5,
|
|
|
- b: float = 0.75
|
|
|
+ b: float = 0.75,
|
|
|
+ chunk_type: str | None = None,
|
|
|
+ log_query: bool = True,
|
|
|
+ chapter: int | None = None,
|
|
|
) -> List[SearchResult]:
|
|
|
"""BM25 关键词搜索"""
|
|
|
top_k = top_k or self.config.bm25_top_k
|
|
|
+ start_time = time.perf_counter()
|
|
|
|
|
|
query_terms = self._tokenize(query)
|
|
|
if not query_terms:
|
|
|
@@ -400,11 +521,24 @@ class RAGAdapter:
|
|
|
# 获取文档内容
|
|
|
results = []
|
|
|
for chunk_id, score in doc_scores.items():
|
|
|
- cursor.execute("""
|
|
|
- SELECT chapter, scene_index, content
|
|
|
- FROM vectors
|
|
|
- WHERE chunk_id = ?
|
|
|
- """, (chunk_id,))
|
|
|
+ if chunk_type:
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
|
|
|
+ FROM vectors
|
|
|
+ WHERE chunk_id = ? AND chunk_type = ?
|
|
|
+ """,
|
|
|
+ (chunk_id, chunk_type),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ cursor.execute(
|
|
|
+ """
|
|
|
+ SELECT chapter, scene_index, content, parent_chunk_id, chunk_type, source_file
|
|
|
+ FROM vectors
|
|
|
+ WHERE chunk_id = ?
|
|
|
+ """,
|
|
|
+ (chunk_id,),
|
|
|
+ )
|
|
|
row = cursor.fetchone()
|
|
|
if row:
|
|
|
results.append(SearchResult(
|
|
|
@@ -413,11 +547,18 @@ class RAGAdapter:
|
|
|
scene_index=row[1],
|
|
|
content=row[2],
|
|
|
score=score,
|
|
|
- source="bm25"
|
|
|
+ source="bm25",
|
|
|
+ parent_chunk_id=row[3],
|
|
|
+ chunk_type=row[4],
|
|
|
+ source_file=row[5],
|
|
|
))
|
|
|
|
|
|
results.sort(key=lambda x: x.score, reverse=True)
|
|
|
- return results[:top_k]
|
|
|
+ results = results[:top_k]
|
|
|
+ if log_query:
|
|
|
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
+ self._log_query(query, "bm25", results, latency_ms, chapter=chapter)
|
|
|
+ return results
|
|
|
|
|
|
# ==================== 混合检索 ====================
|
|
|
|
|
|
@@ -426,7 +567,9 @@ class RAGAdapter:
|
|
|
query: str,
|
|
|
vector_top_k: int = None,
|
|
|
bm25_top_k: int = None,
|
|
|
- rerank_top_n: int = None
|
|
|
+ rerank_top_n: int = None,
|
|
|
+ chunk_type: str | None = None,
|
|
|
+ log_query: bool = True,
|
|
|
) -> List[SearchResult]:
|
|
|
"""
|
|
|
混合检索:向量 + BM25 + RRF 融合 + Rerank
|
|
|
@@ -440,6 +583,7 @@ class RAGAdapter:
|
|
|
vector_top_k = vector_top_k or self.config.vector_top_k
|
|
|
bm25_top_k = bm25_top_k or self.config.bm25_top_k
|
|
|
rerank_top_n = rerank_top_n or self.config.rerank_top_n
|
|
|
+ start_time = time.perf_counter()
|
|
|
|
|
|
# 小规模:全表向量扫描(召回更稳);大规模:预筛选避免 O(n) 扫描拖慢
|
|
|
vectors_count = await asyncio.to_thread(self._get_vectors_count)
|
|
|
@@ -448,8 +592,8 @@ class RAGAdapter:
|
|
|
if use_full_scan:
|
|
|
# 并行执行向量和 BM25 检索
|
|
|
vector_results, bm25_results = await asyncio.gather(
|
|
|
- self.vector_search(query, vector_top_k),
|
|
|
- asyncio.to_thread(self.bm25_search, query, bm25_top_k)
|
|
|
+ self.vector_search(query, vector_top_k, chunk_type=chunk_type, log_query=False),
|
|
|
+ asyncio.to_thread(self.bm25_search, query, bm25_top_k, 1.5, 0.75, chunk_type, False),
|
|
|
)
|
|
|
else:
|
|
|
bm25_candidates = max(
|
|
|
@@ -464,8 +608,8 @@ class RAGAdapter:
|
|
|
int(rerank_top_n) * 10,
|
|
|
)
|
|
|
|
|
|
- bm25_task = asyncio.to_thread(self.bm25_search, query, bm25_candidates)
|
|
|
- recent_task = asyncio.to_thread(self._get_recent_chunk_ids, recent_candidates)
|
|
|
+ bm25_task = asyncio.to_thread(self.bm25_search, query, bm25_candidates, 1.5, 0.75, chunk_type, False)
|
|
|
+ recent_task = asyncio.to_thread(self._get_recent_chunk_ids, recent_candidates, chunk_type)
|
|
|
embed_task = self.api_client.embed([query])
|
|
|
|
|
|
bm25_candidates_results, recent_ids, query_embeddings = await asyncio.gather(
|
|
|
@@ -482,6 +626,8 @@ class RAGAdapter:
|
|
|
candidate_ids.update(recent_ids)
|
|
|
|
|
|
rows = await asyncio.to_thread(self._fetch_vectors_by_chunk_ids, list(candidate_ids))
|
|
|
+ if chunk_type:
|
|
|
+ rows = [r for r in rows if len(r) > 6 and r[6] == chunk_type]
|
|
|
vector_results = await asyncio.to_thread(
|
|
|
self._vector_search_rows,
|
|
|
query_embedding,
|
|
|
@@ -517,7 +663,11 @@ class RAGAdapter:
|
|
|
candidates = [item["result"] for item in sorted_results[:rerank_top_n * 2]]
|
|
|
|
|
|
if not candidates:
|
|
|
- return []
|
|
|
+ final_results: List[SearchResult] = []
|
|
|
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
+ if log_query:
|
|
|
+ self._log_query(query, "hybrid", final_results, latency_ms)
|
|
|
+ return final_results
|
|
|
|
|
|
# 调用 Rerank API
|
|
|
documents = [c.content for c in candidates]
|
|
|
@@ -525,7 +675,11 @@ class RAGAdapter:
|
|
|
|
|
|
if not rerank_results:
|
|
|
# Rerank 失败,返回 RRF 结果
|
|
|
- return [item["result"] for item in sorted_results[:rerank_top_n]]
|
|
|
+ final_results = [item["result"] for item in sorted_results[:rerank_top_n]]
|
|
|
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
+ if log_query:
|
|
|
+ self._log_query(query, "hybrid", final_results, latency_ms)
|
|
|
+ return final_results
|
|
|
|
|
|
# 组装最终结果
|
|
|
final_results = []
|
|
|
@@ -537,8 +691,73 @@ class RAGAdapter:
|
|
|
result.source = "hybrid"
|
|
|
final_results.append(result)
|
|
|
|
|
|
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
+ if log_query:
|
|
|
+ self._log_query(query, "hybrid", final_results, latency_ms)
|
|
|
return final_results
|
|
|
|
|
|
+ def _get_chunks_by_ids(self, chunk_ids: List[str]) -> List[SearchResult]:
|
|
|
+ rows = self._fetch_vectors_by_chunk_ids(chunk_ids)
|
|
|
+ results: List[SearchResult] = []
|
|
|
+ for row in rows:
|
|
|
+ (
|
|
|
+ chunk_id,
|
|
|
+ chapter,
|
|
|
+ scene_index,
|
|
|
+ content,
|
|
|
+ _embedding_bytes,
|
|
|
+ parent_chunk_id,
|
|
|
+ chunk_type,
|
|
|
+ source_file,
|
|
|
+ ) = row
|
|
|
+ results.append(
|
|
|
+ SearchResult(
|
|
|
+ chunk_id=chunk_id,
|
|
|
+ chapter=chapter,
|
|
|
+ scene_index=scene_index,
|
|
|
+ content=content,
|
|
|
+ score=0.0,
|
|
|
+ source="parent",
|
|
|
+ parent_chunk_id=parent_chunk_id,
|
|
|
+ chunk_type=chunk_type,
|
|
|
+ source_file=source_file,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ return results
|
|
|
+
|
|
|
+ def _merge_results(
|
|
|
+ self,
|
|
|
+ parents: List[SearchResult],
|
|
|
+ children: List[SearchResult],
|
|
|
+ ) -> List[SearchResult]:
|
|
|
+ parent_map = {p.chunk_id: p for p in parents}
|
|
|
+ merged: List[SearchResult] = []
|
|
|
+ seen = set()
|
|
|
+ for child in children:
|
|
|
+ parent_id = child.parent_chunk_id
|
|
|
+ if parent_id and parent_id in parent_map and parent_id not in seen:
|
|
|
+ merged.append(parent_map[parent_id])
|
|
|
+ seen.add(parent_id)
|
|
|
+ merged.append(child)
|
|
|
+ return merged
|
|
|
+
|
|
|
+ async def search_with_backtrack(self, query: str, top_k: int = 5) -> List[SearchResult]:
|
|
|
+ start_time = time.perf_counter()
|
|
|
+ child_results = await self.hybrid_search(
|
|
|
+ query,
|
|
|
+ vector_top_k=top_k * 2,
|
|
|
+ bm25_top_k=top_k * 2,
|
|
|
+ rerank_top_n=top_k,
|
|
|
+ chunk_type="scene",
|
|
|
+ log_query=False,
|
|
|
+ )
|
|
|
+ parent_ids = sorted({r.parent_chunk_id for r in child_results if r.parent_chunk_id})
|
|
|
+ parents = self._get_chunks_by_ids(parent_ids) if parent_ids else []
|
|
|
+ merged = self._merge_results(parents, child_results[:top_k])
|
|
|
+ latency_ms = int((time.perf_counter() - start_time) * 1000)
|
|
|
+ self._log_query(query, "backtrack", merged, latency_ms)
|
|
|
+ return merged
|
|
|
+
|
|
|
# ==================== 统计 ====================
|
|
|
|
|
|
def get_stats(self) -> Dict[str, int]:
|
|
|
@@ -566,6 +785,7 @@ class RAGAdapter:
|
|
|
|
|
|
def main():
|
|
|
import argparse
|
|
|
+ from .cli_output import print_success, print_error
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="RAG Adapter CLI")
|
|
|
parser.add_argument("--project-root", type=str, help="项目根目录")
|
|
|
@@ -575,16 +795,18 @@ def main():
|
|
|
# 获取统计
|
|
|
subparsers.add_parser("stats")
|
|
|
|
|
|
- # 搜索
|
|
|
- search_parser = subparsers.add_parser("search")
|
|
|
- search_parser.add_argument("--query", required=True, help="搜索查询")
|
|
|
- search_parser.add_argument("--mode", choices=["vector", "bm25", "hybrid"], default="hybrid")
|
|
|
- search_parser.add_argument("--top-k", type=int, default=10)
|
|
|
-
|
|
|
- # 索引章节
|
|
|
+ # 写入索引
|
|
|
index_parser = subparsers.add_parser("index-chapter")
|
|
|
index_parser.add_argument("--chapter", type=int, required=True)
|
|
|
index_parser.add_argument("--scenes", required=True, help="JSON 格式的场景列表")
|
|
|
+ index_parser.add_argument("--summary", required=False, help="章节摘要文本")
|
|
|
+
|
|
|
+ # 搜索
|
|
|
+ search_parser = subparsers.add_parser("search")
|
|
|
+ search_parser.add_argument("--query", required=True)
|
|
|
+ search_parser.add_argument("--mode", choices=["vector", "bm25", "hybrid", "backtrack"], default="hybrid")
|
|
|
+ search_parser.add_argument("--top-k", type=int, default=5)
|
|
|
+ search_parser.add_argument("--chunk-type", choices=["scene", "summary"], default=None)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
@@ -592,46 +814,88 @@ def main():
|
|
|
config = None
|
|
|
if args.project_root:
|
|
|
from .config import DataModulesConfig
|
|
|
+
|
|
|
config = DataModulesConfig.from_project_root(args.project_root)
|
|
|
|
|
|
adapter = RAGAdapter(config)
|
|
|
+ tool_name = f"rag_adapter:{args.command or 'unknown'}"
|
|
|
+
|
|
|
+ def emit_success(data=None, message: str = "ok"):
|
|
|
+ print_success(data, message=message)
|
|
|
+ try:
|
|
|
+ adapter.index_manager.log_tool_call(tool_name, True)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ def emit_error(code: str, message: str, suggestion: str | None = None):
|
|
|
+ print_error(code, message, suggestion=suggestion)
|
|
|
+ try:
|
|
|
+ adapter.index_manager.log_tool_call(tool_name, False, error_code=code, error_message=message)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
|
|
|
if args.command == "stats":
|
|
|
stats = adapter.get_stats()
|
|
|
- print(json.dumps(stats, ensure_ascii=False, indent=2))
|
|
|
+ emit_success(stats, message="stats")
|
|
|
|
|
|
- elif args.command == "search":
|
|
|
- async def do_search():
|
|
|
- if args.mode == "vector":
|
|
|
- results = await adapter.vector_search(args.query, args.top_k)
|
|
|
- elif args.mode == "bm25":
|
|
|
- results = adapter.bm25_search(args.query, args.top_k)
|
|
|
- else:
|
|
|
- results = await adapter.hybrid_search(args.query)
|
|
|
+ elif args.command == "index-chapter":
|
|
|
+ scenes = json.loads(args.scenes)
|
|
|
+ chunks = []
|
|
|
+
|
|
|
+ # summary chunk
|
|
|
+ summary_text = args.summary
|
|
|
+ if not summary_text and config:
|
|
|
+ summary_path = config.webnovel_dir / "summaries" / f"ch{args.chapter:04d}.md"
|
|
|
+ if summary_path.exists():
|
|
|
+ summary_text = summary_path.read_text(encoding="utf-8")
|
|
|
+
|
|
|
+ parent_chunk_id = None
|
|
|
+ if summary_text:
|
|
|
+ parent_chunk_id = f"ch{args.chapter:04d}_summary"
|
|
|
+ chunks.append(
|
|
|
+ {
|
|
|
+ "chapter": args.chapter,
|
|
|
+ "scene_index": 0,
|
|
|
+ "content": summary_text,
|
|
|
+ "chunk_type": "summary",
|
|
|
+ "chunk_id": parent_chunk_id,
|
|
|
+ "source_file": f"summaries/ch{args.chapter:04d}.md",
|
|
|
+ }
|
|
|
+ )
|
|
|
|
|
|
- print(f"搜索结果 ({len(results)} 条):")
|
|
|
- for r in results:
|
|
|
- print(f"\n[{r.source}] 第 {r.chapter} 章 场景 {r.scene_index} (score: {r.score:.4f})")
|
|
|
- print(f" {r.content[:100]}...")
|
|
|
+ for s in scenes:
|
|
|
+ scene_index = s.get("index", 0)
|
|
|
+ chunk_id = f"ch{args.chapter:04d}_s{int(scene_index)}"
|
|
|
+ chunks.append(
|
|
|
+ {
|
|
|
+ "chapter": args.chapter,
|
|
|
+ "scene_index": scene_index,
|
|
|
+ "content": s.get("content", ""),
|
|
|
+ "chunk_type": "scene",
|
|
|
+ "parent_chunk_id": parent_chunk_id,
|
|
|
+ "chunk_id": chunk_id,
|
|
|
+ "source_file": f"正文/第{args.chapter:04d}章.md#scene_{int(scene_index)}",
|
|
|
+ }
|
|
|
+ )
|
|
|
|
|
|
- asyncio.run(do_search())
|
|
|
+ stored = asyncio.run(adapter.store_chunks(chunks))
|
|
|
+ emit_success({"stored": stored, "chunks": len(chunks)}, message="indexed")
|
|
|
|
|
|
- elif args.command == "index-chapter":
|
|
|
- scenes = json.loads(args.scenes)
|
|
|
- chunks = [
|
|
|
- {
|
|
|
- "chapter": args.chapter,
|
|
|
- "scene_index": s.get("index", i),
|
|
|
- "content": s.get("summary", "") + "\n" + s.get("content", "")
|
|
|
- }
|
|
|
- for i, s in enumerate(scenes)
|
|
|
- ]
|
|
|
+ elif args.command == "search":
|
|
|
+ if args.mode == "vector":
|
|
|
+ results = asyncio.run(adapter.vector_search(args.query, args.top_k, chunk_type=args.chunk_type))
|
|
|
+ elif args.mode == "bm25":
|
|
|
+ results = adapter.bm25_search(args.query, args.top_k, chunk_type=args.chunk_type)
|
|
|
+ elif args.mode == "backtrack":
|
|
|
+ results = asyncio.run(adapter.search_with_backtrack(args.query, args.top_k))
|
|
|
+ else:
|
|
|
+ results = asyncio.run(adapter.hybrid_search(args.query, args.top_k, args.top_k, args.top_k, chunk_type=args.chunk_type))
|
|
|
|
|
|
- async def do_index():
|
|
|
- stored = await adapter.store_chunks(chunks)
|
|
|
- print(f"✓ 已索引 {stored} 个场景")
|
|
|
+ payload = [r.__dict__ for r in results]
|
|
|
+ emit_success(payload, message="search_results")
|
|
|
|
|
|
- asyncio.run(do_index())
|
|
|
+ else:
|
|
|
+ emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|