style_sampler.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Style Sampler - 风格样本管理模块
  5. 管理高质量章节片段作为风格参考:
  6. - 风格样本存储
  7. - 按场景类型分类
  8. - 样本选择策略
  9. """
  10. import json
  11. import sqlite3
  12. from pathlib import Path
  13. from typing import Dict, List, Optional, Any
  14. from dataclasses import dataclass, asdict
  15. from datetime import datetime
  16. from enum import Enum
  17. from contextlib import contextmanager
  18. from .config import get_config
  19. class SceneType(Enum):
  20. """场景类型"""
  21. BATTLE = "战斗"
  22. DIALOGUE = "对话"
  23. DESCRIPTION = "描写"
  24. TRANSITION = "过渡"
  25. EMOTION = "情感"
  26. TENSION = "紧张"
  27. COMEDY = "轻松"
  28. @dataclass
  29. class StyleSample:
  30. """风格样本"""
  31. id: str
  32. chapter: int
  33. scene_type: str
  34. content: str
  35. score: float
  36. tags: List[str]
  37. created_at: str = ""
  38. class StyleSampler:
  39. """风格样本管理器"""
  40. def __init__(self, config=None):
  41. self.config = config or get_config()
  42. self._init_db()
  43. def _init_db(self):
  44. """初始化数据库"""
  45. self.config.ensure_dirs()
  46. with self._get_conn() as conn:
  47. cursor = conn.cursor()
  48. cursor.execute("""
  49. CREATE TABLE IF NOT EXISTS samples (
  50. id TEXT PRIMARY KEY,
  51. chapter INTEGER,
  52. scene_type TEXT,
  53. content TEXT,
  54. score REAL,
  55. tags TEXT,
  56. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  57. )
  58. """)
  59. cursor.execute("CREATE INDEX IF NOT EXISTS idx_samples_type ON samples(scene_type)")
  60. cursor.execute("CREATE INDEX IF NOT EXISTS idx_samples_score ON samples(score DESC)")
  61. conn.commit()
  62. @contextmanager
  63. def _get_conn(self):
  64. """获取数据库连接(确保关闭,避免 Windows 下文件句柄泄漏导致无法清理临时目录)"""
  65. db_path = self.config.webnovel_dir / "style_samples.db"
  66. conn = sqlite3.connect(str(db_path))
  67. try:
  68. yield conn
  69. finally:
  70. conn.close()
  71. # ==================== 样本管理 ====================
  72. def add_sample(self, sample: StyleSample) -> bool:
  73. """添加风格样本"""
  74. with self._get_conn() as conn:
  75. cursor = conn.cursor()
  76. try:
  77. cursor.execute("""
  78. INSERT INTO samples
  79. (id, chapter, scene_type, content, score, tags, created_at)
  80. VALUES (?, ?, ?, ?, ?, ?, ?)
  81. """, (
  82. sample.id,
  83. sample.chapter,
  84. sample.scene_type,
  85. sample.content,
  86. sample.score,
  87. json.dumps(sample.tags, ensure_ascii=False),
  88. sample.created_at or datetime.now().isoformat()
  89. ))
  90. conn.commit()
  91. return True
  92. except sqlite3.IntegrityError:
  93. return False
  94. def get_samples_by_type(
  95. self,
  96. scene_type: str,
  97. limit: int = 5,
  98. min_score: float = 0.0
  99. ) -> List[StyleSample]:
  100. """按场景类型获取样本"""
  101. with self._get_conn() as conn:
  102. cursor = conn.cursor()
  103. cursor.execute("""
  104. SELECT id, chapter, scene_type, content, score, tags, created_at
  105. FROM samples
  106. WHERE scene_type = ? AND score >= ?
  107. ORDER BY score DESC
  108. LIMIT ?
  109. """, (scene_type, min_score, limit))
  110. return [self._row_to_sample(row) for row in cursor.fetchall()]
  111. def get_best_samples(self, limit: int = 10) -> List[StyleSample]:
  112. """获取最高分样本"""
  113. with self._get_conn() as conn:
  114. cursor = conn.cursor()
  115. cursor.execute("""
  116. SELECT id, chapter, scene_type, content, score, tags, created_at
  117. FROM samples
  118. ORDER BY score DESC
  119. LIMIT ?
  120. """, (limit,))
  121. return [self._row_to_sample(row) for row in cursor.fetchall()]
  122. def _row_to_sample(self, row) -> StyleSample:
  123. """将数据库行转换为样本对象"""
  124. return StyleSample(
  125. id=row[0],
  126. chapter=row[1],
  127. scene_type=row[2],
  128. content=row[3],
  129. score=row[4],
  130. tags=json.loads(row[5]) if row[5] else [],
  131. created_at=row[6]
  132. )
  133. # ==================== 样本提取 ====================
  134. def extract_candidates(
  135. self,
  136. chapter: int,
  137. content: str,
  138. review_score: float,
  139. scenes: List[Dict]
  140. ) -> List[StyleSample]:
  141. """
  142. 从章节中提取风格样本候选
  143. 只有高分章节 (review_score >= 80) 才提取样本
  144. """
  145. if review_score < 80:
  146. return []
  147. candidates = []
  148. for scene in scenes:
  149. scene_type = self._classify_scene_type(scene)
  150. scene_content = scene.get("content", "")
  151. # 跳过过短的场景
  152. if len(scene_content) < 200:
  153. continue
  154. # 创建样本
  155. sample = StyleSample(
  156. id=f"ch{chapter}_s{scene.get('index', 0)}",
  157. chapter=chapter,
  158. scene_type=scene_type,
  159. content=scene_content[:2000], # 限制长度
  160. score=review_score / 100.0,
  161. tags=self._extract_tags(scene_content)
  162. )
  163. candidates.append(sample)
  164. return candidates
  165. def _classify_scene_type(self, scene: Dict) -> str:
  166. """分类场景类型"""
  167. summary = scene.get("summary", "").lower()
  168. content = scene.get("content", "").lower()
  169. # 简单关键词分类
  170. battle_keywords = ["战斗", "攻击", "出手", "拳", "剑", "杀", "打", "斗"]
  171. dialogue_keywords = ["说道", "问道", "笑道", "冷声", "对话"]
  172. emotion_keywords = ["心中", "感觉", "情", "泪", "痛", "喜"]
  173. tension_keywords = ["危险", "紧张", "恐惧", "压力"]
  174. text = summary + content
  175. if any(kw in text for kw in battle_keywords):
  176. return SceneType.BATTLE.value
  177. elif any(kw in text for kw in tension_keywords):
  178. return SceneType.TENSION.value
  179. elif any(kw in text for kw in dialogue_keywords):
  180. return SceneType.DIALOGUE.value
  181. elif any(kw in text for kw in emotion_keywords):
  182. return SceneType.EMOTION.value
  183. else:
  184. return SceneType.DESCRIPTION.value
  185. def _extract_tags(self, content: str) -> List[str]:
  186. """提取内容标签"""
  187. tags = []
  188. # 简单标签提取
  189. if "战斗" in content or "攻击" in content:
  190. tags.append("战斗")
  191. if "修炼" in content or "突破" in content:
  192. tags.append("修炼")
  193. if "对话" in content or "说道" in content:
  194. tags.append("对话")
  195. if "描写" in content or "景色" in content:
  196. tags.append("描写")
  197. return tags[:5]
  198. # ==================== 样本选择 ====================
  199. def select_samples_for_chapter(
  200. self,
  201. chapter_outline: str,
  202. target_types: List[str] = None,
  203. max_samples: int = 3
  204. ) -> List[StyleSample]:
  205. """
  206. 为章节写作选择合适的风格样本
  207. 基于大纲分析需要什么类型的样本
  208. """
  209. if target_types is None:
  210. # 根据大纲推断需要的场景类型
  211. target_types = self._infer_scene_types(chapter_outline)
  212. samples = []
  213. per_type = max(1, max_samples // len(target_types)) if target_types else max_samples
  214. for scene_type in target_types:
  215. type_samples = self.get_samples_by_type(scene_type, limit=per_type, min_score=0.8)
  216. samples.extend(type_samples)
  217. return samples[:max_samples]
  218. def _infer_scene_types(self, outline: str) -> List[str]:
  219. """从大纲推断需要的场景类型"""
  220. types = []
  221. if any(kw in outline for kw in ["战斗", "对决", "比试", "交手"]):
  222. types.append(SceneType.BATTLE.value)
  223. if any(kw in outline for kw in ["对话", "谈话", "商议", "讨论"]):
  224. types.append(SceneType.DIALOGUE.value)
  225. if any(kw in outline for kw in ["情感", "感情", "心理"]):
  226. types.append(SceneType.EMOTION.value)
  227. if not types:
  228. types = [SceneType.DESCRIPTION.value]
  229. return types
  230. # ==================== 统计 ====================
  231. def get_stats(self) -> Dict[str, Any]:
  232. """获取样本统计"""
  233. with self._get_conn() as conn:
  234. cursor = conn.cursor()
  235. cursor.execute("SELECT COUNT(*) FROM samples")
  236. total = cursor.fetchone()[0]
  237. cursor.execute("""
  238. SELECT scene_type, COUNT(*) as count
  239. FROM samples
  240. GROUP BY scene_type
  241. """)
  242. by_type = {row[0]: row[1] for row in cursor.fetchall()}
  243. cursor.execute("SELECT AVG(score) FROM samples")
  244. avg_score = cursor.fetchone()[0] or 0
  245. return {
  246. "total": total,
  247. "by_type": by_type,
  248. "avg_score": round(avg_score, 3)
  249. }
  250. # ==================== CLI 接口 ====================
  251. def main():
  252. import argparse
  253. from .cli_output import print_success, print_error
  254. from .index_manager import IndexManager
  255. parser = argparse.ArgumentParser(description="Style Sampler CLI")
  256. parser.add_argument("--project-root", type=str, help="项目根目录")
  257. subparsers = parser.add_subparsers(dest="command")
  258. # 获取统计
  259. subparsers.add_parser("stats")
  260. # 列出样本
  261. list_parser = subparsers.add_parser("list")
  262. list_parser.add_argument("--type", help="按类型过滤")
  263. list_parser.add_argument("--limit", type=int, default=10)
  264. # 提取样本
  265. extract_parser = subparsers.add_parser("extract")
  266. extract_parser.add_argument("--chapter", type=int, required=True)
  267. extract_parser.add_argument("--score", type=float, required=True)
  268. extract_parser.add_argument("--scenes", required=True, help="JSON 格式的场景列表")
  269. # 选择样本
  270. select_parser = subparsers.add_parser("select")
  271. select_parser.add_argument("--outline", required=True, help="章节大纲")
  272. select_parser.add_argument("--max", type=int, default=3)
  273. args = parser.parse_args()
  274. # 初始化
  275. config = None
  276. if args.project_root:
  277. from .config import DataModulesConfig
  278. config = DataModulesConfig.from_project_root(args.project_root)
  279. sampler = StyleSampler(config)
  280. logger = IndexManager(config)
  281. tool_name = f"style_sampler:{args.command or 'unknown'}"
  282. def emit_success(data=None, message: str = "ok"):
  283. print_success(data, message=message)
  284. try:
  285. logger.log_tool_call(tool_name, True)
  286. except Exception:
  287. pass
  288. def emit_error(code: str, message: str, suggestion: str | None = None):
  289. print_error(code, message, suggestion=suggestion)
  290. try:
  291. logger.log_tool_call(tool_name, False, error_code=code, error_message=message)
  292. except Exception:
  293. pass
  294. if args.command == "stats":
  295. stats = sampler.get_stats()
  296. emit_success(stats, message="stats")
  297. elif args.command == "list":
  298. if args.type:
  299. samples = sampler.get_samples_by_type(args.type, args.limit)
  300. else:
  301. samples = sampler.get_best_samples(args.limit)
  302. emit_success([s.__dict__ for s in samples], message="samples")
  303. elif args.command == "extract":
  304. scenes = json.loads(args.scenes)
  305. candidates = sampler.extract_candidates(
  306. chapter=args.chapter,
  307. content="",
  308. review_score=args.score,
  309. scenes=scenes,
  310. )
  311. added = []
  312. skipped = []
  313. for c in candidates:
  314. if sampler.add_sample(c):
  315. added.append(c.id)
  316. else:
  317. skipped.append(c.id)
  318. emit_success({"added": added, "skipped": skipped}, message="extracted")
  319. elif args.command == "select":
  320. samples = sampler.select_samples_for_chapter(args.outline, max_samples=args.max)
  321. emit_success([s.__dict__ for s in samples], message="selected")
  322. else:
  323. emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")
  324. if __name__ == "__main__":
  325. main()