test_context_manager.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. ContextManager and SnapshotManager tests
  5. """
  6. import json
  7. import pytest
  8. from data_modules.config import DataModulesConfig
  9. from data_modules.index_manager import IndexManager, EntityMeta
  10. from data_modules.context_manager import ContextManager
  11. from data_modules.snapshot_manager import SnapshotManager, SnapshotVersionMismatch
  12. from data_modules.query_router import QueryRouter
  13. @pytest.fixture
  14. def temp_project(tmp_path):
  15. cfg = DataModulesConfig.from_project_root(tmp_path)
  16. cfg.ensure_dirs()
  17. return cfg
  18. def test_snapshot_manager_roundtrip(temp_project):
  19. manager = SnapshotManager(temp_project)
  20. payload = {"hello": "world"}
  21. manager.save_snapshot(1, payload)
  22. loaded = manager.load_snapshot(1)
  23. assert loaded["payload"] == payload
  24. def test_snapshot_version_mismatch(temp_project):
  25. manager = SnapshotManager(temp_project, version="1.0")
  26. manager.save_snapshot(1, {"a": 1})
  27. other = SnapshotManager(temp_project, version="2.0")
  28. with pytest.raises(SnapshotVersionMismatch):
  29. other.load_snapshot(1)
  30. def test_context_manager_build_and_filter(temp_project):
  31. state = {
  32. "protagonist_state": {"name": "萧炎", "location": {"current": "天云宗"}},
  33. "chapter_meta": {"0001": {"hook": "测试"}},
  34. }
  35. temp_project.state_file.write_text(json.dumps(state, ensure_ascii=False), encoding="utf-8")
  36. # preferences and memory
  37. (temp_project.webnovel_dir / "preferences.json").write_text(json.dumps({"tone": "热血"}, ensure_ascii=False), encoding="utf-8")
  38. (temp_project.webnovel_dir / "project_memory.json").write_text(json.dumps({"patterns": []}, ensure_ascii=False), encoding="utf-8")
  39. idx = IndexManager(temp_project)
  40. idx.upsert_entity(
  41. EntityMeta(
  42. id="xiaoyan",
  43. type="角色",
  44. canonical_name="萧炎",
  45. current={},
  46. first_appearance=1,
  47. last_appearance=1,
  48. )
  49. )
  50. idx.upsert_entity(
  51. EntityMeta(
  52. id="bad",
  53. type="角色",
  54. canonical_name="坏人",
  55. current={},
  56. first_appearance=1,
  57. last_appearance=1,
  58. )
  59. )
  60. idx.record_appearance("xiaoyan", 1, ["萧炎"], 1.0)
  61. idx.record_appearance("bad", 1, ["坏人"], 1.0)
  62. invalid_id = idx.mark_invalid_fact("entity", "bad", "错误")
  63. idx.resolve_invalid_fact(invalid_id, "confirm")
  64. manager = ContextManager(temp_project)
  65. payload = manager.build_context(1, use_snapshot=False, save_snapshot=False)
  66. characters = payload["sections"]["scene"]["content"]["appearing_characters"]
  67. assert any(c.get("entity_id") == "xiaoyan" for c in characters)
  68. assert not any(c.get("entity_id") == "bad" for c in characters)
  69. assert payload["sections"]["preferences"]["content"].get("tone") == "热血"
  70. def test_query_router():
  71. router = QueryRouter()
  72. assert router.route("角色是谁") == "entity"
  73. assert router.route("发生了什么剧情") == "plot"
  74. assert "A" in router.split("A, B;C")
  75. def test_context_snapshot_respects_template(temp_project):
  76. state = {
  77. "protagonist_state": {"name": "萧炎"},
  78. "chapter_meta": {},
  79. "disambiguation_warnings": [],
  80. "disambiguation_pending": [],
  81. }
  82. temp_project.state_file.write_text(json.dumps(state, ensure_ascii=False), encoding="utf-8")
  83. manager = ContextManager(temp_project)
  84. plot_payload = manager.build_context(1, template="plot", use_snapshot=True, save_snapshot=True)
  85. battle_payload = manager.build_context(1, template="battle", use_snapshot=True, save_snapshot=True)
  86. assert plot_payload.get("template") == "plot"
  87. assert battle_payload.get("template") == "battle"