Просмотр исходного кода

feat: MemoryContractAdapter——薄适配器包装现有模块

lingfengQAQ 2 месяцев назад
Родитель
Сommit
085c223188

+ 222 - 0
webnovel-writer/scripts/data_modules/memory_contract_adapter.py

@@ -0,0 +1,222 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+MemoryContractAdapter——薄适配器,包装现有模块满足 MemoryContract Protocol。
+
+不做存储重构,仅委托给 StateManager / IndexManager / ScratchpadManager 等。
+"""
+from __future__ import annotations
+
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from .config import DataModulesConfig, get_config
+from .memory_contract import (
+    CommitResult,
+    ContextPack,
+    EntitySnapshot,
+    OpenLoop,
+    Rule,
+    TimelineEvent,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class MemoryContractAdapter:
+    """满足 MemoryContract Protocol 的具体实现。"""
+
+    def __init__(self, config: DataModulesConfig | None = None):
+        self.config = config or get_config()
+
+    # ------------------------------------------------------------------
+    # 内部懒加载(避免在构造时就初始化所有重量级模块)
+    # ------------------------------------------------------------------
+
+    def _state_manager(self):
+        from .state_manager import StateManager
+        return StateManager(self.config)
+
+    def _index_manager(self):
+        from .index_manager import IndexManager
+        return IndexManager(self.config)
+
+    def _memory_writer(self):
+        from .memory.writer import MemoryWriter
+        return MemoryWriter(self.config)
+
+    def _memory_store(self):
+        from .memory.store import ScratchpadManager
+        return ScratchpadManager(self.config)
+
+    def _memory_orchestrator(self):
+        from .memory.orchestrator import MemoryOrchestrator
+        return MemoryOrchestrator(self.config)
+
+    # ------------------------------------------------------------------
+    # 契约方法
+    # ------------------------------------------------------------------
+
+    def commit_chapter(self, chapter: int, result: dict) -> CommitResult:
+        warnings: List[str] = []
+        entities_added = 0
+        entities_updated = 0
+        state_changes_recorded = 0
+        relationships_added = 0
+        memory_items_added = 0
+        summary_path = ""
+
+        # 1. StateManager: process_chapter_result
+        try:
+            sm = self._state_manager()
+            sm._load_state()
+            sm_warnings = sm.process_chapter_result(chapter, result)
+            warnings.extend(sm_warnings or [])
+            entities_added = len(result.get("entities_new", []) or [])
+            entities_updated = len(result.get("entities_appeared", []) or [])
+            state_changes_recorded = len(result.get("state_changes", []) or [])
+            relationships_added = len(result.get("relationships_new", []) or [])
+        except Exception as e:
+            logger.warning("commit_chapter: StateManager failed: %s", e)
+            warnings.append(f"StateManager error: {e}")
+
+        # 2. MemoryWriter: update_from_chapter_result
+        try:
+            mw = self._memory_writer()
+            mem_stats = mw.update_from_chapter_result(chapter, result)
+            memory_items_added = mem_stats.get("items_added", 0)
+            if mem_stats.get("warnings"):
+                warnings.extend(mem_stats["warnings"])
+        except Exception as e:
+            logger.warning("commit_chapter: MemoryWriter failed: %s", e)
+            warnings.append(f"MemoryWriter error: {e}")
+
+        # 3. 摘要路径
+        padded = f"{chapter:04d}"
+        summary_file = self.config.webnovel_dir / "summaries" / f"ch{padded}.md"
+        if summary_file.exists():
+            summary_path = str(summary_file)
+
+        return CommitResult(
+            chapter=chapter,
+            entities_added=entities_added,
+            entities_updated=entities_updated,
+            state_changes_recorded=state_changes_recorded,
+            relationships_added=relationships_added,
+            memory_items_added=memory_items_added,
+            summary_path=summary_path,
+            warnings=warnings,
+        )
+
+    def load_context(self, chapter: int, budget_tokens: int = 4000) -> ContextPack:
+        try:
+            orch = self._memory_orchestrator()
+            pack = orch.build_memory_pack(chapter)
+            return ContextPack(
+                chapter=chapter,
+                sections=pack,
+                budget_used_tokens=0,  # orchestrator 不计 token,由调用者按需裁剪
+            )
+        except Exception as e:
+            logger.warning("load_context failed: %s", e)
+            return ContextPack(chapter=chapter)
+
+    def query_entity(self, entity_id: str) -> Optional[EntitySnapshot]:
+        try:
+            sm = self._state_manager()
+            sm._load_state()
+            entity = sm.get_entity(entity_id)
+            if not entity:
+                return None
+
+            entity_type = sm.get_entity_type(entity_id) or "角色"
+            state_changes = sm.get_state_changes(entity_id)
+            recent_changes = state_changes[-5:] if state_changes else []
+
+            return EntitySnapshot(
+                id=entity_id,
+                name=entity.get("name", entity_id),
+                type=entity_type,
+                tier=entity.get("tier", "核心"),
+                aliases=entity.get("aliases", []),
+                attributes={k: v for k, v in entity.items()
+                            if k not in ("name", "tier", "aliases", "first_appearance", "last_appearance")},
+                first_appearance=entity.get("first_appearance", 0),
+                last_appearance=entity.get("last_appearance", 0),
+                recent_state_changes=recent_changes,
+            )
+        except Exception as e:
+            logger.warning("query_entity(%s) failed: %s", entity_id, e)
+            return None
+
+    def query_rules(self, domain: str = "") -> List[Rule]:
+        try:
+            store = self._memory_store()
+            items = store.query(category="world_rule", status="active")
+            rules = []
+            for item in items:
+                if domain and item.subject != domain and domain not in item.value:
+                    continue
+                rules.append(Rule(
+                    id=item.id,
+                    subject=item.subject,
+                    field=item.field,
+                    value=item.value,
+                    domain=item.subject,
+                    source_chapter=item.source_chapter,
+                ))
+            return rules
+        except Exception as e:
+            logger.warning("query_rules failed: %s", e)
+            return []
+
+    def read_summary(self, chapter: int) -> str:
+        padded = f"{chapter:04d}"
+        summary_file = self.config.webnovel_dir / "summaries" / f"ch{padded}.md"
+        try:
+            if summary_file.exists():
+                return summary_file.read_text(encoding="utf-8")
+            return ""
+        except Exception as e:
+            logger.warning("read_summary(%d) failed: %s", chapter, e)
+            return ""
+
+    def get_open_loops(self, status: str = "active") -> List[OpenLoop]:
+        try:
+            store = self._memory_store()
+            items = store.query(category="open_loop", status=status)
+            return [
+                OpenLoop(
+                    id=item.id,
+                    content=item.value,
+                    status=item.status,
+                    planted_chapter=item.source_chapter,
+                    expected_payoff=item.payload.get("expected_payoff", ""),
+                    urgency=float(item.payload.get("urgency", 0.0)),
+                )
+                for item in items
+            ]
+        except Exception as e:
+            logger.warning("get_open_loops failed: %s", e)
+            return []
+
+    def get_timeline(self, from_ch: int, to_ch: int) -> List[TimelineEvent]:
+        try:
+            store = self._memory_store()
+            items = store.query(category="timeline", status="active")
+            events = []
+            for item in items:
+                ch = item.source_chapter
+                if from_ch <= ch <= to_ch:
+                    events.append(TimelineEvent(
+                        event=item.value,
+                        chapter=ch,
+                        time_hint=item.field,
+                        event_type=item.subject,
+                    ))
+            events.sort(key=lambda e: e.chapter)
+            return events
+        except Exception as e:
+            logger.warning("get_timeline failed: %s", e)
+            return []

+ 222 - 0
webnovel-writer/scripts/data_modules/tests/test_memory_contract_adapter.py

@@ -0,0 +1,222 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""MemoryContractAdapter 集成测试。"""
+from __future__ import annotations
+
+import json
+import sys
+from pathlib import Path
+
+import pytest
+
+# 确保 scripts/ 在 sys.path 中
+_scripts_dir = str(Path(__file__).resolve().parent.parent.parent)
+if _scripts_dir not in sys.path:
+    sys.path.insert(0, _scripts_dir)
+
+from data_modules.config import DataModulesConfig
+from data_modules.memory_contract import (
+    CommitResult,
+    ContextPack,
+    EntitySnapshot,
+    MemoryContract,
+    OpenLoop,
+    Rule,
+    TimelineEvent,
+)
+from data_modules.memory_contract_adapter import MemoryContractAdapter
+
+
+def _make_project(tmp_path: Path) -> DataModulesConfig:
+    """创建最小项目结构并返回配置。"""
+    webnovel_dir = tmp_path / ".webnovel"
+    webnovel_dir.mkdir(parents=True, exist_ok=True)
+    (webnovel_dir / "state.json").write_text("{}", encoding="utf-8")
+    (webnovel_dir / "summaries").mkdir(exist_ok=True)
+    return DataModulesConfig.from_project_root(tmp_path)
+
+
+class TestAdapterSatisfiesProtocol:
+    def test_isinstance_check(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        adapter = MemoryContractAdapter(cfg)
+        assert isinstance(adapter, MemoryContract)
+
+
+class TestReadSummary:
+    def test_read_existing_summary(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        summary_dir = cfg.webnovel_dir / "summaries"
+        summary_dir.mkdir(parents=True, exist_ok=True)
+        (summary_dir / "ch0010.md").write_text("第10章摘要", encoding="utf-8")
+
+        adapter = MemoryContractAdapter(cfg)
+        text = adapter.read_summary(10)
+        assert text == "第10章摘要"
+
+    def test_read_missing_summary(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        adapter = MemoryContractAdapter(cfg)
+        assert adapter.read_summary(999) == ""
+
+
+class TestQueryEntity:
+    def test_query_nonexistent_entity(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        adapter = MemoryContractAdapter(cfg)
+        assert adapter.query_entity("nobody") is None
+
+    def test_query_existing_entity(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        # 写入包含实体的 state.json
+        state = {
+            "entities_v3": {
+                "角色": {
+                    "xiaoyan": {
+                        "name": "萧炎",
+                        "tier": "核心",
+                        "aliases": ["他"],
+                        "realm": "斗帝",
+                        "first_appearance": 1,
+                        "last_appearance": 100,
+                    }
+                }
+            },
+            "state_changes": [
+                {"entity_id": "xiaoyan", "field": "realm", "old": "斗圣", "new": "斗帝", "chapter": 100}
+            ],
+        }
+        (cfg.state_file).write_text(json.dumps(state, ensure_ascii=False), encoding="utf-8")
+
+        adapter = MemoryContractAdapter(cfg)
+        snap = adapter.query_entity("xiaoyan")
+        assert snap is not None
+        assert snap.name == "萧炎"
+        assert snap.type == "角色"
+        assert snap.tier == "核心"
+        assert "他" in snap.aliases
+        assert len(snap.recent_state_changes) == 1
+
+
+class TestQueryRules:
+    def test_query_rules_empty(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        adapter = MemoryContractAdapter(cfg)
+        assert adapter.query_rules() == []
+
+    def test_query_rules_with_data(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        # 写入 scratchpad 数据
+        from data_modules.memory.schema import MemoryItem
+        from data_modules.memory.store import ScratchpadManager
+
+        store = ScratchpadManager(cfg)
+        store.upsert_item(MemoryItem(
+            id="rule-1", layer="semantic", category="world_rule",
+            subject="力量体系", field="异火数量", value="23种",
+            status="active", source_chapter=1,
+        ))
+
+        adapter = MemoryContractAdapter(cfg)
+        rules = adapter.query_rules()
+        assert len(rules) == 1
+        assert rules[0].value == "23种"
+        assert rules[0].domain == "力量体系"
+
+    def test_query_rules_filter_by_domain(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        from data_modules.memory.schema import MemoryItem
+        from data_modules.memory.store import ScratchpadManager
+
+        store = ScratchpadManager(cfg)
+        store.upsert_item(MemoryItem(
+            id="rule-1", layer="semantic", category="world_rule",
+            subject="力量体系", field="异火数量", value="23种",
+            status="active", source_chapter=1,
+        ))
+        store.upsert_item(MemoryItem(
+            id="rule-2", layer="semantic", category="world_rule",
+            subject="社会结构", field="帝国数量", value="4个",
+            status="active", source_chapter=2,
+        ))
+
+        adapter = MemoryContractAdapter(cfg)
+        rules = adapter.query_rules(domain="力量体系")
+        assert len(rules) == 1
+        assert rules[0].field == "异火数量"
+
+
+class TestGetOpenLoops:
+    def test_get_open_loops_empty(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        adapter = MemoryContractAdapter(cfg)
+        assert adapter.get_open_loops() == []
+
+    def test_get_open_loops_with_data(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        from data_modules.memory.schema import MemoryItem
+        from data_modules.memory.store import ScratchpadManager
+
+        store = ScratchpadManager(cfg)
+        store.upsert_item(MemoryItem(
+            id="ol-1", layer="semantic", category="open_loop",
+            subject="三年之约", field="", value="萧炎与纳兰嫣然三年之约",
+            status="active", source_chapter=1,
+            payload={"expected_payoff": "大比", "urgency": 0.9},
+        ))
+
+        adapter = MemoryContractAdapter(cfg)
+        loops = adapter.get_open_loops()
+        assert len(loops) == 1
+        assert loops[0].content == "萧炎与纳兰嫣然三年之约"
+        assert loops[0].urgency == 0.9
+
+
+class TestGetTimeline:
+    def test_get_timeline_empty(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        adapter = MemoryContractAdapter(cfg)
+        assert adapter.get_timeline(1, 100) == []
+
+    def test_get_timeline_filters_by_range(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        from data_modules.memory.schema import MemoryItem
+        from data_modules.memory.store import ScratchpadManager
+
+        store = ScratchpadManager(cfg)
+        for ch in [5, 10, 50, 100]:
+            store.upsert_item(MemoryItem(
+                id=f"tl-{ch}", layer="semantic", category="timeline",
+                subject="事件", field=f"第{ch}章时", value=f"事件{ch}",
+                status="active", source_chapter=ch,
+            ))
+
+        adapter = MemoryContractAdapter(cfg)
+        events = adapter.get_timeline(8, 55)
+        assert len(events) == 2
+        assert events[0].chapter == 10
+        assert events[1].chapter == 50
+
+
+class TestLoadContext:
+    def test_load_context_returns_context_pack(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        adapter = MemoryContractAdapter(cfg)
+        pack = adapter.load_context(10)
+        assert isinstance(pack, ContextPack)
+        assert pack.chapter == 10
+
+
+class TestCommitChapter:
+    def test_commit_chapter_basic(self, tmp_path):
+        cfg = _make_project(tmp_path)
+        adapter = MemoryContractAdapter(cfg)
+        result = adapter.commit_chapter(1, {
+            "entities_appeared": [{"id": "xiaoyan", "type": "角色"}],
+            "entities_new": [],
+            "state_changes": [],
+            "relationships_new": [],
+        })
+        assert isinstance(result, CommitResult)
+        assert result.chapter == 1
+        assert result.entities_updated == 1