story_system_engine.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from __future__ import annotations
  4. import csv
  5. import re
  6. from pathlib import Path
  7. from typing import Any, Dict, List, Optional
  8. from reference_search import search as search_reference
  9. from .story_contracts import merge_anti_patterns
  10. ANTI_PATTERN_SOURCE_FIELDS = {
  11. "场景写法": ["毒点"],
  12. "写作技法": ["毒点"],
  13. "爽点与节奏": ["毒点"],
  14. "人设与关系": ["毒点"],
  15. "桥段套路": ["毒点"],
  16. "题材与调性推理": ["毒点"],
  17. "命名规则": ["毒点"],
  18. "金手指与设定": ["毒点"],
  19. }
  20. class StorySystemEngine:
  21. def __init__(self, csv_dir: str | Path):
  22. self.csv_dir = Path(csv_dir)
  23. def build(self, query: str, genre: Optional[str], chapter: Optional[int]) -> Dict[str, Any]:
  24. route = self._route(query=query, genre=genre)
  25. search_query = self._expand_query(query, route.get("default_query", ""))
  26. base_context = self._collect_tables(
  27. search_query,
  28. route["recommended_base_tables"],
  29. genre=route["genre_filter"],
  30. top_k=1,
  31. )
  32. dynamic_context = self._collect_tables(
  33. search_query,
  34. route["recommended_dynamic_tables"],
  35. genre=route["genre_filter"],
  36. top_k=2,
  37. )
  38. source_trace = route["source_trace"] + self._build_source_trace(base_context, dynamic_context)
  39. anti_patterns = merge_anti_patterns(
  40. route["route_anti_patterns"],
  41. self._extract_anti_patterns(base_context),
  42. self._extract_anti_patterns(dynamic_context),
  43. )
  44. return {
  45. "meta": {"query": query, "chapter": chapter, "explicit_genre": genre or ""},
  46. "master_setting": {
  47. "meta": {
  48. "schema_version": "story-system/v1",
  49. "contract_type": "MASTER_SETTING",
  50. "generator_version": "phase1",
  51. "query": query,
  52. },
  53. "route": route["meta"],
  54. "master_constraints": {
  55. "core_tone": route["core_tone"],
  56. "pacing_strategy": route["pacing_strategy"],
  57. },
  58. "base_context": base_context,
  59. "source_trace": source_trace,
  60. "override_policy": {
  61. "locked": ["route.primary_genre", "master_constraints.core_tone"],
  62. "append_only": ["anti_patterns"],
  63. "override_allowed": [],
  64. },
  65. },
  66. "chapter_brief": (
  67. {
  68. "meta": {
  69. "schema_version": "story-system/v1",
  70. "contract_type": "CHAPTER_BRIEF",
  71. "generator_version": "phase1",
  72. "chapter": chapter,
  73. },
  74. "override_allowed": {
  75. "chapter_focus": self._suggest_chapter_focus(query, dynamic_context),
  76. },
  77. "dynamic_context": dynamic_context,
  78. "source_trace": source_trace,
  79. }
  80. if chapter is not None
  81. else None
  82. ),
  83. "anti_patterns": anti_patterns,
  84. }
  85. def _route(self, query: str, genre: Optional[str]) -> Dict[str, Any]:
  86. route_rows = self._load_csv_rows("题材与调性推理")
  87. query_text = self._normalize_text(" ".join([query or "", genre or ""]))
  88. matched = None
  89. route_source = "empty_csv_fallback"
  90. for row in route_rows:
  91. aliases = (
  92. self._split_multi_value(row.get("关键词"))
  93. + self._split_multi_value(row.get("意图与同义词"))
  94. + self._split_multi_value(row.get("题材别名"))
  95. )
  96. if any(alias and self._normalize_text(alias) in query_text for alias in aliases):
  97. matched = row
  98. route_source = "keyword_or_alias_match"
  99. break
  100. if matched is None and genre:
  101. matched = self._fallback_row_for_genre(route_rows, genre)
  102. if matched is not None:
  103. route_source = "explicit_genre_fallback"
  104. if matched is None and route_rows:
  105. matched = route_rows[0]
  106. route_source = "default_seed_fallback"
  107. if matched is None:
  108. return self._empty_route(query=query, genre=genre)
  109. primary_genre = str(matched.get("题材/流派") or genre or "").strip()
  110. genre_filter = str(matched.get("适用题材") or genre or primary_genre).strip()
  111. return {
  112. "meta": {
  113. "primary_genre": primary_genre,
  114. "route_source": route_source,
  115. "genre_filter": genre_filter,
  116. "recommended_base_tables": self._split_multi_value(matched.get("推荐基础检索表")),
  117. "recommended_dynamic_tables": self._split_multi_value(matched.get("推荐动态检索表")),
  118. },
  119. "core_tone": str(matched.get("核心调性") or "").strip(),
  120. "pacing_strategy": str(matched.get("节奏策略") or "").strip(),
  121. "route_anti_patterns": self._extract_route_anti_patterns(matched),
  122. "recommended_base_tables": self._split_multi_value(matched.get("推荐基础检索表")),
  123. "recommended_dynamic_tables": self._split_multi_value(matched.get("推荐动态检索表")),
  124. "genre_filter": genre_filter,
  125. "default_query": str(matched.get("默认查询词") or "").strip(),
  126. "source_trace": [{"table": "题材与调性推理", "id": matched.get("编号", ""), "reason": route_source}],
  127. }
  128. def _collect_tables(self, query: str, tables: List[str], genre: str, top_k: int) -> List[Dict[str, Any]]:
  129. rows: List[Dict[str, Any]] = []
  130. for table_name in tables:
  131. result = search_reference(
  132. csv_dir=self.csv_dir,
  133. skill="write",
  134. query=query,
  135. table=table_name,
  136. genre=genre or None,
  137. max_results=top_k,
  138. )
  139. raw_rows = {str(row.get("编号") or ""): row for row in self._load_csv_rows(table_name)}
  140. for item in result.get("data", {}).get("results", []):
  141. row_id = str(item.get("编号") or "")
  142. full_row = dict(raw_rows.get(row_id) or {})
  143. full_row["_table"] = str(item.get("表") or table_name)
  144. full_row["编号"] = row_id
  145. full_row["核心摘要"] = str(
  146. full_row.get("核心摘要")
  147. or item.get("内容摘要")
  148. or item.get("核心摘要")
  149. or ""
  150. ).strip()
  151. rows.append(full_row)
  152. return rows
  153. def _extract_anti_patterns(self, rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  154. extracted: List[Dict[str, Any]] = []
  155. for row in rows:
  156. table_name = str(row.get("_table") or "")
  157. for field_name in ANTI_PATTERN_SOURCE_FIELDS.get(table_name, []):
  158. for text in self._split_multi_value(row.get(field_name)):
  159. extracted.append(
  160. {
  161. "text": text,
  162. "source_table": table_name,
  163. "source_id": row.get("编号", ""),
  164. }
  165. )
  166. return extracted
  167. def _suggest_chapter_focus(self, query: str, dynamic_rows: List[Dict[str, Any]]) -> str:
  168. for row in dynamic_rows:
  169. summary = str(row.get("核心摘要") or "").strip()
  170. if summary:
  171. return summary
  172. return query
  173. def _build_source_trace(self, *groups: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  174. trace: List[Dict[str, Any]] = []
  175. for group in groups:
  176. for row in group:
  177. trace.append(
  178. {
  179. "table": row.get("_table", ""),
  180. "id": row.get("编号", ""),
  181. "summary": row.get("核心摘要", ""),
  182. }
  183. )
  184. return trace
  185. def _load_csv_rows(self, table_name: str) -> List[Dict[str, Any]]:
  186. csv_path = self.csv_dir / f"{table_name}.csv"
  187. if not csv_path.is_file():
  188. return []
  189. with csv_path.open("r", encoding="utf-8-sig", newline="") as f:
  190. return list(csv.DictReader(f))
  191. def _normalize_text(self, text: str) -> str:
  192. return str(text or "").strip().lower()
  193. def _split_multi_value(self, raw: Any) -> List[str]:
  194. return [item.strip() for item in re.split(r"[|;;]+", str(raw or "")) if item.strip()]
  195. def _expand_query(self, query: str, default_query: str) -> str:
  196. items: List[str] = []
  197. for candidate in [query, *self._split_multi_value(default_query)]:
  198. text = str(candidate or "").strip()
  199. if text and text not in items:
  200. items.append(text)
  201. return " ".join(items)
  202. def _fallback_row_for_genre(self, rows: List[Dict[str, Any]], genre: str) -> Dict[str, Any] | None:
  203. genre_text = self._normalize_text(genre)
  204. for row in rows:
  205. candidates = self._split_multi_value(row.get("适用题材")) + self._split_multi_value(row.get("题材/流派"))
  206. if any(self._normalize_text(candidate) == genre_text for candidate in candidates):
  207. return row
  208. return None
  209. def _extract_route_anti_patterns(self, row: Dict[str, Any]) -> List[Dict[str, Any]]:
  210. return [
  211. {"text": text, "source_table": "题材与调性推理", "source_id": row.get("编号", "")}
  212. for text in self._split_multi_value(row.get("毒点"))
  213. ]
  214. def _empty_route(self, query: str, genre: Optional[str]) -> Dict[str, Any]:
  215. fallback_genre = str(genre or "未命中题材").strip()
  216. route_source = "explicit_genre_fallback" if genre else "empty_csv_fallback"
  217. return {
  218. "meta": {
  219. "primary_genre": fallback_genre,
  220. "route_source": route_source,
  221. "genre_filter": fallback_genre,
  222. "recommended_base_tables": ["命名规则", "人设与关系"],
  223. "recommended_dynamic_tables": ["桥段套路", "爽点与节奏", "场景写法"],
  224. },
  225. "core_tone": "",
  226. "pacing_strategy": "",
  227. "route_anti_patterns": [],
  228. "recommended_base_tables": ["命名规则", "人设与关系"],
  229. "recommended_dynamic_tables": ["桥段套路", "爽点与节奏", "场景写法"],
  230. "genre_filter": fallback_genre,
  231. "default_query": "",
  232. "source_trace": [{"table": "题材与调性推理", "id": "", "reason": f"{route_source}:{query}"}],
  233. }