style_sampler.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Style Sampler - 风格样本管理模块
  5. 管理高质量章节片段作为风格参考:
  6. - 风格样本存储
  7. - 按场景类型分类
  8. - 样本选择策略
  9. """
  10. import json
  11. import sqlite3
  12. import time
  13. from pathlib import Path
  14. from typing import Dict, List, Optional, Any
  15. from dataclasses import dataclass, asdict
  16. from datetime import datetime
  17. from enum import Enum
  18. from contextlib import contextmanager
  19. from .config import get_config
  20. from .observability import safe_append_perf_timing, safe_log_tool_call
  21. class SceneType(Enum):
  22. """场景类型"""
  23. BATTLE = "战斗"
  24. DIALOGUE = "对话"
  25. DESCRIPTION = "描写"
  26. TRANSITION = "过渡"
  27. EMOTION = "情感"
  28. TENSION = "紧张"
  29. COMEDY = "轻松"
  30. @dataclass
  31. class StyleSample:
  32. """风格样本"""
  33. id: str
  34. chapter: int
  35. scene_type: str
  36. content: str
  37. score: float
  38. tags: List[str]
  39. created_at: str = ""
  40. class StyleSampler:
  41. """风格样本管理器"""
  42. def __init__(self, config=None):
  43. self.config = config or get_config()
  44. self._init_db()
  45. def _init_db(self):
  46. """初始化数据库"""
  47. self.config.ensure_dirs()
  48. with self._get_conn() as conn:
  49. cursor = conn.cursor()
  50. cursor.execute("""
  51. CREATE TABLE IF NOT EXISTS samples (
  52. id TEXT PRIMARY KEY,
  53. chapter INTEGER,
  54. scene_type TEXT,
  55. content TEXT,
  56. score REAL,
  57. tags TEXT,
  58. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  59. )
  60. """)
  61. cursor.execute("CREATE INDEX IF NOT EXISTS idx_samples_type ON samples(scene_type)")
  62. cursor.execute("CREATE INDEX IF NOT EXISTS idx_samples_score ON samples(score DESC)")
  63. conn.commit()
  64. @contextmanager
  65. def _get_conn(self):
  66. """获取数据库连接(确保关闭,避免 Windows 下文件句柄泄漏导致无法清理临时目录)"""
  67. db_path = self.config.webnovel_dir / "style_samples.db"
  68. conn = sqlite3.connect(str(db_path))
  69. try:
  70. yield conn
  71. finally:
  72. conn.close()
  73. # ==================== 样本管理 ====================
  74. def add_sample(self, sample: StyleSample) -> bool:
  75. """添加风格样本"""
  76. with self._get_conn() as conn:
  77. cursor = conn.cursor()
  78. try:
  79. cursor.execute("""
  80. INSERT INTO samples
  81. (id, chapter, scene_type, content, score, tags, created_at)
  82. VALUES (?, ?, ?, ?, ?, ?, ?)
  83. """, (
  84. sample.id,
  85. sample.chapter,
  86. sample.scene_type,
  87. sample.content,
  88. sample.score,
  89. json.dumps(sample.tags, ensure_ascii=False),
  90. sample.created_at or datetime.now().isoformat()
  91. ))
  92. conn.commit()
  93. return True
  94. except sqlite3.IntegrityError:
  95. return False
  96. def get_samples_by_type(
  97. self,
  98. scene_type: str,
  99. limit: int = 5,
  100. min_score: float = 0.0
  101. ) -> List[StyleSample]:
  102. """按场景类型获取样本"""
  103. with self._get_conn() as conn:
  104. cursor = conn.cursor()
  105. cursor.execute("""
  106. SELECT id, chapter, scene_type, content, score, tags, created_at
  107. FROM samples
  108. WHERE scene_type = ? AND score >= ?
  109. ORDER BY score DESC
  110. LIMIT ?
  111. """, (scene_type, min_score, limit))
  112. return [self._row_to_sample(row) for row in cursor.fetchall()]
  113. def get_best_samples(self, limit: int = 10) -> List[StyleSample]:
  114. """获取最高分样本"""
  115. with self._get_conn() as conn:
  116. cursor = conn.cursor()
  117. cursor.execute("""
  118. SELECT id, chapter, scene_type, content, score, tags, created_at
  119. FROM samples
  120. ORDER BY score DESC
  121. LIMIT ?
  122. """, (limit,))
  123. return [self._row_to_sample(row) for row in cursor.fetchall()]
  124. def _row_to_sample(self, row) -> StyleSample:
  125. """将数据库行转换为样本对象"""
  126. return StyleSample(
  127. id=row[0],
  128. chapter=row[1],
  129. scene_type=row[2],
  130. content=row[3],
  131. score=row[4],
  132. tags=json.loads(row[5]) if row[5] else [],
  133. created_at=row[6]
  134. )
  135. # ==================== 样本提取 ====================
  136. def extract_candidates(
  137. self,
  138. chapter: int,
  139. content: str,
  140. review_score: float,
  141. scenes: List[Dict]
  142. ) -> List[StyleSample]:
  143. """
  144. 从章节中提取风格样本候选
  145. 只有高分章节 (review_score >= 80) 才提取样本
  146. """
  147. if review_score < 80:
  148. return []
  149. candidates = []
  150. for scene in scenes:
  151. scene_type = self._classify_scene_type(scene)
  152. scene_content = scene.get("content", "")
  153. # 跳过过短的场景
  154. if len(scene_content) < 200:
  155. continue
  156. # 创建样本
  157. sample = StyleSample(
  158. id=f"ch{chapter}_s{scene.get('index', 0)}",
  159. chapter=chapter,
  160. scene_type=scene_type,
  161. content=scene_content[:2000], # 限制长度
  162. score=review_score / 100.0,
  163. tags=self._extract_tags(scene_content)
  164. )
  165. candidates.append(sample)
  166. return candidates
  167. def _classify_scene_type(self, scene: Dict) -> str:
  168. """分类场景类型"""
  169. summary = scene.get("summary", "").lower()
  170. content = scene.get("content", "").lower()
  171. # 简单关键词分类
  172. battle_keywords = ["战斗", "攻击", "出手", "拳", "剑", "杀", "打", "斗"]
  173. dialogue_keywords = ["说道", "问道", "笑道", "冷声", "对话"]
  174. emotion_keywords = ["心中", "感觉", "情", "泪", "痛", "喜"]
  175. tension_keywords = ["危险", "紧张", "恐惧", "压力"]
  176. text = summary + content
  177. if any(kw in text for kw in battle_keywords):
  178. return SceneType.BATTLE.value
  179. elif any(kw in text for kw in tension_keywords):
  180. return SceneType.TENSION.value
  181. elif any(kw in text for kw in dialogue_keywords):
  182. return SceneType.DIALOGUE.value
  183. elif any(kw in text for kw in emotion_keywords):
  184. return SceneType.EMOTION.value
  185. else:
  186. return SceneType.DESCRIPTION.value
  187. def _extract_tags(self, content: str) -> List[str]:
  188. """提取内容标签"""
  189. tags = []
  190. # 简单标签提取
  191. if "战斗" in content or "攻击" in content:
  192. tags.append("战斗")
  193. if "修炼" in content or "突破" in content:
  194. tags.append("修炼")
  195. if "对话" in content or "说道" in content:
  196. tags.append("对话")
  197. if "描写" in content or "景色" in content:
  198. tags.append("描写")
  199. return tags[:5]
  200. # ==================== 样本选择 ====================
  201. def select_samples_for_chapter(
  202. self,
  203. chapter_outline: str,
  204. target_types: List[str] = None,
  205. max_samples: int = 3
  206. ) -> List[StyleSample]:
  207. """
  208. 为章节写作选择合适的风格样本
  209. 基于大纲分析需要什么类型的样本
  210. """
  211. if target_types is None:
  212. # 根据大纲推断需要的场景类型
  213. target_types = self._infer_scene_types(chapter_outline)
  214. samples = []
  215. per_type = max(1, max_samples // len(target_types)) if target_types else max_samples
  216. for scene_type in target_types:
  217. type_samples = self.get_samples_by_type(scene_type, limit=per_type, min_score=0.8)
  218. samples.extend(type_samples)
  219. return samples[:max_samples]
  220. def _infer_scene_types(self, outline: str) -> List[str]:
  221. """从大纲推断需要的场景类型"""
  222. types = []
  223. if any(kw in outline for kw in ["战斗", "对决", "比试", "交手"]):
  224. types.append(SceneType.BATTLE.value)
  225. if any(kw in outline for kw in ["对话", "谈话", "商议", "讨论"]):
  226. types.append(SceneType.DIALOGUE.value)
  227. if any(kw in outline for kw in ["情感", "感情", "心理"]):
  228. types.append(SceneType.EMOTION.value)
  229. if not types:
  230. types = [SceneType.DESCRIPTION.value]
  231. return types
  232. # ==================== 统计 ====================
  233. def get_stats(self) -> Dict[str, Any]:
  234. """获取样本统计"""
  235. with self._get_conn() as conn:
  236. cursor = conn.cursor()
  237. cursor.execute("SELECT COUNT(*) FROM samples")
  238. total = cursor.fetchone()[0]
  239. cursor.execute("""
  240. SELECT scene_type, COUNT(*) as count
  241. FROM samples
  242. GROUP BY scene_type
  243. """)
  244. by_type = {row[0]: row[1] for row in cursor.fetchall()}
  245. cursor.execute("SELECT AVG(score) FROM samples")
  246. avg_score = cursor.fetchone()[0] or 0
  247. return {
  248. "total": total,
  249. "by_type": by_type,
  250. "avg_score": round(avg_score, 3)
  251. }
  252. # ==================== CLI 接口 ====================
  253. def main():
  254. import argparse
  255. from .cli_output import print_success, print_error
  256. from .index_manager import IndexManager
  257. parser = argparse.ArgumentParser(description="Style Sampler CLI")
  258. parser.add_argument("--project-root", type=str, help="项目根目录")
  259. subparsers = parser.add_subparsers(dest="command")
  260. # 获取统计
  261. subparsers.add_parser("stats")
  262. # 列出样本
  263. list_parser = subparsers.add_parser("list")
  264. list_parser.add_argument("--type", help="按类型过滤")
  265. list_parser.add_argument("--limit", type=int, default=10)
  266. # 提取样本
  267. extract_parser = subparsers.add_parser("extract")
  268. extract_parser.add_argument("--chapter", type=int, required=True)
  269. extract_parser.add_argument("--score", type=float, required=True)
  270. extract_parser.add_argument("--scenes", required=True, help="JSON 格式的场景列表")
  271. # 选择样本
  272. select_parser = subparsers.add_parser("select")
  273. select_parser.add_argument("--outline", required=True, help="章节大纲")
  274. select_parser.add_argument("--max", type=int, default=3)
  275. args = parser.parse_args()
  276. command_started_at = time.perf_counter()
  277. # 初始化
  278. config = None
  279. if args.project_root:
  280. from .config import DataModulesConfig
  281. config = DataModulesConfig.from_project_root(args.project_root)
  282. sampler = StyleSampler(config)
  283. logger = IndexManager(config)
  284. tool_name = f"style_sampler:{args.command or 'unknown'}"
  285. def _append_timing(success: bool, *, error_code: str | None = None, error_message: str | None = None, chapter: int | None = None):
  286. elapsed_ms = int((time.perf_counter() - command_started_at) * 1000)
  287. safe_append_perf_timing(
  288. sampler.config.project_root,
  289. tool_name=tool_name,
  290. success=success,
  291. elapsed_ms=elapsed_ms,
  292. chapter=chapter,
  293. error_code=error_code,
  294. error_message=error_message,
  295. )
  296. def emit_success(data=None, message: str = "ok", chapter: int | None = None):
  297. print_success(data, message=message)
  298. safe_log_tool_call(logger, tool_name=tool_name, success=True)
  299. _append_timing(True, chapter=chapter)
  300. def emit_error(code: str, message: str, suggestion: str | None = None, chapter: int | None = None):
  301. print_error(code, message, suggestion=suggestion)
  302. safe_log_tool_call(
  303. logger,
  304. tool_name=tool_name,
  305. success=False,
  306. error_code=code,
  307. error_message=message,
  308. )
  309. _append_timing(False, error_code=code, error_message=message, chapter=chapter)
  310. if args.command == "stats":
  311. stats = sampler.get_stats()
  312. emit_success(stats, message="stats")
  313. elif args.command == "list":
  314. if args.type:
  315. samples = sampler.get_samples_by_type(args.type, args.limit)
  316. else:
  317. samples = sampler.get_best_samples(args.limit)
  318. emit_success([s.__dict__ for s in samples], message="samples")
  319. elif args.command == "extract":
  320. scenes = json.loads(args.scenes)
  321. candidates = sampler.extract_candidates(
  322. chapter=args.chapter,
  323. content="",
  324. review_score=args.score,
  325. scenes=scenes,
  326. )
  327. added = []
  328. skipped = []
  329. for c in candidates:
  330. if sampler.add_sample(c):
  331. added.append(c.id)
  332. else:
  333. skipped.append(c.id)
  334. emit_success({"added": added, "skipped": skipped}, message="extracted", chapter=args.chapter)
  335. elif args.command == "select":
  336. samples = sampler.select_samples_for_chapter(args.outline, max_samples=args.max)
  337. emit_success([s.__dict__ for s in samples], message="selected")
  338. else:
  339. emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")
  340. if __name__ == "__main__":
  341. main()