style_sampler.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Style Sampler - 风格样本管理模块
  5. 管理高质量章节片段作为风格参考:
  6. - 风格样本存储
  7. - 按场景类型分类
  8. - 样本选择策略
  9. """
  10. import json
  11. import sqlite3
  12. import time
  13. from pathlib import Path
  14. from typing import Dict, List, Optional, Any
  15. from dataclasses import dataclass, asdict
  16. from datetime import datetime
  17. from enum import Enum
  18. from contextlib import contextmanager
  19. from .config import get_config
  20. from .observability import safe_append_perf_timing, safe_log_tool_call
  21. class SceneType(Enum):
  22. """场景类型"""
  23. BATTLE = "战斗"
  24. DIALOGUE = "对话"
  25. DESCRIPTION = "描写"
  26. TRANSITION = "过渡"
  27. EMOTION = "情感"
  28. TENSION = "紧张"
  29. COMEDY = "轻松"
  30. @dataclass
  31. class StyleSample:
  32. """风格样本"""
  33. id: str
  34. chapter: int
  35. scene_type: str
  36. content: str
  37. score: float
  38. tags: List[str]
  39. created_at: str = ""
  40. class StyleSampler:
  41. """风格样本管理器"""
  42. def __init__(self, config=None):
  43. self.config = config or get_config()
  44. self._init_db()
  45. def _init_db(self):
  46. """初始化数据库"""
  47. self.config.ensure_dirs()
  48. with self._get_conn() as conn:
  49. cursor = conn.cursor()
  50. cursor.execute("""
  51. CREATE TABLE IF NOT EXISTS samples (
  52. id TEXT PRIMARY KEY,
  53. chapter INTEGER,
  54. scene_type TEXT,
  55. content TEXT,
  56. score REAL,
  57. tags TEXT,
  58. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  59. )
  60. """)
  61. cursor.execute("CREATE INDEX IF NOT EXISTS idx_samples_type ON samples(scene_type)")
  62. cursor.execute("CREATE INDEX IF NOT EXISTS idx_samples_score ON samples(score DESC)")
  63. conn.commit()
  64. @contextmanager
  65. def _get_conn(self):
  66. """获取数据库连接(确保关闭,避免 Windows 下文件句柄泄漏导致无法清理临时目录)"""
  67. db_path = self.config.webnovel_dir / "style_samples.db"
  68. conn = sqlite3.connect(str(db_path))
  69. try:
  70. yield conn
  71. finally:
  72. conn.close()
  73. # ==================== 样本管理 ====================
  74. def add_sample(self, sample: StyleSample) -> bool:
  75. """添加风格样本"""
  76. with self._get_conn() as conn:
  77. cursor = conn.cursor()
  78. try:
  79. cursor.execute("""
  80. INSERT INTO samples
  81. (id, chapter, scene_type, content, score, tags, created_at)
  82. VALUES (?, ?, ?, ?, ?, ?, ?)
  83. """, (
  84. sample.id,
  85. sample.chapter,
  86. sample.scene_type,
  87. sample.content,
  88. sample.score,
  89. json.dumps(sample.tags, ensure_ascii=False),
  90. sample.created_at or datetime.now().isoformat()
  91. ))
  92. conn.commit()
  93. return True
  94. except sqlite3.IntegrityError:
  95. return False
  96. def get_samples_by_type(
  97. self,
  98. scene_type: str,
  99. limit: int = 5,
  100. min_score: float = 0.0
  101. ) -> List[StyleSample]:
  102. """按场景类型获取样本"""
  103. with self._get_conn() as conn:
  104. cursor = conn.cursor()
  105. cursor.execute("""
  106. SELECT id, chapter, scene_type, content, score, tags, created_at
  107. FROM samples
  108. WHERE scene_type = ? AND score >= ?
  109. ORDER BY score DESC
  110. LIMIT ?
  111. """, (scene_type, min_score, limit))
  112. return [self._row_to_sample(row) for row in cursor.fetchall()]
  113. def get_best_samples(self, limit: int = 10) -> List[StyleSample]:
  114. """获取最高分样本"""
  115. with self._get_conn() as conn:
  116. cursor = conn.cursor()
  117. cursor.execute("""
  118. SELECT id, chapter, scene_type, content, score, tags, created_at
  119. FROM samples
  120. ORDER BY score DESC
  121. LIMIT ?
  122. """, (limit,))
  123. return [self._row_to_sample(row) for row in cursor.fetchall()]
  124. def _safe_tags(self, raw) -> List[str]:
  125. if not raw:
  126. return []
  127. try:
  128. value = json.loads(raw)
  129. except (TypeError, json.JSONDecodeError):
  130. return []
  131. return value if isinstance(value, list) else []
  132. def _row_to_sample(self, row) -> StyleSample:
  133. """将数据库行转换为样本对象"""
  134. return StyleSample(
  135. id=row[0],
  136. chapter=row[1],
  137. scene_type=row[2],
  138. content=row[3],
  139. score=row[4],
  140. tags=self._safe_tags(row[5]),
  141. created_at=row[6]
  142. )
  143. # ==================== 样本提取 ====================
  144. def extract_candidates(
  145. self,
  146. chapter: int,
  147. content: str,
  148. review_score: float,
  149. scenes: List[Dict]
  150. ) -> List[StyleSample]:
  151. """
  152. 从章节中提取风格样本候选
  153. 只有高分章节 (review_score >= 80) 才提取样本
  154. """
  155. if review_score < 80:
  156. return []
  157. candidates = []
  158. for scene in scenes:
  159. scene_type = self._classify_scene_type(scene)
  160. scene_content = scene.get("content", "")
  161. # 跳过过短的场景
  162. if len(scene_content) < 200:
  163. continue
  164. # 创建样本
  165. sample = StyleSample(
  166. id=f"ch{chapter}_s{scene.get('index', 0)}",
  167. chapter=chapter,
  168. scene_type=scene_type,
  169. content=scene_content[:2000], # 限制长度
  170. score=review_score / 100.0,
  171. tags=self._extract_tags(scene_content)
  172. )
  173. candidates.append(sample)
  174. return candidates
  175. def _classify_scene_type(self, scene: Dict) -> str:
  176. """分类场景类型"""
  177. summary = scene.get("summary", "").lower()
  178. content = scene.get("content", "").lower()
  179. # 简单关键词分类
  180. battle_keywords = ["战斗", "攻击", "出手", "拳", "剑", "杀", "打", "斗"]
  181. dialogue_keywords = ["说道", "问道", "笑道", "冷声", "对话"]
  182. emotion_keywords = ["心中", "感觉", "情", "泪", "痛", "喜"]
  183. tension_keywords = ["危险", "紧张", "恐惧", "压力"]
  184. text = summary + content
  185. if any(kw in text for kw in battle_keywords):
  186. return SceneType.BATTLE.value
  187. elif any(kw in text for kw in tension_keywords):
  188. return SceneType.TENSION.value
  189. elif any(kw in text for kw in dialogue_keywords):
  190. return SceneType.DIALOGUE.value
  191. elif any(kw in text for kw in emotion_keywords):
  192. return SceneType.EMOTION.value
  193. else:
  194. return SceneType.DESCRIPTION.value
  195. def _extract_tags(self, content: str) -> List[str]:
  196. """提取内容标签"""
  197. tags = []
  198. # 简单标签提取
  199. if "战斗" in content or "攻击" in content:
  200. tags.append("战斗")
  201. if "修炼" in content or "突破" in content:
  202. tags.append("修炼")
  203. if "对话" in content or "说道" in content:
  204. tags.append("对话")
  205. if "描写" in content or "景色" in content:
  206. tags.append("描写")
  207. return tags[:5]
  208. # ==================== 样本选择 ====================
  209. def select_samples_for_chapter(
  210. self,
  211. chapter_outline: str,
  212. target_types: List[str] = None,
  213. max_samples: int = 3
  214. ) -> List[StyleSample]:
  215. """
  216. 为章节写作选择合适的风格样本
  217. 基于大纲分析需要什么类型的样本
  218. """
  219. if target_types is None:
  220. # 根据大纲推断需要的场景类型
  221. target_types = self._infer_scene_types(chapter_outline)
  222. samples = []
  223. per_type = max(1, max_samples // len(target_types)) if target_types else max_samples
  224. for scene_type in target_types:
  225. type_samples = self.get_samples_by_type(scene_type, limit=per_type, min_score=0.8)
  226. samples.extend(type_samples)
  227. return samples[:max_samples]
  228. def _infer_scene_types(self, outline: str) -> List[str]:
  229. """从大纲推断需要的场景类型"""
  230. types = []
  231. if any(kw in outline for kw in ["战斗", "对决", "比试", "交手"]):
  232. types.append(SceneType.BATTLE.value)
  233. if any(kw in outline for kw in ["对话", "谈话", "商议", "讨论"]):
  234. types.append(SceneType.DIALOGUE.value)
  235. if any(kw in outline for kw in ["情感", "感情", "心理"]):
  236. types.append(SceneType.EMOTION.value)
  237. if not types:
  238. types = [SceneType.DESCRIPTION.value]
  239. return types
  240. # ==================== 统计 ====================
  241. def get_stats(self) -> Dict[str, Any]:
  242. """获取样本统计"""
  243. with self._get_conn() as conn:
  244. cursor = conn.cursor()
  245. cursor.execute("SELECT COUNT(*) FROM samples")
  246. total = cursor.fetchone()[0]
  247. cursor.execute("""
  248. SELECT scene_type, COUNT(*) as count
  249. FROM samples
  250. GROUP BY scene_type
  251. """)
  252. by_type = {row[0]: row[1] for row in cursor.fetchall()}
  253. cursor.execute("SELECT AVG(score) FROM samples")
  254. avg_score = cursor.fetchone()[0] or 0
  255. return {
  256. "total": total,
  257. "by_type": by_type,
  258. "avg_score": round(avg_score, 3)
  259. }
  260. # ==================== CLI 接口 ====================
  261. def main():
  262. import argparse
  263. import sys
  264. from .cli_output import print_success, print_error
  265. from .cli_args import normalize_global_project_root, load_json_arg
  266. from .index_manager import IndexManager
  267. parser = argparse.ArgumentParser(description="Style Sampler CLI")
  268. parser.add_argument("--project-root", type=str, help="项目根目录")
  269. subparsers = parser.add_subparsers(dest="command")
  270. # 获取统计
  271. subparsers.add_parser("stats")
  272. # 列出样本
  273. list_parser = subparsers.add_parser("list")
  274. list_parser.add_argument("--type", help="按类型过滤")
  275. list_parser.add_argument("--limit", type=int, default=10)
  276. # 提取样本
  277. extract_parser = subparsers.add_parser("extract")
  278. extract_parser.add_argument("--chapter", type=int, required=True)
  279. extract_parser.add_argument("--score", type=float, required=True)
  280. extract_parser.add_argument("--scenes", required=True, help="JSON 格式的场景列表")
  281. # 选择样本
  282. select_parser = subparsers.add_parser("select")
  283. select_parser.add_argument("--outline", required=True, help="章节大纲")
  284. select_parser.add_argument("--max", type=int, default=3)
  285. argv = normalize_global_project_root(sys.argv[1:])
  286. args = parser.parse_args(argv)
  287. command_started_at = time.perf_counter()
  288. # 初始化
  289. config = None
  290. if args.project_root:
  291. # 允许传入“工作区根目录”,统一解析到真正的 book project_root(必须包含 .webnovel/state.json)
  292. from project_locator import resolve_project_root
  293. from .config import DataModulesConfig
  294. resolved_root = resolve_project_root(args.project_root)
  295. config = DataModulesConfig.from_project_root(resolved_root)
  296. sampler = StyleSampler(config)
  297. logger = IndexManager(config)
  298. tool_name = f"style_sampler:{args.command or 'unknown'}"
  299. def _append_timing(success: bool, *, error_code: str | None = None, error_message: str | None = None, chapter: int | None = None):
  300. elapsed_ms = int((time.perf_counter() - command_started_at) * 1000)
  301. safe_append_perf_timing(
  302. sampler.config.project_root,
  303. tool_name=tool_name,
  304. success=success,
  305. elapsed_ms=elapsed_ms,
  306. chapter=chapter,
  307. error_code=error_code,
  308. error_message=error_message,
  309. )
  310. def emit_success(data=None, message: str = "ok", chapter: int | None = None):
  311. print_success(data, message=message)
  312. safe_log_tool_call(logger, tool_name=tool_name, success=True)
  313. _append_timing(True, chapter=chapter)
  314. def emit_error(code: str, message: str, suggestion: str | None = None, chapter: int | None = None):
  315. print_error(code, message, suggestion=suggestion)
  316. safe_log_tool_call(
  317. logger,
  318. tool_name=tool_name,
  319. success=False,
  320. error_code=code,
  321. error_message=message,
  322. )
  323. _append_timing(False, error_code=code, error_message=message, chapter=chapter)
  324. if args.command == "stats":
  325. stats = sampler.get_stats()
  326. emit_success(stats, message="stats")
  327. elif args.command == "list":
  328. if args.type:
  329. samples = sampler.get_samples_by_type(args.type, args.limit)
  330. else:
  331. samples = sampler.get_best_samples(args.limit)
  332. emit_success([s.__dict__ for s in samples], message="samples")
  333. elif args.command == "extract":
  334. scenes = load_json_arg(args.scenes, base_dir=sampler.config.project_root)
  335. candidates = sampler.extract_candidates(
  336. chapter=args.chapter,
  337. content="",
  338. review_score=args.score,
  339. scenes=scenes,
  340. )
  341. added = []
  342. skipped = []
  343. for c in candidates:
  344. if sampler.add_sample(c):
  345. added.append(c.id)
  346. else:
  347. skipped.append(c.id)
  348. emit_success({"added": added, "skipped": skipped}, message="extracted", chapter=args.chapter)
  349. elif args.command == "select":
  350. samples = sampler.select_samples_for_chapter(args.outline, max_samples=args.max)
  351. emit_success([s.__dict__ for s in samples], message="selected")
  352. else:
  353. emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")
  354. if __name__ == "__main__":
  355. main()