chapter_commit_schema.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from __future__ import annotations
  4. import hashlib
  5. import json
  6. from typing import Any, ClassVar
  7. from pydantic import (
  8. BaseModel,
  9. ConfigDict,
  10. Field,
  11. ValidationInfo,
  12. field_validator,
  13. model_validator,
  14. )
  15. from .story_event_schema import StoryEvent
  16. EXTRACTION_CORE_FIELDS = ("accepted_events", "state_deltas", "entity_deltas")
  17. EXTRACTION_LIST_FIELDS = (
  18. "accepted_events",
  19. "state_deltas",
  20. "entity_deltas",
  21. "entities_appeared",
  22. "scenes",
  23. )
  24. FULFILLMENT_LIST_FIELDS = (
  25. "planned_nodes",
  26. "covered_nodes",
  27. "missed_nodes",
  28. "extra_nodes",
  29. )
  30. EVENT_TYPE_ALIASES = {
  31. "character_state": "character_state_changed",
  32. "character_state_change": "character_state_changed",
  33. "state_changed": "character_state_changed",
  34. "relationship_change": "relationship_changed",
  35. "relation_changed": "relationship_changed",
  36. "world_rule": "world_rule_revealed",
  37. "rule_revealed": "world_rule_revealed",
  38. "rule_broken": "world_rule_broken",
  39. "breakthrough": "power_breakthrough",
  40. "power_up": "power_breakthrough",
  41. "artifact": "artifact_obtained",
  42. "item_obtained": "artifact_obtained",
  43. "promise": "promise_created",
  44. "promise_resolved": "promise_paid_off",
  45. "promise_fulfilled": "promise_paid_off",
  46. "mystery_introduction": "open_loop_created",
  47. "mystery_introduced": "open_loop_created",
  48. "unresolved_thread": "open_loop_created",
  49. "scene_open": "open_loop_created",
  50. "open_loop": "open_loop_created",
  51. "loop_closed": "open_loop_closed",
  52. }
  53. class CommitArtifactModel(BaseModel):
  54. model_config = ConfigDict(extra="allow")
  55. artifact_name: ClassVar[str]
  56. wrapper_key: ClassVar[str | None] = None
  57. required_top_level_fields: ClassVar[tuple[str, ...]] = ()
  58. @model_validator(mode="before")
  59. @classmethod
  60. def validate_top_level_shape(cls, value: Any) -> Any:
  61. if not isinstance(value, dict):
  62. raise ValueError(f"{cls.artifact_name} must be a JSON object")
  63. wrapper_key = cls.wrapper_key
  64. if wrapper_key and wrapper_key in value:
  65. if cls.artifact_name == "extraction_result":
  66. raise ValueError(
  67. "extraction_result must expose accepted_events/state_deltas/entity_deltas "
  68. "as top-level fields, not nested under extraction"
  69. )
  70. raise ValueError(
  71. f"{cls.artifact_name} fields must be top-level, not nested under {wrapper_key}"
  72. )
  73. missing = [
  74. field for field in cls.required_top_level_fields if field not in value
  75. ]
  76. if missing:
  77. raise ValueError(
  78. f"{cls.artifact_name} missing required top-level fields: "
  79. + ", ".join(missing)
  80. )
  81. return value
  82. def _ensure_list(artifact_name: str, field_name: str, value: Any) -> Any:
  83. if not isinstance(value, list):
  84. raise ValueError(f"{artifact_name}.{field_name} must be a list")
  85. return value
  86. def _ensure_object_list(artifact_name: str, field_name: str, value: Any) -> Any:
  87. _ensure_list(artifact_name, field_name, value)
  88. for index, item in enumerate(value):
  89. if not isinstance(item, dict):
  90. raise ValueError(f"{artifact_name}.{field_name}[{index}] must be a JSON object")
  91. return value
  92. class ReviewResult(CommitArtifactModel):
  93. artifact_name: ClassVar[str] = "review_result"
  94. wrapper_key: ClassVar[str | None] = "review"
  95. required_top_level_fields: ClassVar[tuple[str, ...]] = ("blocking_count",)
  96. blocking_count: int = Field(ge=0, strict=True)
  97. class FulfillmentResult(CommitArtifactModel):
  98. artifact_name: ClassVar[str] = "fulfillment_result"
  99. wrapper_key: ClassVar[str | None] = "fulfillment"
  100. required_top_level_fields: ClassVar[tuple[str, ...]] = FULFILLMENT_LIST_FIELDS
  101. planned_nodes: list[Any]
  102. covered_nodes: list[Any]
  103. missed_nodes: list[Any]
  104. extra_nodes: list[Any]
  105. @field_validator(*FULFILLMENT_LIST_FIELDS, mode="before")
  106. @classmethod
  107. def validate_list_fields(cls, value: Any, info: ValidationInfo) -> Any:
  108. return _ensure_list(cls.artifact_name, info.field_name, value)
  109. class DisambiguationResult(CommitArtifactModel):
  110. artifact_name: ClassVar[str] = "disambiguation_result"
  111. wrapper_key: ClassVar[str | None] = "disambiguation"
  112. required_top_level_fields: ClassVar[tuple[str, ...]] = ("pending",)
  113. pending: list[Any]
  114. @field_validator("pending", mode="before")
  115. @classmethod
  116. def validate_pending(cls, value: Any, info: ValidationInfo) -> Any:
  117. return _ensure_list(cls.artifact_name, info.field_name, value)
  118. class ExtractionResult(CommitArtifactModel):
  119. artifact_name: ClassVar[str] = "extraction_result"
  120. wrapper_key: ClassVar[str | None] = "extraction"
  121. required_top_level_fields: ClassVar[tuple[str, ...]] = EXTRACTION_CORE_FIELDS
  122. accepted_events: list[dict[str, Any]]
  123. state_deltas: list[dict[str, Any]]
  124. entity_deltas: list[dict[str, Any]]
  125. entities_appeared: list[dict[str, Any]] = Field(default_factory=list)
  126. scenes: list[dict[str, Any]] = Field(default_factory=list)
  127. chapter_meta: Any = Field(default_factory=dict)
  128. dominant_strand: Any = ""
  129. summary_text: str = ""
  130. @field_validator(*EXTRACTION_LIST_FIELDS, mode="before")
  131. @classmethod
  132. def validate_object_list_fields(cls, value: Any, info: ValidationInfo) -> Any:
  133. return _ensure_object_list(cls.artifact_name, info.field_name, value)
  134. @field_validator("summary_text", mode="before")
  135. @classmethod
  136. def validate_summary_text(cls, value: Any) -> Any:
  137. if not isinstance(value, str):
  138. raise ValueError("extraction_result.summary_text must be a string")
  139. return value
  140. class AcceptedEventInput(BaseModel):
  141. model_config = ConfigDict(extra="allow")
  142. event_id: str
  143. chapter: int = Field(ge=1)
  144. event_type: str
  145. subject: str
  146. payload: dict[str, Any] = Field(default_factory=dict)
  147. @model_validator(mode="before")
  148. @classmethod
  149. def normalize_aliases(cls, value: Any, info: ValidationInfo) -> Any:
  150. if not isinstance(value, dict):
  151. index = _event_context_index(info)
  152. raise ValueError(f"accepted_events[{index}] must be a JSON object")
  153. payload = dict(value)
  154. context = info.context or {}
  155. chapter = int(payload.get("chapter") or context.get("chapter") or 0)
  156. payload["chapter"] = chapter
  157. event_type = str(payload.get("event_type") or payload.get("type") or "").strip()
  158. if event_type:
  159. normalized_type = event_type.lower().replace("-", "_")
  160. payload["event_type"] = EVENT_TYPE_ALIASES.get(normalized_type, normalized_type)
  161. subject = _event_subject(payload)
  162. if not subject:
  163. index = _event_context_index(info)
  164. raise ValueError(
  165. f"accepted_events[{index}].subject must be a non-empty string"
  166. )
  167. payload["subject"] = subject
  168. if not str(payload.get("event_id") or "").strip():
  169. index = _event_context_index(info)
  170. payload["event_id"] = _generated_event_id(chapter, index + 1, payload)
  171. return payload
  172. class AcceptedEventsInput(BaseModel):
  173. accepted_events: list[Any]
  174. @field_validator("accepted_events", mode="before")
  175. @classmethod
  176. def validate_events_list(cls, value: Any) -> Any:
  177. if not isinstance(value, list):
  178. raise ValueError("accepted_events must be a list")
  179. return value
  180. def normalize(self, chapter: int) -> list[dict[str, Any]]:
  181. normalized: list[dict[str, Any]] = []
  182. for index, event in enumerate(self.accepted_events):
  183. if not isinstance(event, dict):
  184. raise ValueError(f"accepted_events[{index}] must be a JSON object")
  185. payload = AcceptedEventInput.model_validate(
  186. event,
  187. context={"chapter": chapter, "index": index},
  188. ).model_dump()
  189. normalized.append(StoryEvent.model_validate(payload).model_dump())
  190. return normalized
  191. def normalize_accepted_events(chapter: int, events: Any) -> list[dict[str, Any]]:
  192. accepted_events = AcceptedEventsInput.model_validate({"accepted_events": events})
  193. return accepted_events.normalize(chapter)
  194. def _event_context_index(info: ValidationInfo) -> int:
  195. context = info.context or {}
  196. return int(context.get("index") or 0)
  197. def _event_subject(payload: dict[str, Any]) -> str:
  198. for key in ("subject", "entity_id", "from_entity", "to_entity"):
  199. value = payload.get(key)
  200. if isinstance(value, str) and value.strip():
  201. return value.strip()
  202. characters = payload.get("characters")
  203. if isinstance(characters, str) and characters.strip():
  204. return characters.strip()
  205. if isinstance(characters, list):
  206. for character in characters:
  207. if isinstance(character, str) and character.strip():
  208. return character.strip()
  209. event_payload = payload.get("payload") or {}
  210. if isinstance(event_payload, dict):
  211. for key in ("subject", "entity_id", "owner", "holder", "artifact_id", "name"):
  212. value = event_payload.get(key)
  213. if isinstance(value, str) and value.strip():
  214. return value.strip()
  215. return ""
  216. def _generated_event_id(chapter: int, index: int, payload: dict[str, Any]) -> str:
  217. stable_payload = {
  218. key: value
  219. for key, value in payload.items()
  220. if key not in {"event_id", "chapter"}
  221. }
  222. raw = json.dumps(stable_payload, ensure_ascii=False, sort_keys=True)
  223. digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:10]
  224. return f"evt-ch{chapter:03d}-{index:03d}-{digest}"