style_sampler.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Style Sampler - 风格样本管理模块
  5. 管理高质量章节片段作为风格参考:
  6. - 风格样本存储
  7. - 按场景类型分类
  8. - 样本选择策略
  9. """
  10. import json
  11. import sqlite3
  12. from pathlib import Path
  13. from typing import Dict, List, Optional, Any
  14. from dataclasses import dataclass, asdict
  15. from datetime import datetime
  16. from enum import Enum
  17. from contextlib import contextmanager
  18. from .config import get_config
  19. from .observability import safe_log_tool_call
  20. class SceneType(Enum):
  21. """场景类型"""
  22. BATTLE = "战斗"
  23. DIALOGUE = "对话"
  24. DESCRIPTION = "描写"
  25. TRANSITION = "过渡"
  26. EMOTION = "情感"
  27. TENSION = "紧张"
  28. COMEDY = "轻松"
  29. @dataclass
  30. class StyleSample:
  31. """风格样本"""
  32. id: str
  33. chapter: int
  34. scene_type: str
  35. content: str
  36. score: float
  37. tags: List[str]
  38. created_at: str = ""
  39. class StyleSampler:
  40. """风格样本管理器"""
  41. def __init__(self, config=None):
  42. self.config = config or get_config()
  43. self._init_db()
  44. def _init_db(self):
  45. """初始化数据库"""
  46. self.config.ensure_dirs()
  47. with self._get_conn() as conn:
  48. cursor = conn.cursor()
  49. cursor.execute("""
  50. CREATE TABLE IF NOT EXISTS samples (
  51. id TEXT PRIMARY KEY,
  52. chapter INTEGER,
  53. scene_type TEXT,
  54. content TEXT,
  55. score REAL,
  56. tags TEXT,
  57. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  58. )
  59. """)
  60. cursor.execute("CREATE INDEX IF NOT EXISTS idx_samples_type ON samples(scene_type)")
  61. cursor.execute("CREATE INDEX IF NOT EXISTS idx_samples_score ON samples(score DESC)")
  62. conn.commit()
  63. @contextmanager
  64. def _get_conn(self):
  65. """获取数据库连接(确保关闭,避免 Windows 下文件句柄泄漏导致无法清理临时目录)"""
  66. db_path = self.config.webnovel_dir / "style_samples.db"
  67. conn = sqlite3.connect(str(db_path))
  68. try:
  69. yield conn
  70. finally:
  71. conn.close()
  72. # ==================== 样本管理 ====================
  73. def add_sample(self, sample: StyleSample) -> bool:
  74. """添加风格样本"""
  75. with self._get_conn() as conn:
  76. cursor = conn.cursor()
  77. try:
  78. cursor.execute("""
  79. INSERT INTO samples
  80. (id, chapter, scene_type, content, score, tags, created_at)
  81. VALUES (?, ?, ?, ?, ?, ?, ?)
  82. """, (
  83. sample.id,
  84. sample.chapter,
  85. sample.scene_type,
  86. sample.content,
  87. sample.score,
  88. json.dumps(sample.tags, ensure_ascii=False),
  89. sample.created_at or datetime.now().isoformat()
  90. ))
  91. conn.commit()
  92. return True
  93. except sqlite3.IntegrityError:
  94. return False
  95. def get_samples_by_type(
  96. self,
  97. scene_type: str,
  98. limit: int = 5,
  99. min_score: float = 0.0
  100. ) -> List[StyleSample]:
  101. """按场景类型获取样本"""
  102. with self._get_conn() as conn:
  103. cursor = conn.cursor()
  104. cursor.execute("""
  105. SELECT id, chapter, scene_type, content, score, tags, created_at
  106. FROM samples
  107. WHERE scene_type = ? AND score >= ?
  108. ORDER BY score DESC
  109. LIMIT ?
  110. """, (scene_type, min_score, limit))
  111. return [self._row_to_sample(row) for row in cursor.fetchall()]
  112. def get_best_samples(self, limit: int = 10) -> List[StyleSample]:
  113. """获取最高分样本"""
  114. with self._get_conn() as conn:
  115. cursor = conn.cursor()
  116. cursor.execute("""
  117. SELECT id, chapter, scene_type, content, score, tags, created_at
  118. FROM samples
  119. ORDER BY score DESC
  120. LIMIT ?
  121. """, (limit,))
  122. return [self._row_to_sample(row) for row in cursor.fetchall()]
  123. def _row_to_sample(self, row) -> StyleSample:
  124. """将数据库行转换为样本对象"""
  125. return StyleSample(
  126. id=row[0],
  127. chapter=row[1],
  128. scene_type=row[2],
  129. content=row[3],
  130. score=row[4],
  131. tags=json.loads(row[5]) if row[5] else [],
  132. created_at=row[6]
  133. )
  134. # ==================== 样本提取 ====================
  135. def extract_candidates(
  136. self,
  137. chapter: int,
  138. content: str,
  139. review_score: float,
  140. scenes: List[Dict]
  141. ) -> List[StyleSample]:
  142. """
  143. 从章节中提取风格样本候选
  144. 只有高分章节 (review_score >= 80) 才提取样本
  145. """
  146. if review_score < 80:
  147. return []
  148. candidates = []
  149. for scene in scenes:
  150. scene_type = self._classify_scene_type(scene)
  151. scene_content = scene.get("content", "")
  152. # 跳过过短的场景
  153. if len(scene_content) < 200:
  154. continue
  155. # 创建样本
  156. sample = StyleSample(
  157. id=f"ch{chapter}_s{scene.get('index', 0)}",
  158. chapter=chapter,
  159. scene_type=scene_type,
  160. content=scene_content[:2000], # 限制长度
  161. score=review_score / 100.0,
  162. tags=self._extract_tags(scene_content)
  163. )
  164. candidates.append(sample)
  165. return candidates
  166. def _classify_scene_type(self, scene: Dict) -> str:
  167. """分类场景类型"""
  168. summary = scene.get("summary", "").lower()
  169. content = scene.get("content", "").lower()
  170. # 简单关键词分类
  171. battle_keywords = ["战斗", "攻击", "出手", "拳", "剑", "杀", "打", "斗"]
  172. dialogue_keywords = ["说道", "问道", "笑道", "冷声", "对话"]
  173. emotion_keywords = ["心中", "感觉", "情", "泪", "痛", "喜"]
  174. tension_keywords = ["危险", "紧张", "恐惧", "压力"]
  175. text = summary + content
  176. if any(kw in text for kw in battle_keywords):
  177. return SceneType.BATTLE.value
  178. elif any(kw in text for kw in tension_keywords):
  179. return SceneType.TENSION.value
  180. elif any(kw in text for kw in dialogue_keywords):
  181. return SceneType.DIALOGUE.value
  182. elif any(kw in text for kw in emotion_keywords):
  183. return SceneType.EMOTION.value
  184. else:
  185. return SceneType.DESCRIPTION.value
  186. def _extract_tags(self, content: str) -> List[str]:
  187. """提取内容标签"""
  188. tags = []
  189. # 简单标签提取
  190. if "战斗" in content or "攻击" in content:
  191. tags.append("战斗")
  192. if "修炼" in content or "突破" in content:
  193. tags.append("修炼")
  194. if "对话" in content or "说道" in content:
  195. tags.append("对话")
  196. if "描写" in content or "景色" in content:
  197. tags.append("描写")
  198. return tags[:5]
  199. # ==================== 样本选择 ====================
  200. def select_samples_for_chapter(
  201. self,
  202. chapter_outline: str,
  203. target_types: List[str] = None,
  204. max_samples: int = 3
  205. ) -> List[StyleSample]:
  206. """
  207. 为章节写作选择合适的风格样本
  208. 基于大纲分析需要什么类型的样本
  209. """
  210. if target_types is None:
  211. # 根据大纲推断需要的场景类型
  212. target_types = self._infer_scene_types(chapter_outline)
  213. samples = []
  214. per_type = max(1, max_samples // len(target_types)) if target_types else max_samples
  215. for scene_type in target_types:
  216. type_samples = self.get_samples_by_type(scene_type, limit=per_type, min_score=0.8)
  217. samples.extend(type_samples)
  218. return samples[:max_samples]
  219. def _infer_scene_types(self, outline: str) -> List[str]:
  220. """从大纲推断需要的场景类型"""
  221. types = []
  222. if any(kw in outline for kw in ["战斗", "对决", "比试", "交手"]):
  223. types.append(SceneType.BATTLE.value)
  224. if any(kw in outline for kw in ["对话", "谈话", "商议", "讨论"]):
  225. types.append(SceneType.DIALOGUE.value)
  226. if any(kw in outline for kw in ["情感", "感情", "心理"]):
  227. types.append(SceneType.EMOTION.value)
  228. if not types:
  229. types = [SceneType.DESCRIPTION.value]
  230. return types
  231. # ==================== 统计 ====================
  232. def get_stats(self) -> Dict[str, Any]:
  233. """获取样本统计"""
  234. with self._get_conn() as conn:
  235. cursor = conn.cursor()
  236. cursor.execute("SELECT COUNT(*) FROM samples")
  237. total = cursor.fetchone()[0]
  238. cursor.execute("""
  239. SELECT scene_type, COUNT(*) as count
  240. FROM samples
  241. GROUP BY scene_type
  242. """)
  243. by_type = {row[0]: row[1] for row in cursor.fetchall()}
  244. cursor.execute("SELECT AVG(score) FROM samples")
  245. avg_score = cursor.fetchone()[0] or 0
  246. return {
  247. "total": total,
  248. "by_type": by_type,
  249. "avg_score": round(avg_score, 3)
  250. }
  251. # ==================== CLI 接口 ====================
  252. def main():
  253. import argparse
  254. from .cli_output import print_success, print_error
  255. from .index_manager import IndexManager
  256. parser = argparse.ArgumentParser(description="Style Sampler CLI")
  257. parser.add_argument("--project-root", type=str, help="项目根目录")
  258. subparsers = parser.add_subparsers(dest="command")
  259. # 获取统计
  260. subparsers.add_parser("stats")
  261. # 列出样本
  262. list_parser = subparsers.add_parser("list")
  263. list_parser.add_argument("--type", help="按类型过滤")
  264. list_parser.add_argument("--limit", type=int, default=10)
  265. # 提取样本
  266. extract_parser = subparsers.add_parser("extract")
  267. extract_parser.add_argument("--chapter", type=int, required=True)
  268. extract_parser.add_argument("--score", type=float, required=True)
  269. extract_parser.add_argument("--scenes", required=True, help="JSON 格式的场景列表")
  270. # 选择样本
  271. select_parser = subparsers.add_parser("select")
  272. select_parser.add_argument("--outline", required=True, help="章节大纲")
  273. select_parser.add_argument("--max", type=int, default=3)
  274. args = parser.parse_args()
  275. # 初始化
  276. config = None
  277. if args.project_root:
  278. from .config import DataModulesConfig
  279. config = DataModulesConfig.from_project_root(args.project_root)
  280. sampler = StyleSampler(config)
  281. logger = IndexManager(config)
  282. tool_name = f"style_sampler:{args.command or 'unknown'}"
  283. def emit_success(data=None, message: str = "ok"):
  284. print_success(data, message=message)
  285. safe_log_tool_call(logger, tool_name=tool_name, success=True)
  286. def emit_error(code: str, message: str, suggestion: str | None = None):
  287. print_error(code, message, suggestion=suggestion)
  288. safe_log_tool_call(
  289. logger,
  290. tool_name=tool_name,
  291. success=False,
  292. error_code=code,
  293. error_message=message,
  294. )
  295. if args.command == "stats":
  296. stats = sampler.get_stats()
  297. emit_success(stats, message="stats")
  298. elif args.command == "list":
  299. if args.type:
  300. samples = sampler.get_samples_by_type(args.type, args.limit)
  301. else:
  302. samples = sampler.get_best_samples(args.limit)
  303. emit_success([s.__dict__ for s in samples], message="samples")
  304. elif args.command == "extract":
  305. scenes = json.loads(args.scenes)
  306. candidates = sampler.extract_candidates(
  307. chapter=args.chapter,
  308. content="",
  309. review_score=args.score,
  310. scenes=scenes,
  311. )
  312. added = []
  313. skipped = []
  314. for c in candidates:
  315. if sampler.add_sample(c):
  316. added.append(c.id)
  317. else:
  318. skipped.append(c.id)
  319. emit_success({"added": added, "skipped": skipped}, message="extracted")
  320. elif args.command == "select":
  321. samples = sampler.select_samples_for_chapter(args.outline, max_samples=args.max)
  322. emit_success([s.__dict__ for s in samples], message="selected")
  323. else:
  324. emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")
  325. if __name__ == "__main__":
  326. main()