snapshot_manager.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Context snapshot manager.
  5. """
  6. from __future__ import annotations
  7. import json
  8. from dataclasses import dataclass
  9. from datetime import datetime, timezone
  10. from pathlib import Path
  11. from typing import Any, Dict, Optional
  12. from .config import get_config
  13. SNAPSHOT_VERSION = "1.1"
  14. class SnapshotVersionMismatch(RuntimeError):
  15. def __init__(self, expected: str, actual: str) -> None:
  16. super().__init__(f"snapshot version mismatch: expected {expected}, got {actual}")
  17. self.expected = expected
  18. self.actual = actual
  19. @dataclass
  20. class SnapshotMeta:
  21. chapter: int
  22. version: str
  23. saved_at: str
  24. class SnapshotManager:
  25. def __init__(self, config=None, version: str = SNAPSHOT_VERSION):
  26. self.config = config or get_config()
  27. self.version = version
  28. self.snapshot_dir = self.config.webnovel_dir / "context_snapshots"
  29. self.snapshot_dir.mkdir(parents=True, exist_ok=True)
  30. def _snapshot_path(self, chapter: int) -> Path:
  31. return self.snapshot_dir / f"ch{chapter:04d}.json"
  32. def save_snapshot(self, chapter: int, payload: Dict[str, Any], meta: Optional[Dict[str, Any]] = None) -> Path:
  33. data: Dict[str, Any] = {
  34. "version": self.version,
  35. "chapter": chapter,
  36. "saved_at": datetime.now(timezone.utc).isoformat(),
  37. "payload": payload,
  38. }
  39. if meta:
  40. data["meta"] = meta
  41. path = self._snapshot_path(chapter)
  42. path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
  43. return path
  44. def load_snapshot(self, chapter: int) -> Optional[Dict[str, Any]]:
  45. path = self._snapshot_path(chapter)
  46. if not path.exists():
  47. return None
  48. data = json.loads(path.read_text(encoding="utf-8"))
  49. version = str(data.get("version", ""))
  50. if version != self.version:
  51. raise SnapshotVersionMismatch(self.version, version)
  52. return data
  53. def delete_snapshot(self, chapter: int) -> bool:
  54. path = self._snapshot_path(chapter)
  55. if path.exists():
  56. path.unlink()
  57. return True
  58. return False
  59. def list_snapshots(self) -> list[str]:
  60. return sorted(p.name for p in self.snapshot_dir.glob("ch*.json"))