test_data_modules.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Data Modules 单元测试
  5. """
  6. import pytest
  7. import asyncio
  8. import json
  9. import tempfile
  10. from pathlib import Path
  11. from data_modules import (
  12. DataModulesConfig,
  13. EntityLinker,
  14. StateManager,
  15. IndexManager,
  16. RAGAdapter,
  17. StyleSampler,
  18. EntityState,
  19. ChapterMeta,
  20. SceneMeta,
  21. StyleSample,
  22. )
  23. @pytest.fixture
  24. def temp_project():
  25. """创建临时项目目录"""
  26. with tempfile.TemporaryDirectory() as tmpdir:
  27. config = DataModulesConfig.from_project_root(tmpdir)
  28. config.ensure_dirs()
  29. yield config
  30. class TestEntityLinker:
  31. """实体链接器测试"""
  32. def test_register_and_lookup_alias(self, temp_project):
  33. linker = EntityLinker(temp_project)
  34. # 注册别名
  35. assert linker.register_alias("xiaoyan", "萧炎")
  36. assert linker.register_alias("xiaoyan", "小炎子")
  37. # 查找
  38. assert linker.lookup_alias("萧炎") == "xiaoyan"
  39. assert linker.lookup_alias("小炎子") == "xiaoyan"
  40. assert linker.lookup_alias("不存在") is None
  41. def test_alias_one_to_many(self, temp_project):
  42. """v5.0: 同一别名可映射多个实体(一对多)"""
  43. linker = EntityLinker(temp_project)
  44. linker.register_alias("xiaoyan", "萧炎", "角色")
  45. # v5.0: 同一别名可绑定不同实体(一对多)
  46. assert linker.register_alias("other_person", "萧炎", "角色")
  47. # 查找所有匹配
  48. entries = linker.lookup_alias_all("萧炎")
  49. assert len(entries) == 2
  50. def test_get_all_aliases(self, temp_project):
  51. linker = EntityLinker(temp_project)
  52. linker.register_alias("xiaoyan", "萧炎")
  53. linker.register_alias("xiaoyan", "小炎子")
  54. linker.register_alias("xiaoyan", "炎哥")
  55. aliases = linker.get_all_aliases("xiaoyan")
  56. assert len(aliases) == 3
  57. assert "萧炎" in aliases
  58. def test_confidence_evaluation(self, temp_project):
  59. linker = EntityLinker(temp_project)
  60. # 高置信度
  61. action, adopt, warning = linker.evaluate_confidence(0.9)
  62. assert action == "auto"
  63. assert adopt is True
  64. assert warning is None
  65. # 中置信度
  66. action, adopt, warning = linker.evaluate_confidence(0.6)
  67. assert action == "warn"
  68. assert adopt is True
  69. assert warning is not None
  70. # 低置信度
  71. action, adopt, warning = linker.evaluate_confidence(0.3)
  72. assert action == "manual"
  73. assert adopt is False
  74. def test_process_uncertain(self, temp_project):
  75. linker = EntityLinker(temp_project)
  76. result = linker.process_uncertain(
  77. mention="那位前辈",
  78. candidates=["yaolao", "elder_zhang"],
  79. suggested="yaolao",
  80. confidence=0.7
  81. )
  82. assert result.mention == "那位前辈"
  83. assert result.entity_id == "yaolao"
  84. assert result.adopted is True
  85. assert result.warning is not None
  86. class TestStateManager:
  87. """状态管理器测试"""
  88. def test_add_and_get_entity(self, temp_project):
  89. manager = StateManager(temp_project)
  90. entity = EntityState(
  91. id="xiaoyan",
  92. name="萧炎",
  93. type="角色",
  94. tier="核心"
  95. )
  96. assert manager.add_entity(entity)
  97. # 获取实体
  98. result = manager.get_entity("xiaoyan")
  99. assert result is not None
  100. assert result["canonical_name"] == "萧炎"
  101. def test_update_entity(self, temp_project):
  102. manager = StateManager(temp_project)
  103. entity = EntityState(id="xiaoyan", name="萧炎", type="角色")
  104. manager.add_entity(entity)
  105. # 更新属性 (v5.0: attributes 存在 current 字段)
  106. manager.update_entity("xiaoyan", {"current": {"realm": "斗师"}})
  107. result = manager.get_entity("xiaoyan")
  108. assert result["current"]["realm"] == "斗师"
  109. def test_record_state_change(self, temp_project):
  110. manager = StateManager(temp_project)
  111. entity = EntityState(id="xiaoyan", name="萧炎", type="角色")
  112. manager.add_entity(entity)
  113. manager.record_state_change(
  114. entity_id="xiaoyan",
  115. field="realm",
  116. old_value="斗者",
  117. new_value="斗师",
  118. reason="突破",
  119. chapter=100
  120. )
  121. changes = manager.get_state_changes("xiaoyan")
  122. assert len(changes) == 1
  123. assert changes[0]["new_value"] == "斗师"
  124. def test_add_relationship(self, temp_project):
  125. manager = StateManager(temp_project)
  126. manager.add_relationship(
  127. from_entity="xiaoyan",
  128. to_entity="yaolao",
  129. rel_type="师徒",
  130. description="药老收萧炎为徒",
  131. chapter=10
  132. )
  133. rels = manager.get_relationships("xiaoyan")
  134. assert len(rels) == 1
  135. assert rels[0]["type"] == "师徒"
  136. def test_process_chapter_result(self, temp_project):
  137. manager = StateManager(temp_project)
  138. result = {
  139. "entities_appeared": [
  140. {"id": "xiaoyan", "mentions": ["萧炎", "他"]}
  141. ],
  142. "entities_new": [
  143. {"suggested_id": "hongyi_girl", "name": "红衣女子", "type": "角色", "tier": "装饰"}
  144. ],
  145. "state_changes": [
  146. {"entity_id": "xiaoyan", "field": "realm", "old": "斗者", "new": "斗师", "reason": "突破"}
  147. ],
  148. "relationships_new": [
  149. {"from": "xiaoyan", "to": "hongyi_girl", "type": "相识", "description": "初次见面"}
  150. ]
  151. }
  152. # 先添加萧炎
  153. manager.add_entity(EntityState(id="xiaoyan", name="萧炎", type="角色"))
  154. warnings = manager.process_chapter_result(100, result)
  155. # 验证新实体被添加
  156. assert manager.get_entity("hongyi_girl") is not None
  157. # 验证状态变化
  158. changes = manager.get_state_changes("xiaoyan")
  159. assert len(changes) == 1
  160. # 验证进度更新
  161. assert manager.get_current_chapter() == 100
  162. def test_save_state_with_init_project_schema(self, temp_project):
  163. """回归:init_project 生成的 state.json 无 meta 字段,StateManager 仍应可写入。"""
  164. # 模拟 init_project.py 生成的 v5.0 state.json 形状(包含 entities_v3/alias_index)
  165. init_state = {
  166. "project_info": {"title": "测试书名", "genre": "修仙/玄幻", "created_at": "2026-01-01"},
  167. "progress": {"current_chapter": 0, "total_words": 0, "last_updated": "2026-01-01 00:00:00"},
  168. "protagonist_state": {"name": "测试主角"},
  169. "relationships": {},
  170. "world_settings": {"power_system": [], "factions": [], "locations": []},
  171. "plot_threads": {"active_threads": [], "foreshadowing": []},
  172. "review_checkpoints": [],
  173. "strand_tracker": {"current_dominant": "quest", "history": []},
  174. "entities_v3": {"角色": {}, "地点": {}, "物品": {}, "势力": {}, "招式": {}},
  175. "alias_index": {},
  176. }
  177. temp_project.state_file.write_text(json.dumps(init_state, ensure_ascii=False, indent=2), encoding="utf-8")
  178. manager = StateManager(temp_project)
  179. manager.update_progress(5, words=100)
  180. manager.save_state()
  181. saved = json.loads(temp_project.state_file.read_text(encoding="utf-8"))
  182. assert "meta" not in saved
  183. assert saved["progress"]["current_chapter"] == 5
  184. assert saved["progress"]["total_words"] == 100
  185. assert isinstance(saved.get("entities_v3"), dict)
  186. assert isinstance(saved.get("alias_index"), dict)
  187. def test_save_state_preserves_unrelated_fields(self, temp_project):
  188. """回归:仅写入增量,不应覆盖/丢失其他模块维护的字段。"""
  189. init_state = {
  190. "project_info": {"title": "测试书名", "genre": "修仙/玄幻", "created_at": "2026-01-01"},
  191. "progress": {"current_chapter": 10, "total_words": 1000, "last_updated": "2026-01-01 00:00:00"},
  192. "protagonist_state": {"name": "测试主角"},
  193. "relationships": {"allies": ["药老"], "enemies": []},
  194. "world_settings": {"power_system": [], "factions": [], "locations": []},
  195. "plot_threads": {"active_threads": [{"id": "t1", "title": "主线"}], "foreshadowing": []},
  196. "review_checkpoints": [],
  197. "strand_tracker": {"current_dominant": "quest", "history": []},
  198. "entities_v3": {"角色": {}, "地点": {}, "物品": {}, "势力": {}, "招式": {}},
  199. "alias_index": {},
  200. "custom_field": {"keep": True},
  201. }
  202. temp_project.state_file.write_text(json.dumps(init_state, ensure_ascii=False, indent=2), encoding="utf-8")
  203. manager = StateManager(temp_project)
  204. manager.add_entity(EntityState(id="xiaoyan", name="萧炎", type="角色", tier="核心"))
  205. manager.save_state()
  206. saved = json.loads(temp_project.state_file.read_text(encoding="utf-8"))
  207. assert saved.get("custom_field", {}).get("keep") is True
  208. assert saved.get("plot_threads", {}).get("active_threads", [])[0].get("id") == "t1"
  209. assert isinstance(saved.get("relationships"), dict)
  210. def test_disambiguation_feedback_persisted(self, temp_project):
  211. """回归:中/低置信度消歧必须对 Writer 可见(写入 state.json)。"""
  212. manager = StateManager(temp_project)
  213. result = {
  214. "entities_appeared": [],
  215. "entities_new": [],
  216. "state_changes": [],
  217. "relationships_new": [],
  218. "uncertain": [
  219. {
  220. "mention": "那位前辈",
  221. "context": "那位前辈看了他一眼",
  222. "candidates": [{"type": "角色", "id": "yaolao"}, {"type": "角色", "id": "elder_zhang"}],
  223. "suggested": "yaolao",
  224. "confidence": 0.6,
  225. },
  226. {
  227. "mention": "宗主",
  228. "context": "宗主出现在血煞秘境",
  229. "candidates": ["xueshazonzhu", "lintian"],
  230. "suggested": "xueshazonzhu",
  231. "confidence": 0.4,
  232. },
  233. ],
  234. }
  235. warnings = manager.process_chapter_result(100, result)
  236. manager.save_state()
  237. state = json.loads(temp_project.state_file.read_text(encoding="utf-8"))
  238. assert isinstance(state.get("disambiguation_warnings"), list)
  239. assert isinstance(state.get("disambiguation_pending"), list)
  240. assert len(state["disambiguation_warnings"]) == 1
  241. assert len(state["disambiguation_pending"]) == 1
  242. warn = state["disambiguation_warnings"][0]
  243. assert warn.get("chapter") == 100
  244. assert warn.get("mention") == "那位前辈"
  245. assert warn.get("chosen_id") == "yaolao"
  246. pending = state["disambiguation_pending"][0]
  247. assert pending.get("chapter") == 100
  248. assert pending.get("mention") == "宗主"
  249. # 返回值也应包含可见警告,便于 CLI/日志透出
  250. assert any("消歧警告" in w for w in warnings)
  251. assert any("需人工确认" in w for w in warnings)
  252. class TestIndexManager:
  253. """索引管理器测试"""
  254. def test_add_and_get_chapter(self, temp_project):
  255. manager = IndexManager(temp_project)
  256. meta = ChapterMeta(
  257. chapter=100,
  258. title="突破",
  259. location="天云宗",
  260. word_count=3500,
  261. characters=["xiaoyan", "yaolao"]
  262. )
  263. manager.add_chapter(meta)
  264. result = manager.get_chapter(100)
  265. assert result is not None
  266. assert result["title"] == "突破"
  267. assert "xiaoyan" in result["characters"]
  268. def test_add_scenes(self, temp_project):
  269. manager = IndexManager(temp_project)
  270. scenes = [
  271. SceneMeta(chapter=100, scene_index=1, start_line=1, end_line=50,
  272. location="天云宗·闭关室", summary="萧炎闭关突破", characters=["xiaoyan"]),
  273. SceneMeta(chapter=100, scene_index=2, start_line=51, end_line=100,
  274. location="天云宗·演武场", summary="展示实力", characters=["xiaoyan", "lintian"])
  275. ]
  276. manager.add_scenes(100, scenes)
  277. result = manager.get_scenes(100)
  278. assert len(result) == 2
  279. assert result[0]["location"] == "天云宗·闭关室"
  280. def test_record_appearance(self, temp_project):
  281. manager = IndexManager(temp_project)
  282. manager.record_appearance("xiaoyan", 100, ["萧炎", "他"], 0.95)
  283. manager.record_appearance("yaolao", 100, ["药老"], 0.92)
  284. appearances = manager.get_chapter_appearances(100)
  285. assert len(appearances) == 2
  286. entity_history = manager.get_entity_appearances("xiaoyan")
  287. assert len(entity_history) == 1
  288. def test_search_scenes_by_location(self, temp_project):
  289. manager = IndexManager(temp_project)
  290. scenes = [
  291. SceneMeta(chapter=100, scene_index=1, start_line=1, end_line=50,
  292. location="天云宗·闭关室", summary="闭关", characters=[]),
  293. SceneMeta(chapter=101, scene_index=1, start_line=1, end_line=50,
  294. location="天云宗·大殿", summary="议事", characters=[])
  295. ]
  296. manager.add_scenes(100, scenes[:1])
  297. manager.add_scenes(101, scenes[1:])
  298. results = manager.search_scenes_by_location("天云宗")
  299. assert len(results) == 2
  300. def test_get_stats(self, temp_project):
  301. manager = IndexManager(temp_project)
  302. manager.add_chapter(ChapterMeta(chapter=1, title="", location="", word_count=1000, characters=[]))
  303. manager.add_scenes(1, [SceneMeta(chapter=1, scene_index=1, start_line=1, end_line=50,
  304. location="", summary="", characters=[])])
  305. manager.record_appearance("xiaoyan", 1, [], 1.0)
  306. stats = manager.get_stats()
  307. assert stats["chapters"] == 1
  308. assert stats["scenes"] == 1
  309. assert stats["entities"] == 1
  310. class TestStyleSampler:
  311. """风格样本测试"""
  312. def test_add_and_get_sample(self, temp_project):
  313. sampler = StyleSampler(temp_project)
  314. sample = StyleSample(
  315. id="ch100_s1",
  316. chapter=100,
  317. scene_type="战斗",
  318. content="萧炎一拳轰出...",
  319. score=0.85,
  320. tags=["战斗", "激烈"]
  321. )
  322. assert sampler.add_sample(sample)
  323. results = sampler.get_samples_by_type("战斗")
  324. assert len(results) == 1
  325. assert results[0].id == "ch100_s1"
  326. def test_extract_candidates(self, temp_project):
  327. sampler = StyleSampler(temp_project)
  328. scenes = [
  329. {"index": 1, "summary": "战斗场景", "content": "萧炎一拳轰出,斗气如虹,直接将对手击退三丈,周围的空气都被震得嗡嗡作响..." + "a" * 200}
  330. ]
  331. # 低分不提取
  332. candidates = sampler.extract_candidates(100, "", 70, scenes)
  333. assert len(candidates) == 0
  334. # 高分提取
  335. candidates = sampler.extract_candidates(100, "", 85, scenes)
  336. assert len(candidates) == 1
  337. assert candidates[0].scene_type == "战斗"
  338. def test_select_samples_for_chapter(self, temp_project):
  339. sampler = StyleSampler(temp_project)
  340. # 添加一些样本
  341. for i in range(3):
  342. sampler.add_sample(StyleSample(
  343. id=f"battle_{i}",
  344. chapter=i,
  345. scene_type="战斗",
  346. content=f"战斗内容 {i}",
  347. score=0.9,
  348. tags=[]
  349. ))
  350. samples = sampler.select_samples_for_chapter("本章有一场激烈的战斗")
  351. assert len(samples) <= 3
  352. assert all(s.scene_type == "战斗" for s in samples)
  353. class TestRAGAdapter:
  354. """RAG 适配器测试(不包含 API 调用)"""
  355. def test_bm25_search(self, temp_project):
  356. adapter = RAGAdapter(temp_project)
  357. # 手动插入一些测试数据
  358. with adapter._get_conn() as conn:
  359. cursor = conn.cursor()
  360. # 插入向量记录(空向量,只测试 BM25)
  361. cursor.execute("""
  362. INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding)
  363. VALUES (?, ?, ?, ?, ?)
  364. """, ("ch1_s1", 1, 1, "萧炎在天云宗修炼斗气", b""))
  365. cursor.execute("""
  366. INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding)
  367. VALUES (?, ?, ?, ?, ?)
  368. """, ("ch1_s2", 1, 2, "药老传授炼药技巧", b""))
  369. conn.commit()
  370. # 更新 BM25 索引
  371. adapter._update_bm25_index(cursor, "ch1_s1", "萧炎在天云宗修炼斗气")
  372. adapter._update_bm25_index(cursor, "ch1_s2", "药老传授炼药技巧")
  373. conn.commit()
  374. # BM25 搜索
  375. results = adapter.bm25_search("萧炎修炼", top_k=5)
  376. assert len(results) >= 1
  377. assert results[0].chunk_id == "ch1_s1"
  378. def test_tokenize(self, temp_project):
  379. adapter = RAGAdapter(temp_project)
  380. tokens = adapter._tokenize("萧炎hello世界world")
  381. assert "萧" in tokens
  382. assert "炎" in tokens
  383. assert "hello" in tokens
  384. assert "world" in tokens
  385. if __name__ == "__main__":
  386. pytest.main([__file__, "-v"])