test_data_modules.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Data Modules 单元测试
  5. """
  6. import pytest
  7. import asyncio
  8. import json
  9. import tempfile
  10. from pathlib import Path
  11. from data_modules import (
  12. DataModulesConfig,
  13. EntityLinker,
  14. StateManager,
  15. IndexManager,
  16. RAGAdapter,
  17. StyleSampler,
  18. EntityState,
  19. ChapterMeta,
  20. SceneMeta,
  21. StyleSample,
  22. )
  23. @pytest.fixture
  24. def temp_project():
  25. """创建临时项目目录"""
  26. with tempfile.TemporaryDirectory() as tmpdir:
  27. config = DataModulesConfig.from_project_root(tmpdir)
  28. config.ensure_dirs()
  29. yield config
  30. class TestEntityLinker:
  31. """实体链接器测试"""
  32. def test_register_and_lookup_alias(self, temp_project):
  33. linker = EntityLinker(temp_project)
  34. # 注册别名
  35. assert linker.register_alias("xiaoyan", "萧炎")
  36. assert linker.register_alias("xiaoyan", "小炎子")
  37. # 查找
  38. assert linker.lookup_alias("萧炎") == "xiaoyan"
  39. assert linker.lookup_alias("小炎子") == "xiaoyan"
  40. assert linker.lookup_alias("不存在") is None
  41. def test_alias_one_to_many(self, temp_project):
  42. """v5.0: 同一别名可映射多个实体(一对多)"""
  43. linker = EntityLinker(temp_project)
  44. linker.register_alias("xiaoyan", "萧炎", "角色")
  45. # v5.0: 同一别名可绑定不同实体(一对多)
  46. assert linker.register_alias("other_person", "萧炎", "角色")
  47. # 查找所有匹配
  48. entries = linker.lookup_alias_all("萧炎")
  49. assert len(entries) == 2
  50. def test_get_all_aliases(self, temp_project):
  51. linker = EntityLinker(temp_project)
  52. linker.register_alias("xiaoyan", "萧炎")
  53. linker.register_alias("xiaoyan", "小炎子")
  54. linker.register_alias("xiaoyan", "炎哥")
  55. aliases = linker.get_all_aliases("xiaoyan")
  56. assert len(aliases) == 3
  57. assert "萧炎" in aliases
  58. def test_confidence_evaluation(self, temp_project):
  59. linker = EntityLinker(temp_project)
  60. # 高置信度
  61. action, adopt, warning = linker.evaluate_confidence(0.9)
  62. assert action == "auto"
  63. assert adopt is True
  64. assert warning is None
  65. # 中置信度
  66. action, adopt, warning = linker.evaluate_confidence(0.6)
  67. assert action == "warn"
  68. assert adopt is True
  69. assert warning is not None
  70. # 低置信度
  71. action, adopt, warning = linker.evaluate_confidence(0.3)
  72. assert action == "manual"
  73. assert adopt is False
  74. def test_process_uncertain(self, temp_project):
  75. linker = EntityLinker(temp_project)
  76. result = linker.process_uncertain(
  77. mention="那位前辈",
  78. candidates=["yaolao", "elder_zhang"],
  79. suggested="yaolao",
  80. confidence=0.7
  81. )
  82. assert result.mention == "那位前辈"
  83. assert result.entity_id == "yaolao"
  84. assert result.adopted is True
  85. assert result.warning is not None
  86. class TestStateManager:
  87. """状态管理器测试"""
  88. def test_add_and_get_entity(self, temp_project):
  89. manager = StateManager(temp_project)
  90. entity = EntityState(
  91. id="xiaoyan",
  92. name="萧炎",
  93. type="角色",
  94. tier="核心"
  95. )
  96. assert manager.add_entity(entity)
  97. # 获取实体
  98. result = manager.get_entity("xiaoyan")
  99. assert result is not None
  100. assert result["canonical_name"] == "萧炎"
  101. def test_update_entity(self, temp_project):
  102. manager = StateManager(temp_project)
  103. entity = EntityState(id="xiaoyan", name="萧炎", type="角色")
  104. manager.add_entity(entity)
  105. # 更新属性 (v5.0: attributes 存在 current 字段)
  106. manager.update_entity("xiaoyan", {"current": {"realm": "斗师"}})
  107. result = manager.get_entity("xiaoyan")
  108. assert result["current"]["realm"] == "斗师"
  109. def test_record_state_change(self, temp_project):
  110. manager = StateManager(temp_project)
  111. entity = EntityState(id="xiaoyan", name="萧炎", type="角色")
  112. manager.add_entity(entity)
  113. manager.record_state_change(
  114. entity_id="xiaoyan",
  115. field="realm",
  116. old_value="斗者",
  117. new_value="斗师",
  118. reason="突破",
  119. chapter=100
  120. )
  121. changes = manager.get_state_changes("xiaoyan")
  122. assert len(changes) == 1
  123. assert changes[0]["new_value"] == "斗师"
  124. def test_add_relationship(self, temp_project):
  125. manager = StateManager(temp_project)
  126. manager.add_relationship(
  127. from_entity="xiaoyan",
  128. to_entity="yaolao",
  129. rel_type="师徒",
  130. description="药老收萧炎为徒",
  131. chapter=10
  132. )
  133. rels = manager.get_relationships("xiaoyan")
  134. assert len(rels) == 1
  135. assert rels[0]["type"] == "师徒"
  136. def test_process_chapter_result(self, temp_project):
  137. manager = StateManager(temp_project)
  138. result = {
  139. "entities_appeared": [
  140. {"id": "xiaoyan", "mentions": ["萧炎", "他"]}
  141. ],
  142. "entities_new": [
  143. {"suggested_id": "hongyi_girl", "name": "红衣女子", "type": "角色", "tier": "装饰"}
  144. ],
  145. "state_changes": [
  146. {"entity_id": "xiaoyan", "field": "realm", "old": "斗者", "new": "斗师", "reason": "突破"}
  147. ],
  148. "relationships_new": [
  149. {"from": "xiaoyan", "to": "hongyi_girl", "type": "相识", "description": "初次见面"}
  150. ]
  151. }
  152. # 先添加萧炎
  153. manager.add_entity(EntityState(id="xiaoyan", name="萧炎", type="角色"))
  154. warnings = manager.process_chapter_result(100, result)
  155. # 验证新实体被添加
  156. assert manager.get_entity("hongyi_girl") is not None
  157. # 验证状态变化
  158. changes = manager.get_state_changes("xiaoyan")
  159. assert len(changes) == 1
  160. # 验证进度更新
  161. assert manager.get_current_chapter() == 100
  162. def test_save_state_with_init_project_schema(self, temp_project):
  163. """回归:init_project 生成的 state.json,StateManager 仍应可写入。(v5.1 SQLite-only)"""
  164. # v5.1: state.json 不再包含 entities_v3/alias_index,实体数据在 SQLite
  165. init_state = {
  166. "project_info": {"title": "测试书名", "genre": "修仙/玄幻", "created_at": "2026-01-01"},
  167. "progress": {"current_chapter": 0, "total_words": 0, "last_updated": "2026-01-01 00:00:00"},
  168. "protagonist_state": {"name": "测试主角"},
  169. "relationships": {},
  170. "world_settings": {"power_system": [], "factions": [], "locations": []},
  171. "plot_threads": {"active_threads": [], "foreshadowing": []},
  172. "review_checkpoints": [],
  173. "strand_tracker": {"current_dominant": "quest", "history": []},
  174. }
  175. temp_project.state_file.write_text(json.dumps(init_state, ensure_ascii=False, indent=2), encoding="utf-8")
  176. manager = StateManager(temp_project)
  177. manager.update_progress(5, words=100)
  178. manager.save_state()
  179. saved = json.loads(temp_project.state_file.read_text(encoding="utf-8"))
  180. assert "meta" not in saved
  181. assert saved["progress"]["current_chapter"] == 5
  182. assert saved["progress"]["total_words"] == 100
  183. # v5.1: entities_v3/alias_index 不再在 state.json 中
  184. def test_save_state_preserves_unrelated_fields(self, temp_project):
  185. """回归:仅写入增量,不应覆盖/丢失其他模块维护的字段。(v5.1 SQLite-only)"""
  186. init_state = {
  187. "project_info": {"title": "测试书名", "genre": "修仙/玄幻", "created_at": "2026-01-01"},
  188. "progress": {"current_chapter": 10, "total_words": 1000, "last_updated": "2026-01-01 00:00:00"},
  189. "protagonist_state": {"name": "测试主角"},
  190. "relationships": {"allies": ["药老"], "enemies": []},
  191. "world_settings": {"power_system": [], "factions": [], "locations": []},
  192. "plot_threads": {"active_threads": [{"id": "t1", "title": "主线"}], "foreshadowing": []},
  193. "review_checkpoints": [],
  194. "strand_tracker": {"current_dominant": "quest", "history": []},
  195. "custom_field": {"keep": True},
  196. }
  197. temp_project.state_file.write_text(json.dumps(init_state, ensure_ascii=False, indent=2), encoding="utf-8")
  198. manager = StateManager(temp_project)
  199. manager.add_entity(EntityState(id="xiaoyan", name="萧炎", type="角色", tier="核心"))
  200. manager.save_state()
  201. saved = json.loads(temp_project.state_file.read_text(encoding="utf-8"))
  202. assert saved.get("custom_field", {}).get("keep") is True
  203. assert saved.get("plot_threads", {}).get("active_threads", [])[0].get("id") == "t1"
  204. assert isinstance(saved.get("relationships"), dict)
  205. def test_disambiguation_feedback_persisted(self, temp_project):
  206. """回归:中/低置信度消歧必须对 Writer 可见(写入 state.json)。"""
  207. manager = StateManager(temp_project)
  208. result = {
  209. "entities_appeared": [],
  210. "entities_new": [],
  211. "state_changes": [],
  212. "relationships_new": [],
  213. "uncertain": [
  214. {
  215. "mention": "那位前辈",
  216. "context": "那位前辈看了他一眼",
  217. "candidates": [{"type": "角色", "id": "yaolao"}, {"type": "角色", "id": "elder_zhang"}],
  218. "suggested": "yaolao",
  219. "confidence": 0.6,
  220. },
  221. {
  222. "mention": "宗主",
  223. "context": "宗主出现在血煞秘境",
  224. "candidates": ["xueshazonzhu", "lintian"],
  225. "suggested": "xueshazonzhu",
  226. "confidence": 0.4,
  227. },
  228. ],
  229. }
  230. warnings = manager.process_chapter_result(100, result)
  231. manager.save_state()
  232. state = json.loads(temp_project.state_file.read_text(encoding="utf-8"))
  233. assert isinstance(state.get("disambiguation_warnings"), list)
  234. assert isinstance(state.get("disambiguation_pending"), list)
  235. assert len(state["disambiguation_warnings"]) == 1
  236. assert len(state["disambiguation_pending"]) == 1
  237. warn = state["disambiguation_warnings"][0]
  238. assert warn.get("chapter") == 100
  239. assert warn.get("mention") == "那位前辈"
  240. assert warn.get("chosen_id") == "yaolao"
  241. pending = state["disambiguation_pending"][0]
  242. assert pending.get("chapter") == 100
  243. assert pending.get("mention") == "宗主"
  244. # 返回值也应包含可见警告,便于 CLI/日志透出
  245. assert any("消歧警告" in w for w in warnings)
  246. assert any("需人工确认" in w for w in warnings)
  247. class TestIndexManager:
  248. """索引管理器测试"""
  249. def test_add_and_get_chapter(self, temp_project):
  250. manager = IndexManager(temp_project)
  251. meta = ChapterMeta(
  252. chapter=100,
  253. title="突破",
  254. location="天云宗",
  255. word_count=3500,
  256. characters=["xiaoyan", "yaolao"]
  257. )
  258. manager.add_chapter(meta)
  259. result = manager.get_chapter(100)
  260. assert result is not None
  261. assert result["title"] == "突破"
  262. assert "xiaoyan" in result["characters"]
  263. def test_add_scenes(self, temp_project):
  264. manager = IndexManager(temp_project)
  265. scenes = [
  266. SceneMeta(chapter=100, scene_index=1, start_line=1, end_line=50,
  267. location="天云宗·闭关室", summary="萧炎闭关突破", characters=["xiaoyan"]),
  268. SceneMeta(chapter=100, scene_index=2, start_line=51, end_line=100,
  269. location="天云宗·演武场", summary="展示实力", characters=["xiaoyan", "lintian"])
  270. ]
  271. manager.add_scenes(100, scenes)
  272. result = manager.get_scenes(100)
  273. assert len(result) == 2
  274. assert result[0]["location"] == "天云宗·闭关室"
  275. def test_record_appearance(self, temp_project):
  276. manager = IndexManager(temp_project)
  277. manager.record_appearance("xiaoyan", 100, ["萧炎", "他"], 0.95)
  278. manager.record_appearance("yaolao", 100, ["药老"], 0.92)
  279. appearances = manager.get_chapter_appearances(100)
  280. assert len(appearances) == 2
  281. entity_history = manager.get_entity_appearances("xiaoyan")
  282. assert len(entity_history) == 1
  283. def test_search_scenes_by_location(self, temp_project):
  284. manager = IndexManager(temp_project)
  285. scenes = [
  286. SceneMeta(chapter=100, scene_index=1, start_line=1, end_line=50,
  287. location="天云宗·闭关室", summary="闭关", characters=[]),
  288. SceneMeta(chapter=101, scene_index=1, start_line=1, end_line=50,
  289. location="天云宗·大殿", summary="议事", characters=[])
  290. ]
  291. manager.add_scenes(100, scenes[:1])
  292. manager.add_scenes(101, scenes[1:])
  293. results = manager.search_scenes_by_location("天云宗")
  294. assert len(results) == 2
  295. def test_get_stats(self, temp_project):
  296. manager = IndexManager(temp_project)
  297. manager.add_chapter(ChapterMeta(chapter=1, title="", location="", word_count=1000, characters=[]))
  298. manager.add_scenes(1, [SceneMeta(chapter=1, scene_index=1, start_line=1, end_line=50,
  299. location="", summary="", characters=[])])
  300. manager.record_appearance("xiaoyan", 1, [], 1.0)
  301. stats = manager.get_stats()
  302. assert stats["chapters"] == 1
  303. assert stats["scenes"] == 1
  304. assert stats["entities"] == 1
  305. class TestStyleSampler:
  306. """风格样本测试"""
  307. def test_add_and_get_sample(self, temp_project):
  308. sampler = StyleSampler(temp_project)
  309. sample = StyleSample(
  310. id="ch100_s1",
  311. chapter=100,
  312. scene_type="战斗",
  313. content="萧炎一拳轰出...",
  314. score=0.85,
  315. tags=["战斗", "激烈"]
  316. )
  317. assert sampler.add_sample(sample)
  318. results = sampler.get_samples_by_type("战斗")
  319. assert len(results) == 1
  320. assert results[0].id == "ch100_s1"
  321. def test_extract_candidates(self, temp_project):
  322. sampler = StyleSampler(temp_project)
  323. scenes = [
  324. {"index": 1, "summary": "战斗场景", "content": "萧炎一拳轰出,斗气如虹,直接将对手击退三丈,周围的空气都被震得嗡嗡作响..." + "a" * 200}
  325. ]
  326. # 低分不提取
  327. candidates = sampler.extract_candidates(100, "", 70, scenes)
  328. assert len(candidates) == 0
  329. # 高分提取
  330. candidates = sampler.extract_candidates(100, "", 85, scenes)
  331. assert len(candidates) == 1
  332. assert candidates[0].scene_type == "战斗"
  333. def test_select_samples_for_chapter(self, temp_project):
  334. sampler = StyleSampler(temp_project)
  335. # 添加一些样本
  336. for i in range(3):
  337. sampler.add_sample(StyleSample(
  338. id=f"battle_{i}",
  339. chapter=i,
  340. scene_type="战斗",
  341. content=f"战斗内容 {i}",
  342. score=0.9,
  343. tags=[]
  344. ))
  345. samples = sampler.select_samples_for_chapter("本章有一场激烈的战斗")
  346. assert len(samples) <= 3
  347. assert all(s.scene_type == "战斗" for s in samples)
  348. class TestRAGAdapter:
  349. """RAG 适配器测试(不包含 API 调用)"""
  350. def test_bm25_search(self, temp_project):
  351. adapter = RAGAdapter(temp_project)
  352. # 手动插入一些测试数据
  353. with adapter._get_conn() as conn:
  354. cursor = conn.cursor()
  355. # 插入向量记录(空向量,只测试 BM25)
  356. cursor.execute("""
  357. INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding)
  358. VALUES (?, ?, ?, ?, ?)
  359. """, ("ch1_s1", 1, 1, "萧炎在天云宗修炼斗气", b""))
  360. cursor.execute("""
  361. INSERT INTO vectors (chunk_id, chapter, scene_index, content, embedding)
  362. VALUES (?, ?, ?, ?, ?)
  363. """, ("ch1_s2", 1, 2, "药老传授炼药技巧", b""))
  364. conn.commit()
  365. # 更新 BM25 索引
  366. adapter._update_bm25_index(cursor, "ch1_s1", "萧炎在天云宗修炼斗气")
  367. adapter._update_bm25_index(cursor, "ch1_s2", "药老传授炼药技巧")
  368. conn.commit()
  369. # BM25 搜索
  370. results = adapter.bm25_search("萧炎修炼", top_k=5)
  371. assert len(results) >= 1
  372. assert results[0].chunk_id == "ch1_s1"
  373. def test_tokenize(self, temp_project):
  374. adapter = RAGAdapter(temp_project)
  375. tokens = adapter._tokenize("萧炎hello世界world")
  376. assert "萧" in tokens
  377. assert "炎" in tokens
  378. assert "hello" in tokens
  379. assert "world" in tokens
  380. if __name__ == "__main__":
  381. pytest.main([__file__, "-v"])