reference_search.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Reference CSV 检索工具。
  5. 在 references/csv/ 目录下的 CSV 文件中执行 BM25 关键词搜索,
  6. 支持按技能、题材过滤,返回 JSON 格式结果。
  7. 用法:
  8. python reference_search.py --skill write --query "角色命名" --genre 玄幻
  9. python reference_search.py --skill write --table 命名规则 --query "战斗描写" --max-results 3
  10. """
  11. from __future__ import annotations
  12. import argparse
  13. import csv
  14. import json
  15. import math
  16. import re
  17. import sys
  18. from pathlib import Path
  19. from typing import Any, Dict, List, Optional
  20. from genre_taxonomy import GENRE_CANONICAL, resolve_canonical_genre
  21. # ---------------------------------------------------------------------------
  22. # CSV loading
  23. # ---------------------------------------------------------------------------
  24. def _load_csv(path: Path) -> List[Dict[str, str]]:
  25. """Load a single CSV file (UTF-8 with BOM)."""
  26. with open(path, "r", encoding="utf-8-sig", newline="") as f:
  27. reader = csv.DictReader(f)
  28. return list(reader)
  29. def load_tables(csv_dir: Path, table: Optional[str] = None) -> Dict[str, List[Dict[str, str]]]:
  30. """
  31. Load CSV tables from *csv_dir*.
  32. If *table* is given, load only that file (``<table>.csv``).
  33. Otherwise load every ``.csv`` file in the directory.
  34. Returns ``{table_name: [row_dict, ...]}``.
  35. """
  36. tables: Dict[str, List[Dict[str, str]]] = {}
  37. if table:
  38. target = csv_dir / f"{table}.csv"
  39. if target.is_file():
  40. tables[table] = _load_csv(target)
  41. else:
  42. for p in sorted(csv_dir.glob("*.csv")):
  43. tables[p.stem] = _load_csv(p)
  44. return tables
  45. # ---------------------------------------------------------------------------
  46. # Filtering
  47. # ---------------------------------------------------------------------------
  48. _MULTI_VALUE_SPLIT_RE = re.compile(r"[|,,、;;]+")
  49. _INTERNAL_TABLE_ROLES = {"route", "reasoning"}
  50. def split_multi_value(cell: Any) -> List[str]:
  51. """Split list-like cells while remaining compatible with legacy comma data."""
  52. if not cell:
  53. return []
  54. return [part.strip() for part in _MULTI_VALUE_SPLIT_RE.split(str(cell)) if part.strip()]
  55. def _split_multi_value(cell: Any) -> List[str]:
  56. return split_multi_value(cell)
  57. def _skill_matches(row: Dict[str, str], skill: str) -> bool:
  58. """Return True if *skill* appears in the pipe-separated ``适用技能`` column."""
  59. return skill in _split_multi_value(row.get("适用技能", ""))
  60. def _genre_matches(row: Dict[str, str], genre: Optional[str]) -> bool:
  61. """Return True if *genre* is None, or matches ``适用题材`` (``全部`` always matches).
  62. Both the input *genre* and the cell values are resolved to canonical form
  63. before comparison, so platform tags and legacy values work transparently.
  64. """
  65. if genre is None:
  66. return True
  67. cell = row.get("适用题材", "")
  68. if cell.strip() == "全部":
  69. return True
  70. requested_genres = [
  71. resolved
  72. for raw in _split_multi_value(genre)
  73. for resolved in [resolve_genre(raw)]
  74. if resolved
  75. ]
  76. cell_genres = [resolve_genre(v) for v in _split_multi_value(cell)]
  77. return any(resolved in cell_genres for resolved in requested_genres)
  78. def _table_visible_for_search(table_name: str, skill: str, explicit_table: bool) -> bool:
  79. """Keep story-system internals out of normal cross-table skill searches."""
  80. if explicit_table or skill == "story-system":
  81. return True
  82. cfg = CSV_CONFIG.get(table_name) or {}
  83. return cfg.get("role") not in _INTERNAL_TABLE_ROLES
  84. # ---------------------------------------------------------------------------
  85. # Genre canonical resolution
  86. # ---------------------------------------------------------------------------
  87. def resolve_genre(genre: Optional[str]) -> Optional[str]:
  88. """Resolve a user-facing genre string to its canonical form.
  89. Accepts canonical genres, platform tags, and legacy values.
  90. Returns the canonical genre string, or the original input if unresolvable.
  91. """
  92. if genre is None:
  93. return None
  94. return resolve_canonical_genre(genre)
  95. # ---------------------------------------------------------------------------
  96. # CSV_CONFIG – per-table metadata registry
  97. # ---------------------------------------------------------------------------
  98. CSV_CONFIG: Dict[str, Dict[str, Any]] = {
  99. "命名规则": {
  100. "file": "命名规则.csv",
  101. "search_cols": {"关键词": 3, "意图与同义词": 4, "核心摘要": 2},
  102. "output_cols": ["编号", "命名对象", "核心摘要", "大模型指令", "详细展开"],
  103. "poison_col": "毒点",
  104. "role": "base",
  105. "contract_inject": "MASTER_SETTING.base_context",
  106. "prefix": "NR",
  107. "required_cols": ["编号", "适用技能", "分类", "层级", "关键词", "适用题材", "核心摘要"],
  108. },
  109. "场景写法": {
  110. "file": "场景写法.csv",
  111. "search_cols": {"关键词": 3, "意图与同义词": 4, "核心摘要": 2},
  112. "output_cols": ["编号", "模式名称", "核心摘要", "大模型指令", "详细展开"],
  113. "poison_col": "毒点",
  114. "role": "base",
  115. "contract_inject": "CHAPTER_BRIEF.dynamic_context",
  116. "prefix": "SP",
  117. "required_cols": ["编号", "适用技能", "分类", "层级", "关键词", "适用题材", "核心摘要"],
  118. },
  119. "写作技法": {
  120. "file": "写作技法.csv",
  121. "search_cols": {"关键词": 3, "意图与同义词": 4, "核心摘要": 2},
  122. "output_cols": ["编号", "技法名称", "核心摘要", "大模型指令", "详细展开"],
  123. "poison_col": "毒点",
  124. "role": "base",
  125. "contract_inject": "CHAPTER_BRIEF.dynamic_context",
  126. "prefix": "WT",
  127. "required_cols": ["编号", "适用技能", "分类", "层级", "关键词", "适用题材", "核心摘要"],
  128. },
  129. "桥段套路": {
  130. "file": "桥段套路.csv",
  131. "search_cols": {"关键词": 3, "意图与同义词": 4, "核心摘要": 2},
  132. "output_cols": ["编号", "桥段名称", "核心摘要", "大模型指令", "详细展开"],
  133. "poison_col": "毒点",
  134. "role": "dynamic",
  135. "contract_inject": "CHAPTER_BRIEF.dynamic_context",
  136. "prefix": "TR",
  137. "required_cols": ["编号", "适用技能", "分类", "层级", "关键词", "适用题材", "核心摘要"],
  138. },
  139. "爽点与节奏": {
  140. "file": "爽点与节奏.csv",
  141. "search_cols": {"关键词": 3, "意图与同义词": 4, "核心摘要": 2},
  142. "output_cols": ["编号", "节奏类型", "核心摘要", "大模型指令", "详细展开"],
  143. "poison_col": "毒点",
  144. "role": "dynamic",
  145. "contract_inject": "CHAPTER_BRIEF.dynamic_context",
  146. "prefix": "PA",
  147. "required_cols": ["编号", "适用技能", "分类", "层级", "关键词", "适用题材", "核心摘要"],
  148. },
  149. "人设与关系": {
  150. "file": "人设与关系.csv",
  151. "search_cols": {"关键词": 3, "意图与同义词": 4, "核心摘要": 2},
  152. "output_cols": ["编号", "人设类型", "核心摘要", "大模型指令", "详细展开"],
  153. "poison_col": "毒点",
  154. "role": "base",
  155. "contract_inject": "MASTER_SETTING.base_context",
  156. "prefix": "CH",
  157. "required_cols": ["编号", "适用技能", "分类", "层级", "关键词", "适用题材", "核心摘要"],
  158. },
  159. "金手指与设定": {
  160. "file": "金手指与设定.csv",
  161. "search_cols": {"关键词": 3, "意图与同义词": 4, "核心摘要": 2},
  162. "output_cols": ["编号", "设定类型", "核心摘要", "大模型指令", "详细展开"],
  163. "poison_col": "毒点",
  164. "role": "base",
  165. "contract_inject": "MASTER_SETTING.base_context",
  166. "prefix": "SY",
  167. "required_cols": ["编号", "适用技能", "分类", "层级", "关键词", "适用题材", "核心摘要"],
  168. },
  169. "题材与调性推理": {
  170. "file": "题材与调性推理.csv",
  171. "search_cols": {"关键词": 3, "意图与同义词": 4, "题材别名": 3},
  172. "output_cols": ["编号", "题材/流派", "canonical_genre", "核心调性", "推荐基础检索表", "推荐动态检索表"],
  173. "poison_col": "毒点",
  174. "role": "route",
  175. "contract_inject": "MASTER_SETTING.route",
  176. "prefix": "GR",
  177. "required_cols": ["编号", "适用技能", "题材/流派", "canonical_genre", "核心调性", "推荐基础检索表", "推荐动态检索表"],
  178. },
  179. "裁决规则": {
  180. "file": "裁决规则.csv",
  181. "search_cols": {"题材": 4},
  182. "output_cols": ["题材", "风格优先级", "爽点优先级", "节奏默认策略",
  183. "毒点权重", "冲突裁决", "contract注入层", "反模式"],
  184. "poison_col": "",
  185. "role": "reasoning",
  186. "contract_inject": "CHAPTER_BRIEF.writing_guidance",
  187. "prefix": "RS",
  188. "required_cols": ["编号", "题材", "风格优先级", "爽点优先级", "节奏默认策略", "冲突裁决"],
  189. },
  190. }
  191. # ---------------------------------------------------------------------------
  192. # BM25-lite scoring
  193. # ---------------------------------------------------------------------------
  194. _TOKEN_SPLIT_RE = re.compile(r"[\s|,,、/;;::()()【】\[\]<>《》""\"'''!?!?。…]+")
  195. _DEFAULT_SEARCH_WEIGHTS = {
  196. "意图与同义词": 4,
  197. "关键词": 3,
  198. "核心摘要": 2,
  199. "详细展开": 1,
  200. }
  201. def _tokenize(text: str) -> List[str]:
  202. """Split text into reusable search terms without requiring a segmenter."""
  203. if not text:
  204. return []
  205. tokens: List[str] = []
  206. for part in _TOKEN_SPLIT_RE.split(text):
  207. token = part.strip()
  208. if not token:
  209. continue
  210. # 过滤 don't -> t 这类单字符英文噪声,避免触发子串兜底误召回。
  211. if len(token) == 1 and token.isascii():
  212. continue
  213. tokens.append(token)
  214. return tokens
  215. def _build_doc_terms(row: Dict[str, str], search_weights: Optional[Dict[str, int]] = None) -> List[str]:
  216. """Build weighted BM25 terms from the configured search fields."""
  217. weights = search_weights or _DEFAULT_SEARCH_WEIGHTS
  218. terms: List[str] = []
  219. for field, weight in weights.items():
  220. field_terms = _tokenize(row.get(field, ""))
  221. if not field_terms:
  222. continue
  223. terms.extend(field_terms * weight)
  224. return terms
  225. def _bm25_score(query_terms: List[str], doc_terms: List[str],
  226. avg_dl: float, k1: float = 1.5, b: float = 0.75,
  227. idf_map: Optional[Dict[str, float]] = None) -> float:
  228. """
  229. Simplified BM25 score for a single document.
  230. *idf_map* maps each query term to its IDF value.
  231. """
  232. if not doc_terms:
  233. return 0.0
  234. dl = len(doc_terms)
  235. score = 0.0
  236. tf_map: Dict[str, int] = {}
  237. for t in doc_terms:
  238. tf_map[t] = tf_map.get(t, 0) + 1
  239. for qt in query_terms:
  240. tf = tf_map.get(qt, 0)
  241. if tf == 0:
  242. # Also check substring match (important for Chinese compound words)
  243. for dt in tf_map:
  244. if qt in dt or dt in qt:
  245. tf = max(tf, 1)
  246. break
  247. if tf == 0:
  248. continue
  249. idf = idf_map.get(qt, 1.0) if idf_map else 1.0
  250. numerator = tf * (k1 + 1)
  251. denominator = tf + k1 * (1 - b + b * dl / max(avg_dl, 1))
  252. score += idf * numerator / denominator
  253. return score
  254. def _compute_idf(query_terms: List[str], all_docs: List[List[str]]) -> Dict[str, float]:
  255. """Compute IDF for each query term across all documents."""
  256. n = len(all_docs)
  257. if n == 0:
  258. return {}
  259. idf: Dict[str, float] = {}
  260. for qt in query_terms:
  261. df = 0
  262. for doc in all_docs:
  263. for dt in doc:
  264. if qt in dt or dt in qt:
  265. df += 1
  266. break
  267. # BM25 IDF: log((N - df + 0.5) / (df + 0.5) + 1)
  268. idf[qt] = math.log((n - df + 0.5) / (df + 0.5) + 1)
  269. return idf
  270. # ---------------------------------------------------------------------------
  271. # Content summary builder
  272. # ---------------------------------------------------------------------------
  273. # Hardcoded fallback columns when no CSV_CONFIG entry exists.
  274. _FALLBACK_CONTENT_COLUMNS = [
  275. "技法名称", "桥段名称", "人设类型", "节奏类型", "设定类型",
  276. "规则", "说明", "模式名称",
  277. "常见误区", "前置铺垫", "核心爽点", "转折设计",
  278. "核心动机", "行为逻辑", "互动模式", "忌讳写法",
  279. "情绪调动手法", "常见崩盘误区",
  280. "数值控制边界", "与剧情交互方式",
  281. "正例", "示例片段",
  282. "反例", "反面写法",
  283. "命名对象", "场景类型", "技法类型", "适用场景",
  284. ]
  285. _SUMMARY_SKIP_COLS = {"编号", "大模型指令", "详细展开", "核心摘要"}
  286. def _build_summary(row: Dict[str, str], table_name: Optional[str] = None) -> str:
  287. """Merge key content columns into a single summary string."""
  288. core_summary = row.get("核心摘要", "").strip()
  289. if core_summary:
  290. return core_summary
  291. # Derive fallback columns from CSV_CONFIG if available
  292. tbl_cfg = CSV_CONFIG.get(table_name) if table_name else None
  293. if tbl_cfg:
  294. cols = [c for c in tbl_cfg["output_cols"] if c not in _SUMMARY_SKIP_COLS]
  295. else:
  296. cols = _FALLBACK_CONTENT_COLUMNS
  297. parts: List[str] = []
  298. for col in cols:
  299. val = row.get(col, "").strip()
  300. if val:
  301. parts.append(val)
  302. if parts:
  303. return ";".join(parts)
  304. return row.get("详细展开", "").strip()
  305. # ---------------------------------------------------------------------------
  306. # Search entry point
  307. # ---------------------------------------------------------------------------
  308. def search(
  309. csv_dir: Path,
  310. skill: str,
  311. query: str,
  312. table: Optional[str] = None,
  313. genre: Optional[str] = None,
  314. max_results: int = 5,
  315. ) -> Dict[str, Any]:
  316. """
  317. Run a BM25 keyword search across CSV reference tables.
  318. Returns a result dict suitable for JSON serialisation.
  319. """
  320. if not csv_dir.is_dir():
  321. return {
  322. "status": "error",
  323. "error": {
  324. "code": "CSV_DIR_NOT_FOUND",
  325. "message": f"CSV directory not found: {csv_dir}",
  326. },
  327. }
  328. tables = load_tables(csv_dir, table=table)
  329. if not tables:
  330. return {
  331. "status": "success",
  332. "message": "search_results",
  333. "data": {
  334. "query": query,
  335. "skill": skill,
  336. "genre": genre,
  337. "total": 0,
  338. "results": [],
  339. },
  340. }
  341. # 1) Collect filtered rows with table name annotation
  342. candidates: List[tuple] = [] # (table_name, row)
  343. for tbl_name, rows in tables.items():
  344. if not _table_visible_for_search(tbl_name, skill, explicit_table=table is not None):
  345. continue
  346. for row in rows:
  347. if _skill_matches(row, skill) and _genre_matches(row, genre):
  348. candidates.append((tbl_name, row))
  349. if not candidates:
  350. return {
  351. "status": "success",
  352. "message": "search_results",
  353. "data": {
  354. "query": query,
  355. "skill": skill,
  356. "genre": genre,
  357. "total": 0,
  358. "results": [],
  359. },
  360. }
  361. # 2) Tokenize
  362. query_terms = _tokenize(query)
  363. doc_terms_list = []
  364. for tbl_name, row in candidates:
  365. tbl_cfg = CSV_CONFIG.get(tbl_name)
  366. weights = dict(tbl_cfg["search_cols"]) if tbl_cfg else None
  367. doc_terms_list.append(_build_doc_terms(row, weights))
  368. avg_dl = sum(len(d) for d in doc_terms_list) / len(doc_terms_list) if doc_terms_list else 1.0
  369. idf_map = _compute_idf(query_terms, doc_terms_list)
  370. # 3) Score
  371. scored: List[tuple] = []
  372. for idx, (tbl_name, row) in enumerate(candidates):
  373. score = _bm25_score(query_terms, doc_terms_list[idx], avg_dl, idf_map=idf_map)
  374. if score > 0:
  375. scored.append((score, tbl_name, row))
  376. scored.sort(key=lambda x: x[0], reverse=True)
  377. top = scored[:max_results]
  378. # 4) Format results
  379. results: List[Dict[str, Any]] = []
  380. for _score, tbl_name, row in top:
  381. results.append({
  382. "编号": row.get("编号", ""),
  383. "表": tbl_name,
  384. "分类": row.get("分类", ""),
  385. "层级": row.get("层级", ""),
  386. "适用题材": row.get("适用题材", ""),
  387. "内容摘要": _build_summary(row, table_name=tbl_name),
  388. "大模型指令": row.get("大模型指令", "").strip(),
  389. })
  390. return {
  391. "status": "success",
  392. "message": "search_results",
  393. "data": {
  394. "query": query,
  395. "skill": skill,
  396. "genre": genre,
  397. "total": len(results),
  398. "results": results,
  399. },
  400. }
  401. # ---------------------------------------------------------------------------
  402. # CLI
  403. # ---------------------------------------------------------------------------
  404. def _default_csv_dir() -> Path:
  405. """Auto-detect the csv directory relative to this script's location."""
  406. return Path(__file__).resolve().parent.parent / "references" / "csv"
  407. def main(argv: Optional[List[str]] = None) -> None:
  408. parser = argparse.ArgumentParser(
  409. description="BM25 keyword search over reference CSV files",
  410. )
  411. parser.add_argument("--skill", required=True, help="Filter by 适用技能 column")
  412. parser.add_argument("--table", default=None, help="Target specific CSV file name (without .csv)")
  413. parser.add_argument("--query", required=True, help="BM25 search keywords")
  414. parser.add_argument("--genre", default=None, help="Filter by 适用题材 column")
  415. parser.add_argument("--max-results", type=int, default=5, help="Max results (default 5)")
  416. parser.add_argument("--csv-dir", default=None, help="Override CSV directory path")
  417. args = parser.parse_args(argv)
  418. csv_dir = Path(args.csv_dir) if args.csv_dir else _default_csv_dir()
  419. result = search(
  420. csv_dir=csv_dir,
  421. skill=args.skill,
  422. query=args.query,
  423. table=args.table,
  424. genre=args.genre,
  425. max_results=args.max_results,
  426. )
  427. print(json.dumps(result, ensure_ascii=False))
  428. if __name__ == "__main__":
  429. main()