test_style_sampler_cli.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. StyleSampler extra tests + CLI
  5. """
  6. import sys
  7. import json
  8. import pytest
  9. import data_modules.style_sampler as sampler_module
  10. from data_modules.style_sampler import StyleSampler, StyleSample, SceneType
  11. from data_modules.config import DataModulesConfig
  12. @pytest.fixture
  13. def temp_project(tmp_path):
  14. cfg = DataModulesConfig.from_project_root(tmp_path)
  15. cfg.ensure_dirs()
  16. return cfg
  17. def test_style_sampler_more(temp_project):
  18. sampler = StyleSampler(temp_project)
  19. sample = StyleSample(
  20. id="ch1_s1",
  21. chapter=1,
  22. scene_type=SceneType.BATTLE.value,
  23. content="战斗描写很精彩",
  24. score=0.9,
  25. tags=["战斗"],
  26. )
  27. assert sampler.add_sample(sample) is True
  28. assert sampler.add_sample(sample) is False
  29. best = sampler.get_best_samples(limit=5)
  30. assert len(best) == 1
  31. stats = sampler.get_stats()
  32. assert stats["total"] == 1
  33. # scene type inference
  34. assert sampler._infer_scene_types("一场战斗") == [SceneType.BATTLE.value]
  35. assert sampler._infer_scene_types("对话和谈话") == [SceneType.DIALOGUE.value]
  36. assert sampler._infer_scene_types("心理情感描写") == [SceneType.EMOTION.value]
  37. # classify and tags
  38. scene_type = sampler._classify_scene_type({"summary": "紧张", "content": ""})
  39. assert scene_type == SceneType.TENSION.value
  40. tags = sampler._extract_tags("战斗 修炼 对话 描写")
  41. assert "战斗" in tags
  42. def test_style_sampler_cli(temp_project, monkeypatch, capsys):
  43. root = str(temp_project.project_root)
  44. def run_cli(args):
  45. monkeypatch.setattr(sys, "argv", ["style_sampler"] + args)
  46. sampler_module.main()
  47. run_cli(["--project-root", root, "stats"])
  48. run_cli(["--project-root", root, "list", "--limit", "5"])
  49. run_cli(
  50. [
  51. "--project-root",
  52. root,
  53. "extract",
  54. "--chapter",
  55. "1",
  56. "--score",
  57. "90",
  58. "--scenes",
  59. json.dumps(
  60. [
  61. {
  62. "index": 1,
  63. "summary": "战斗场景",
  64. "content": "战斗" + "a" * 300,
  65. }
  66. ],
  67. ensure_ascii=False,
  68. ),
  69. ]
  70. )
  71. run_cli(["--project-root", root, "list", "--type", "战斗", "--limit", "5"])
  72. run_cli(["--project-root", root, "select", "--outline", "本章有一场战斗", "--max", "2"])
  73. capsys.readouterr()