| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Style Sampler - 风格样本管理模块
- 管理高质量章节片段作为风格参考:
- - 风格样本存储
- - 按场景类型分类
- - 样本选择策略
- """
- import json
- import sqlite3
- import time
- from pathlib import Path
- from typing import Dict, List, Optional, Any
- from dataclasses import dataclass, asdict
- from datetime import datetime
- from enum import Enum
- from contextlib import contextmanager
- from .config import get_config
- from .observability import safe_append_perf_timing, safe_log_tool_call
- class SceneType(Enum):
- """场景类型"""
- BATTLE = "战斗"
- DIALOGUE = "对话"
- DESCRIPTION = "描写"
- TRANSITION = "过渡"
- EMOTION = "情感"
- TENSION = "紧张"
- COMEDY = "轻松"
- @dataclass
- class StyleSample:
- """风格样本"""
- id: str
- chapter: int
- scene_type: str
- content: str
- score: float
- tags: List[str]
- created_at: str = ""
- class StyleSampler:
- """风格样本管理器"""
- def __init__(self, config=None):
- self.config = config or get_config()
- self._init_db()
- def _init_db(self):
- """初始化数据库"""
- self.config.ensure_dirs()
- with self._get_conn() as conn:
- cursor = conn.cursor()
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS samples (
- id TEXT PRIMARY KEY,
- chapter INTEGER,
- scene_type TEXT,
- content TEXT,
- score REAL,
- tags TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- )
- """)
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_samples_type ON samples(scene_type)")
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_samples_score ON samples(score DESC)")
- conn.commit()
- @contextmanager
- def _get_conn(self):
- """获取数据库连接(确保关闭,避免 Windows 下文件句柄泄漏导致无法清理临时目录)"""
- db_path = self.config.webnovel_dir / "style_samples.db"
- conn = sqlite3.connect(str(db_path))
- try:
- yield conn
- finally:
- conn.close()
- # ==================== 样本管理 ====================
- def add_sample(self, sample: StyleSample) -> bool:
- """添加风格样本"""
- with self._get_conn() as conn:
- cursor = conn.cursor()
- try:
- cursor.execute("""
- INSERT INTO samples
- (id, chapter, scene_type, content, score, tags, created_at)
- VALUES (?, ?, ?, ?, ?, ?, ?)
- """, (
- sample.id,
- sample.chapter,
- sample.scene_type,
- sample.content,
- sample.score,
- json.dumps(sample.tags, ensure_ascii=False),
- sample.created_at or datetime.now().isoformat()
- ))
- conn.commit()
- return True
- except sqlite3.IntegrityError:
- return False
- def get_samples_by_type(
- self,
- scene_type: str,
- limit: int = 5,
- min_score: float = 0.0
- ) -> List[StyleSample]:
- """按场景类型获取样本"""
- with self._get_conn() as conn:
- cursor = conn.cursor()
- cursor.execute("""
- SELECT id, chapter, scene_type, content, score, tags, created_at
- FROM samples
- WHERE scene_type = ? AND score >= ?
- ORDER BY score DESC
- LIMIT ?
- """, (scene_type, min_score, limit))
- return [self._row_to_sample(row) for row in cursor.fetchall()]
- def get_best_samples(self, limit: int = 10) -> List[StyleSample]:
- """获取最高分样本"""
- with self._get_conn() as conn:
- cursor = conn.cursor()
- cursor.execute("""
- SELECT id, chapter, scene_type, content, score, tags, created_at
- FROM samples
- ORDER BY score DESC
- LIMIT ?
- """, (limit,))
- return [self._row_to_sample(row) for row in cursor.fetchall()]
- def _safe_tags(self, raw) -> List[str]:
- if not raw:
- return []
- try:
- value = json.loads(raw)
- except (TypeError, json.JSONDecodeError):
- return []
- return value if isinstance(value, list) else []
- def _row_to_sample(self, row) -> StyleSample:
- """将数据库行转换为样本对象"""
- return StyleSample(
- id=row[0],
- chapter=row[1],
- scene_type=row[2],
- content=row[3],
- score=row[4],
- tags=self._safe_tags(row[5]),
- created_at=row[6]
- )
- # ==================== 样本提取 ====================
- def extract_candidates(
- self,
- chapter: int,
- content: str,
- review_score: float,
- scenes: List[Dict]
- ) -> List[StyleSample]:
- """
- 从章节中提取风格样本候选
- 只有高分章节 (review_score >= 80) 才提取样本
- """
- if review_score < 80:
- return []
- candidates = []
- for scene in scenes:
- scene_type = self._classify_scene_type(scene)
- scene_content = scene.get("content", "")
- # 跳过过短的场景
- if len(scene_content) < 200:
- continue
- # 创建样本
- sample = StyleSample(
- id=f"ch{chapter}_s{scene.get('index', 0)}",
- chapter=chapter,
- scene_type=scene_type,
- content=scene_content[:2000], # 限制长度
- score=review_score / 100.0,
- tags=self._extract_tags(scene_content)
- )
- candidates.append(sample)
- return candidates
- def _classify_scene_type(self, scene: Dict) -> str:
- """分类场景类型"""
- summary = scene.get("summary", "").lower()
- content = scene.get("content", "").lower()
- # 简单关键词分类
- battle_keywords = ["战斗", "攻击", "出手", "拳", "剑", "杀", "打", "斗"]
- dialogue_keywords = ["说道", "问道", "笑道", "冷声", "对话"]
- emotion_keywords = ["心中", "感觉", "情", "泪", "痛", "喜"]
- tension_keywords = ["危险", "紧张", "恐惧", "压力"]
- text = summary + content
- if any(kw in text for kw in battle_keywords):
- return SceneType.BATTLE.value
- elif any(kw in text for kw in tension_keywords):
- return SceneType.TENSION.value
- elif any(kw in text for kw in dialogue_keywords):
- return SceneType.DIALOGUE.value
- elif any(kw in text for kw in emotion_keywords):
- return SceneType.EMOTION.value
- else:
- return SceneType.DESCRIPTION.value
- def _extract_tags(self, content: str) -> List[str]:
- """提取内容标签"""
- tags = []
- # 简单标签提取
- if "战斗" in content or "攻击" in content:
- tags.append("战斗")
- if "修炼" in content or "突破" in content:
- tags.append("修炼")
- if "对话" in content or "说道" in content:
- tags.append("对话")
- if "描写" in content or "景色" in content:
- tags.append("描写")
- return tags[:5]
- # ==================== 样本选择 ====================
- def select_samples_for_chapter(
- self,
- chapter_outline: str,
- target_types: List[str] = None,
- max_samples: int = 3
- ) -> List[StyleSample]:
- """
- 为章节写作选择合适的风格样本
- 基于大纲分析需要什么类型的样本
- """
- if target_types is None:
- # 根据大纲推断需要的场景类型
- target_types = self._infer_scene_types(chapter_outline)
- samples = []
- per_type = max(1, max_samples // len(target_types)) if target_types else max_samples
- for scene_type in target_types:
- type_samples = self.get_samples_by_type(scene_type, limit=per_type, min_score=0.8)
- samples.extend(type_samples)
- return samples[:max_samples]
- def _infer_scene_types(self, outline: str) -> List[str]:
- """从大纲推断需要的场景类型"""
- types = []
- if any(kw in outline for kw in ["战斗", "对决", "比试", "交手"]):
- types.append(SceneType.BATTLE.value)
- if any(kw in outline for kw in ["对话", "谈话", "商议", "讨论"]):
- types.append(SceneType.DIALOGUE.value)
- if any(kw in outline for kw in ["情感", "感情", "心理"]):
- types.append(SceneType.EMOTION.value)
- if not types:
- types = [SceneType.DESCRIPTION.value]
- return types
- # ==================== 统计 ====================
- def get_stats(self) -> Dict[str, Any]:
- """获取样本统计"""
- with self._get_conn() as conn:
- cursor = conn.cursor()
- cursor.execute("SELECT COUNT(*) FROM samples")
- total = cursor.fetchone()[0]
- cursor.execute("""
- SELECT scene_type, COUNT(*) as count
- FROM samples
- GROUP BY scene_type
- """)
- by_type = {row[0]: row[1] for row in cursor.fetchall()}
- cursor.execute("SELECT AVG(score) FROM samples")
- avg_score = cursor.fetchone()[0] or 0
- return {
- "total": total,
- "by_type": by_type,
- "avg_score": round(avg_score, 3)
- }
- # ==================== CLI 接口 ====================
- def main():
- import argparse
- import sys
- from .cli_output import print_success, print_error
- from .cli_args import normalize_global_project_root, load_json_arg
- from .index_manager import IndexManager
- parser = argparse.ArgumentParser(description="Style Sampler CLI")
- parser.add_argument("--project-root", type=str, help="项目根目录")
- subparsers = parser.add_subparsers(dest="command")
- # 获取统计
- subparsers.add_parser("stats")
- # 列出样本
- list_parser = subparsers.add_parser("list")
- list_parser.add_argument("--type", help="按类型过滤")
- list_parser.add_argument("--limit", type=int, default=10)
- # 提取样本
- extract_parser = subparsers.add_parser("extract")
- extract_parser.add_argument("--chapter", type=int, required=True)
- extract_parser.add_argument("--score", type=float, required=True)
- extract_parser.add_argument("--scenes", required=True, help="JSON 格式的场景列表")
- # 选择样本
- select_parser = subparsers.add_parser("select")
- select_parser.add_argument("--outline", required=True, help="章节大纲")
- select_parser.add_argument("--max", type=int, default=3)
- argv = normalize_global_project_root(sys.argv[1:])
- args = parser.parse_args(argv)
- command_started_at = time.perf_counter()
- # 初始化
- config = None
- if args.project_root:
- # 允许传入“工作区根目录”,统一解析到真正的 book project_root(必须包含 .webnovel/state.json)
- from project_locator import resolve_project_root
- from .config import DataModulesConfig
- resolved_root = resolve_project_root(args.project_root)
- config = DataModulesConfig.from_project_root(resolved_root)
- sampler = StyleSampler(config)
- logger = IndexManager(config)
- tool_name = f"style_sampler:{args.command or 'unknown'}"
- def _append_timing(success: bool, *, error_code: str | None = None, error_message: str | None = None, chapter: int | None = None):
- elapsed_ms = int((time.perf_counter() - command_started_at) * 1000)
- safe_append_perf_timing(
- sampler.config.project_root,
- tool_name=tool_name,
- success=success,
- elapsed_ms=elapsed_ms,
- chapter=chapter,
- error_code=error_code,
- error_message=error_message,
- )
- def emit_success(data=None, message: str = "ok", chapter: int | None = None):
- print_success(data, message=message)
- safe_log_tool_call(logger, tool_name=tool_name, success=True)
- _append_timing(True, chapter=chapter)
- def emit_error(code: str, message: str, suggestion: str | None = None, chapter: int | None = None):
- print_error(code, message, suggestion=suggestion)
- safe_log_tool_call(
- logger,
- tool_name=tool_name,
- success=False,
- error_code=code,
- error_message=message,
- )
- _append_timing(False, error_code=code, error_message=message, chapter=chapter)
- if args.command == "stats":
- stats = sampler.get_stats()
- emit_success(stats, message="stats")
- elif args.command == "list":
- if args.type:
- samples = sampler.get_samples_by_type(args.type, args.limit)
- else:
- samples = sampler.get_best_samples(args.limit)
- emit_success([s.__dict__ for s in samples], message="samples")
- elif args.command == "extract":
- scenes = load_json_arg(args.scenes, base_dir=sampler.config.project_root)
- candidates = sampler.extract_candidates(
- chapter=args.chapter,
- content="",
- review_score=args.score,
- scenes=scenes,
- )
- added = []
- skipped = []
- for c in candidates:
- if sampler.add_sample(c):
- added.append(c.id)
- else:
- skipped.append(c.id)
- emit_success({"added": added, "skipped": skipped}, message="extracted", chapter=args.chapter)
- elif args.command == "select":
- samples = sampler.select_samples_for_chapter(args.outline, max_samples=args.max)
- emit_success([s.__dict__ for s in samples], message="selected")
- else:
- emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")
- if __name__ == "__main__":
- main()
|