|
|
@@ -56,6 +56,25 @@ def test_style_sampler_more(temp_project):
|
|
|
assert "战斗" in tags
|
|
|
|
|
|
|
|
|
+def test_style_sampler_ignores_corrupt_tag_json(temp_project):
|
|
|
+ sampler = StyleSampler(temp_project)
|
|
|
+ with sampler._get_conn() as conn:
|
|
|
+ conn.execute(
|
|
|
+ """
|
|
|
+ INSERT INTO samples
|
|
|
+ (id, chapter, scene_type, content, score, tags, created_at)
|
|
|
+ VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
|
|
+ """,
|
|
|
+ ("bad-tags", 1, SceneType.BATTLE.value, "战斗描写" * 50, 0.8, "[bad-json"),
|
|
|
+ )
|
|
|
+ conn.commit()
|
|
|
+
|
|
|
+ samples = sampler.get_best_samples(limit=5)
|
|
|
+
|
|
|
+ assert samples[0].id == "bad-tags"
|
|
|
+ assert samples[0].tags == []
|
|
|
+
|
|
|
+
|
|
|
def test_style_sampler_cli(temp_project, monkeypatch, capsys):
|
|
|
root = str(temp_project.project_root)
|
|
|
|