active_task.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. #!/usr/bin/env python3
  2. """Session-scoped active task resolution.
  3. The user-facing concept is a single "active task". Trellis stores that pointer
  4. per AI session/window under `.trellis/.runtime/sessions/`; without a stable
  5. session key there is no active task.
  6. """
  7. from __future__ import annotations
  8. import hashlib
  9. import json
  10. import os
  11. import re
  12. import sys
  13. import time
  14. from dataclasses import dataclass
  15. from datetime import datetime, timezone
  16. from pathlib import Path
  17. from typing import Any
  18. DIR_WORKFLOW = ".trellis"
  19. DIR_TASKS = "tasks"
  20. DIR_RUNTIME = ".runtime"
  21. DIR_SESSIONS = "sessions"
  22. DIR_CURSOR_SHELL = "cursor-shell"
  23. CURSOR_SHELL_TICKET_TTL_SECONDS = 30
  24. TASK_SESSION_COMMANDS = {"start", "current", "finish"}
  25. _SESSION_KEYS = ("session_id", "sessionId", "sessionID")
  26. _CONVERSATION_KEYS = ("conversation_id", "conversationId", "conversationID")
  27. _TRANSCRIPT_KEYS = ("transcript_path", "transcriptPath", "transcript")
  28. _NESTED_KEYS = ("input", "properties", "event", "hook_input", "hookInput")
  29. _KNOWN_PLATFORMS = {
  30. "claude",
  31. "codex",
  32. "cursor",
  33. "opencode",
  34. "gemini",
  35. "droid",
  36. "qoder",
  37. "codebuddy",
  38. "kiro",
  39. "copilot",
  40. "pi",
  41. "trae",
  42. }
  43. _ENV_SESSION_KEYS: tuple[tuple[str, tuple[str, ...]], ...] = (
  44. ("claude", ("CLAUDE_SESSION_ID", "CLAUDE_CODE_SESSION_ID")),
  45. ("codex", ("CODEX_SESSION_ID", "CODEX_THREAD_ID")),
  46. ("cursor", ("CURSOR_SESSION_ID",)),
  47. ("opencode", ("OPENCODE_SESSION_ID", "OPENCODE_SESSIONID", "OPENCODE_RUN_ID")),
  48. ("gemini", ("GEMINI_SESSION_ID",)),
  49. ("droid", ("FACTORY_SESSION_ID", "DROID_SESSION_ID")),
  50. ("qoder", ("QODER_SESSION_ID",)),
  51. ("codebuddy", ("CODEBUDDY_SESSION_ID",)),
  52. ("kiro", ("KIRO_SESSION_ID",)),
  53. ("copilot", ("COPILOT_SESSION_ID", "COPILOT_SESSIONID")),
  54. ("pi", ("PI_SESSION_ID", "PI_SESSIONID")),
  55. ("trae", ("TRAE_SESSION_ID",)),
  56. )
  57. _ENV_CONVERSATION_KEYS: tuple[tuple[str, tuple[str, ...]], ...] = (
  58. ("cursor", ("CURSOR_CONVERSATION_ID", "CURSOR_CONVERSATIONID")),
  59. )
  60. _ENV_TRANSCRIPT_KEYS: tuple[tuple[str, tuple[str, ...]], ...] = (
  61. ("claude", ("CLAUDE_TRANSCRIPT_PATH",)),
  62. ("codex", ("CODEX_TRANSCRIPT_PATH",)),
  63. ("cursor", ("CURSOR_TRANSCRIPT_PATH",)),
  64. ("gemini", ("GEMINI_TRANSCRIPT_PATH",)),
  65. ("droid", ("FACTORY_TRANSCRIPT_PATH", "DROID_TRANSCRIPT_PATH")),
  66. ("qoder", ("QODER_TRANSCRIPT_PATH",)),
  67. ("codebuddy", ("CODEBUDDY_TRANSCRIPT_PATH",)),
  68. )
  69. _ENV_PLATFORM_ALIASES = {
  70. "claude-code": "claude",
  71. "factory": "droid",
  72. "factory-ai": "droid",
  73. "github-copilot": "copilot",
  74. }
  75. @dataclass(frozen=True)
  76. class ActiveTask:
  77. """Resolved active task state."""
  78. task_path: str | None
  79. source_type: str
  80. context_key: str | None = None
  81. stale: bool = False
  82. @property
  83. def source(self) -> str:
  84. """Human-readable source label."""
  85. if self.source_type == "session" and self.context_key:
  86. return f"session:{self.context_key}"
  87. if self.source_type == "session-fallback" and self.context_key:
  88. return f"session-fallback:{self.context_key}"
  89. return self.source_type
  90. def normalize_task_ref(task_ref: str) -> str:
  91. """Normalize a task ref for stable storage and comparison."""
  92. normalized = task_ref.strip()
  93. if not normalized:
  94. return ""
  95. path_obj = Path(normalized)
  96. if path_obj.is_absolute():
  97. return str(path_obj)
  98. normalized = normalized.replace("\\", "/")
  99. while normalized.startswith("./"):
  100. normalized = normalized[2:]
  101. if normalized.startswith(f"{DIR_TASKS}/"):
  102. return f"{DIR_WORKFLOW}/{normalized}"
  103. return normalized
  104. def resolve_task_ref(task_ref: str, repo_root: Path) -> Path | None:
  105. """Resolve a task ref to an absolute task directory."""
  106. normalized = normalize_task_ref(task_ref)
  107. if not normalized:
  108. return None
  109. path_obj = Path(normalized)
  110. if path_obj.is_absolute():
  111. return path_obj
  112. if normalized.startswith(f"{DIR_WORKFLOW}/"):
  113. return repo_root / path_obj
  114. return repo_root / DIR_WORKFLOW / DIR_TASKS / path_obj
  115. def _runtime_sessions_dir(repo_root: Path) -> Path:
  116. return repo_root / DIR_WORKFLOW / DIR_RUNTIME / DIR_SESSIONS
  117. def _sanitize_key(raw: str) -> str:
  118. safe = re.sub(r"[^A-Za-z0-9._-]+", "_", raw.strip())
  119. safe = safe.strip("._-")
  120. return safe[:160] if safe else ""
  121. def _hash_value(raw: str) -> str:
  122. return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:24]
  123. def _as_dict(value: Any) -> dict[str, Any] | None:
  124. return value if isinstance(value, dict) else None
  125. def _string_value(value: Any) -> str | None:
  126. if isinstance(value, str):
  127. stripped = value.strip()
  128. return stripped or None
  129. return None
  130. def _lookup_string(data: dict[str, Any], keys: tuple[str, ...]) -> str | None:
  131. for key in keys:
  132. value = _string_value(data.get(key))
  133. if value:
  134. return value
  135. for nested_key in _NESTED_KEYS:
  136. nested = _as_dict(data.get(nested_key))
  137. if not nested:
  138. continue
  139. value = _lookup_string(nested, keys)
  140. if value:
  141. return value
  142. return None
  143. def _detect_platform(platform_input: dict[str, Any] | None, platform: str | None) -> str:
  144. if platform:
  145. return _sanitize_key(platform) or "session"
  146. if platform_input:
  147. for key in ("_trellis_platform", "trellis_platform", "platform", "source"):
  148. value = _string_value(platform_input.get(key))
  149. if value:
  150. return _sanitize_key(value) or "session"
  151. if _string_value(platform_input.get("cursor_version")):
  152. return "cursor"
  153. return "session"
  154. def _context_key(platform_name: str, kind: str, value: str) -> str:
  155. if kind == "transcript":
  156. return f"{platform_name}_transcript_{_hash_value(value)}"
  157. safe_value = _sanitize_key(value)
  158. if safe_value:
  159. return f"{platform_name}_{safe_value}"
  160. return f"{platform_name}_{_hash_value(value)}"
  161. def _iter_env_keys(
  162. env_keys: tuple[tuple[str, tuple[str, ...]], ...],
  163. platform_name: str | None,
  164. ) -> tuple[tuple[str, tuple[str, ...]], ...]:
  165. if not platform_name:
  166. return env_keys
  167. matched = tuple((name, keys) for name, keys in env_keys if name == platform_name)
  168. return matched
  169. def _env_platform_name(platform_name: str | None) -> str | None:
  170. if not platform_name or platform_name == "session":
  171. return None
  172. return _ENV_PLATFORM_ALIASES.get(platform_name, platform_name)
  173. def _lookup_env_context_key(platform_name: str | None) -> str | None:
  174. """Resolve a context key from platform-provided environment variables.
  175. Hooks pass `TRELLIS_CONTEXT_ID` to subprocesses they launch, but an AI-run
  176. shell command can only see session identity if the host platform exports it
  177. in the command environment. These names are best-effort adapters; if none
  178. are present, there is no session-scoped active task.
  179. """
  180. env_platform_name = _env_platform_name(platform_name)
  181. for name, keys in _iter_env_keys(_ENV_SESSION_KEYS, env_platform_name):
  182. for key in keys:
  183. value = _string_value(os.environ.get(key))
  184. if value:
  185. return _context_key(name, "session", value)
  186. for name, keys in _iter_env_keys(_ENV_CONVERSATION_KEYS, env_platform_name):
  187. for key in keys:
  188. value = _string_value(os.environ.get(key))
  189. if value:
  190. return _context_key(name, "conversation", value)
  191. for name, keys in _iter_env_keys(_ENV_TRANSCRIPT_KEYS, env_platform_name):
  192. for key in keys:
  193. value = _string_value(os.environ.get(key))
  194. if value:
  195. return _context_key(name, "transcript", value)
  196. return None
  197. def _find_repo_root_from_cwd() -> Path | None:
  198. current = Path.cwd().resolve()
  199. while True:
  200. if (current / DIR_WORKFLOW).is_dir():
  201. return current
  202. if current == current.parent:
  203. return None
  204. current = current.parent
  205. def _cursor_shell_ticket_dir(repo_root: Path) -> Path:
  206. return repo_root / DIR_WORKFLOW / DIR_RUNTIME / DIR_CURSOR_SHELL
  207. def _remove_file(path: Path) -> bool:
  208. try:
  209. path.unlink()
  210. return True
  211. except OSError:
  212. return False
  213. def _task_refs_match(left: str | None, right: str | None, repo_root: Path) -> bool:
  214. if not left or not right:
  215. return False
  216. left_path = resolve_task_ref(left, repo_root)
  217. right_path = resolve_task_ref(right, repo_root)
  218. if left_path is not None and right_path is not None:
  219. return left_path == right_path
  220. return normalize_task_ref(left) == normalize_task_ref(right)
  221. def _pending_ticket_matches_args(ticket: dict[str, Any], repo_root: Path) -> bool:
  222. if Path(sys.argv[0]).name != "task.py":
  223. return False
  224. args = tuple(sys.argv[1:])
  225. if not args:
  226. return False
  227. command_name = args[0]
  228. if command_name not in TASK_SESSION_COMMANDS:
  229. return False
  230. subcommands = ticket.get("subcommands")
  231. if not isinstance(subcommands, list):
  232. return False
  233. for subcommand in subcommands:
  234. if not isinstance(subcommand, dict):
  235. continue
  236. if _string_value(subcommand.get("name")) != command_name:
  237. continue
  238. if command_name != "start":
  239. return True
  240. task_ref = args[1] if len(args) > 1 else None
  241. if _task_refs_match(_string_value(subcommand.get("task_ref")), task_ref, repo_root):
  242. return True
  243. return False
  244. def _ticket_is_fresh(ticket: dict[str, Any], ticket_path: Path, now: float) -> bool:
  245. expires_at = ticket.get("expires_at_epoch")
  246. if isinstance(expires_at, (int, float)) and expires_at < now:
  247. _remove_file(ticket_path)
  248. return False
  249. created_at = ticket.get("created_at_epoch")
  250. if isinstance(created_at, (int, float)):
  251. if now - created_at <= CURSOR_SHELL_TICKET_TTL_SECONDS:
  252. return True
  253. _remove_file(ticket_path)
  254. return False
  255. return True
  256. def _ticket_cwd_matches_repo(ticket: dict[str, Any], repo_root: Path) -> bool:
  257. cwd = _string_value(ticket.get("cwd"))
  258. if not cwd:
  259. return True
  260. try:
  261. Path(cwd).resolve().relative_to(repo_root)
  262. except ValueError:
  263. return False
  264. return True
  265. def _matching_cursor_ticket_context_key(
  266. ticket_path: Path,
  267. repo_root: Path,
  268. now: float,
  269. ) -> str | None:
  270. ticket = _read_json(ticket_path)
  271. if ticket is None or ticket.get("platform") != "cursor":
  272. return None
  273. if not _ticket_is_fresh(ticket, ticket_path, now):
  274. return None
  275. if not _ticket_cwd_matches_repo(ticket, repo_root):
  276. return None
  277. if not _pending_ticket_matches_args(ticket, repo_root):
  278. return None
  279. return _string_value(ticket.get("context_key"))
  280. def _lookup_cursor_shell_ticket_context_key() -> str | None:
  281. """Resolve Cursor conversation identity from a short-lived shell ticket.
  282. Cursor exposes `conversation_id` to `beforeShellExecution`, but does not
  283. export it into the shell command environment. The Cursor hook writes a
  284. short-lived ticket just before `task.py` runs. We accept a ticket only when
  285. the current `task.py` subcommand matches and exactly one fresh context key
  286. matches, which avoids cross-window pointer contamination.
  287. """
  288. repo_root = _find_repo_root_from_cwd()
  289. if repo_root is None:
  290. return None
  291. ticket_dir = _cursor_shell_ticket_dir(repo_root)
  292. if not ticket_dir.is_dir():
  293. return None
  294. now = time.time()
  295. candidates: set[str] = set()
  296. for ticket_path in ticket_dir.glob("*.json"):
  297. context_key = _matching_cursor_ticket_context_key(ticket_path, repo_root, now)
  298. if context_key:
  299. candidates.add(context_key)
  300. if len(candidates) == 1:
  301. return next(iter(candidates))
  302. return None
  303. def resolve_context_key(
  304. platform_input: dict[str, Any] | None = None,
  305. platform: str | None = None,
  306. ) -> str | None:
  307. """Resolve a stable session/window context key, if one is available.
  308. `TRELLIS_CONTEXT_ID` is an explicit context-key override used by CLI
  309. scripts and subprocesses. It does not store the task itself.
  310. """
  311. override = _string_value(os.environ.get("TRELLIS_CONTEXT_ID"))
  312. if override:
  313. return _sanitize_key(override) or _hash_value(override)
  314. data = _as_dict(platform_input)
  315. platform_name = _detect_platform(data, platform) if data or platform else None
  316. if data:
  317. session_id = _lookup_string(data, _SESSION_KEYS)
  318. if session_id:
  319. return _context_key(platform_name or "session", "session", session_id)
  320. conversation_id = _lookup_string(data, _CONVERSATION_KEYS)
  321. if conversation_id:
  322. return _context_key(platform_name or "session", "conversation", conversation_id)
  323. transcript_path = _lookup_string(data, _TRANSCRIPT_KEYS)
  324. if transcript_path:
  325. return _context_key(platform_name or "session", "transcript", transcript_path)
  326. env_context_key = _lookup_env_context_key(platform_name)
  327. if env_context_key:
  328. return env_context_key
  329. if platform_name in (None, "session", "cursor"):
  330. return _lookup_cursor_shell_ticket_context_key()
  331. return None
  332. def _read_json(path: Path) -> dict[str, Any] | None:
  333. try:
  334. data = json.loads(path.read_text(encoding="utf-8"))
  335. except (FileNotFoundError, json.JSONDecodeError, OSError):
  336. return None
  337. return data if isinstance(data, dict) else None
  338. def _write_json(path: Path, data: dict[str, Any]) -> bool:
  339. try:
  340. path.parent.mkdir(parents=True, exist_ok=True)
  341. path.write_text(
  342. json.dumps(data, indent=2, ensure_ascii=False) + "\n",
  343. encoding="utf-8",
  344. )
  345. return True
  346. except OSError:
  347. return False
  348. def _canonical_task_ref(task_path: str, repo_root: Path) -> str | None:
  349. normalized = normalize_task_ref(task_path)
  350. if not normalized:
  351. return None
  352. full_path = resolve_task_ref(normalized, repo_root)
  353. if full_path is None or not full_path.is_dir():
  354. return None
  355. try:
  356. return full_path.relative_to(repo_root).as_posix()
  357. except ValueError:
  358. return str(full_path)
  359. def _active_from_ref(
  360. task_ref: str | None,
  361. repo_root: Path,
  362. source_type: str,
  363. context_key: str | None = None,
  364. ) -> ActiveTask | None:
  365. if not task_ref:
  366. return None
  367. resolved = resolve_task_ref(task_ref, repo_root)
  368. stale = resolved is None or not resolved.is_dir()
  369. return ActiveTask(task_ref, source_type, context_key, stale)
  370. def _context_path(repo_root: Path, context_key: str) -> Path:
  371. return _runtime_sessions_dir(repo_root) / f"{context_key}.json"
  372. def resolve_active_task(
  373. repo_root: Path,
  374. platform_input: dict[str, Any] | None = None,
  375. platform: str | None = None,
  376. ) -> ActiveTask:
  377. """Resolve the active task from session runtime state only.
  378. A stale session task is returned as stale. Missing context identity or a
  379. missing/empty session context falls back to single-session inference: if
  380. exactly one session file exists in the runtime, return its task with
  381. source_type="session-fallback" — covers class-2 platform sub-agents (codex,
  382. copilot, gemini, qoder) that don't inherit the parent's session id. ≥2
  383. files or 0 files yield ActiveTask(None) — refuses to guess across windows.
  384. """
  385. context_key = resolve_context_key(platform_input, platform)
  386. if context_key:
  387. context = _read_json(_context_path(repo_root, context_key)) or {}
  388. task_ref = _string_value(context.get("current_task"))
  389. active = _active_from_ref(task_ref, repo_root, "session", context_key)
  390. if active:
  391. return active
  392. fallback = _resolve_single_session_fallback(repo_root)
  393. if fallback is not None:
  394. return fallback
  395. return ActiveTask(None, "none", context_key)
  396. def _resolve_single_session_fallback(repo_root: Path) -> ActiveTask | None:
  397. """Return the task pointed at by the sole session file, if exactly one exists.
  398. Used when context-key resolution fails (typical for class-2 platform
  399. sub-agents). Returns None if 0 or ≥2 session files are present — refuses
  400. to pick across windows so 04-21's multi-session isolation contract holds.
  401. """
  402. sessions_dir = _runtime_sessions_dir(repo_root)
  403. if not sessions_dir.is_dir():
  404. return None
  405. session_files = sorted(sessions_dir.glob("*.json"))
  406. if len(session_files) != 1:
  407. return None
  408. session_file = session_files[0]
  409. context = _read_json(session_file) or {}
  410. task_ref = _string_value(context.get("current_task"))
  411. if not task_ref:
  412. return None
  413. fallback_key = session_file.stem
  414. return _active_from_ref(task_ref, repo_root, "session-fallback", fallback_key)
  415. def _utc_now() -> str:
  416. return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
  417. def _context_metadata(
  418. platform_input: dict[str, Any] | None,
  419. platform: str | None,
  420. context_key: str | None = None,
  421. ) -> dict[str, Any]:
  422. data = _as_dict(platform_input) or {}
  423. platform_name = _detect_platform(data, platform)
  424. if platform_name == "session" and context_key:
  425. prefix = context_key.split("_", 1)[0]
  426. if prefix in _KNOWN_PLATFORMS:
  427. platform_name = prefix
  428. metadata: dict[str, Any] = {
  429. "platform": platform_name,
  430. "last_seen_at": _utc_now(),
  431. }
  432. for key in (*_SESSION_KEYS, *_CONVERSATION_KEYS, *_TRANSCRIPT_KEYS):
  433. value = _lookup_string(data, (key,))
  434. if value:
  435. metadata[key] = value
  436. return metadata
  437. def set_active_task(
  438. task_path: str,
  439. repo_root: Path,
  440. platform_input: dict[str, Any] | None = None,
  441. platform: str | None = None,
  442. ) -> ActiveTask | None:
  443. """Set the active task in session scope.
  444. Returns None when no context key is available; callers should surface a
  445. user-facing error that explains how to provide session identity.
  446. """
  447. canonical = _canonical_task_ref(task_path, repo_root)
  448. if canonical is None:
  449. return None
  450. context_key = resolve_context_key(platform_input, platform)
  451. if not context_key:
  452. return None
  453. context_path = _context_path(repo_root, context_key)
  454. context = _read_json(context_path) or {}
  455. context.update(_context_metadata(platform_input, platform, context_key))
  456. context["current_task"] = canonical
  457. context.setdefault("current_run", None)
  458. if not _write_json(context_path, context):
  459. return None
  460. return ActiveTask(canonical, "session", context_key)
  461. def clear_active_task(
  462. repo_root: Path,
  463. platform_input: dict[str, Any] | None = None,
  464. platform: str | None = None,
  465. ) -> ActiveTask:
  466. """Clear the active task by deleting the current session context file."""
  467. context_key = resolve_context_key(platform_input, platform)
  468. if not context_key:
  469. return ActiveTask(None, "none")
  470. previous = resolve_active_task(repo_root, platform_input, platform)
  471. context_path = _context_path(repo_root, context_key)
  472. if context_path.is_file():
  473. _remove_file(context_path)
  474. return previous
  475. def clear_task_from_sessions(task_path: str, repo_root: Path) -> int:
  476. """Delete all session runtime files that point at a task."""
  477. target = _canonical_task_ref(task_path, repo_root) or normalize_task_ref(task_path)
  478. if not target:
  479. return 0
  480. cleared = 0
  481. sessions_dir = _runtime_sessions_dir(repo_root)
  482. if not sessions_dir.is_dir():
  483. return cleared
  484. for session_path in sessions_dir.glob("*.json"):
  485. context = _read_json(session_path) or {}
  486. current = _string_value(context.get("current_task"))
  487. if not current:
  488. continue
  489. current_ref = _canonical_task_ref(current, repo_root) or normalize_task_ref(current)
  490. if current_ref != target:
  491. continue
  492. if session_path.is_file() and _remove_file(session_path):
  493. cleared += 1
  494. return cleared
  495. def get_current_task_source(
  496. repo_root: Path,
  497. platform_input: dict[str, Any] | None = None,
  498. platform: str | None = None,
  499. ) -> tuple[str, str | None, str | None]:
  500. """Return (`source_type`, `context_key`, `task_path`) for compatibility."""
  501. active = resolve_active_task(repo_root, platform_input, platform)
  502. return active.source_type, active.context_key, active.task_path