test_style_sampler_cli.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. StyleSampler extra tests + CLI
  5. """
  6. import sys
  7. import json
  8. import pytest
  9. import data_modules.style_sampler as sampler_module
  10. from data_modules.style_sampler import StyleSampler, StyleSample, SceneType
  11. from data_modules.config import DataModulesConfig
  12. @pytest.fixture
  13. def temp_project(tmp_path):
  14. cfg = DataModulesConfig.from_project_root(tmp_path)
  15. cfg.ensure_dirs()
  16. if not cfg.state_file.exists():
  17. cfg.state_file.write_text("{}", encoding="utf-8")
  18. return cfg
  19. def test_style_sampler_more(temp_project):
  20. sampler = StyleSampler(temp_project)
  21. sample = StyleSample(
  22. id="ch1_s1",
  23. chapter=1,
  24. scene_type=SceneType.BATTLE.value,
  25. content="战斗描写很精彩",
  26. score=0.9,
  27. tags=["战斗"],
  28. )
  29. assert sampler.add_sample(sample) is True
  30. assert sampler.add_sample(sample) is False
  31. best = sampler.get_best_samples(limit=5)
  32. assert len(best) == 1
  33. stats = sampler.get_stats()
  34. assert stats["total"] == 1
  35. # scene type inference
  36. assert sampler._infer_scene_types("一场战斗") == [SceneType.BATTLE.value]
  37. assert sampler._infer_scene_types("对话和谈话") == [SceneType.DIALOGUE.value]
  38. assert sampler._infer_scene_types("心理情感描写") == [SceneType.EMOTION.value]
  39. # classify and tags
  40. scene_type = sampler._classify_scene_type({"summary": "紧张", "content": ""})
  41. assert scene_type == SceneType.TENSION.value
  42. tags = sampler._extract_tags("战斗 修炼 对话 描写")
  43. assert "战斗" in tags
  44. def test_style_sampler_ignores_corrupt_tag_json(temp_project):
  45. sampler = StyleSampler(temp_project)
  46. with sampler._get_conn() as conn:
  47. conn.execute(
  48. """
  49. INSERT INTO samples
  50. (id, chapter, scene_type, content, score, tags, created_at)
  51. VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
  52. """,
  53. ("bad-tags", 1, SceneType.BATTLE.value, "战斗描写" * 50, 0.8, "[bad-json"),
  54. )
  55. conn.commit()
  56. samples = sampler.get_best_samples(limit=5)
  57. assert samples[0].id == "bad-tags"
  58. assert samples[0].tags == []
  59. def test_style_sampler_cli(temp_project, monkeypatch, capsys):
  60. root = str(temp_project.project_root)
  61. def run_cli(args):
  62. monkeypatch.setattr(sys, "argv", ["style_sampler"] + args)
  63. sampler_module.main()
  64. run_cli(["--project-root", root, "stats"])
  65. run_cli(["--project-root", root, "list", "--limit", "5"])
  66. run_cli(
  67. [
  68. "--project-root",
  69. root,
  70. "extract",
  71. "--chapter",
  72. "1",
  73. "--score",
  74. "90",
  75. "--scenes",
  76. json.dumps(
  77. [
  78. {
  79. "index": 1,
  80. "summary": "战斗场景",
  81. "content": "战斗" + "a" * 300,
  82. }
  83. ],
  84. ensure_ascii=False,
  85. ),
  86. ]
  87. )
  88. run_cli(["--project-root", root, "list", "--type", "战斗", "--limit", "5"])
  89. run_cli(["--project-root", root, "select", "--outline", "本章有一场战斗", "--max", "2"])
  90. capsys.readouterr()