| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626 |
- #!/usr/bin/env python3
- """Session-scoped active task resolution.
- The user-facing concept is a single "active task". Trellis stores that pointer
- per AI session/window under `.trellis/.runtime/sessions/`; without a stable
- session key there is no active task.
- """
- from __future__ import annotations
- import hashlib
- import json
- import os
- import re
- import sys
- import time
- from dataclasses import dataclass
- from datetime import datetime, timezone
- from pathlib import Path
- from typing import Any
- DIR_WORKFLOW = ".trellis"
- DIR_TASKS = "tasks"
- DIR_RUNTIME = ".runtime"
- DIR_SESSIONS = "sessions"
- DIR_CURSOR_SHELL = "cursor-shell"
- CURSOR_SHELL_TICKET_TTL_SECONDS = 30
- TASK_SESSION_COMMANDS = {"start", "current", "finish"}
- _SESSION_KEYS = ("session_id", "sessionId", "sessionID")
- _CONVERSATION_KEYS = ("conversation_id", "conversationId", "conversationID")
- _TRANSCRIPT_KEYS = ("transcript_path", "transcriptPath", "transcript")
- _NESTED_KEYS = ("input", "properties", "event", "hook_input", "hookInput")
- _KNOWN_PLATFORMS = {
- "claude",
- "codex",
- "cursor",
- "opencode",
- "gemini",
- "droid",
- "qoder",
- "codebuddy",
- "kiro",
- "copilot",
- "pi",
- }
- _ENV_SESSION_KEYS: tuple[tuple[str, tuple[str, ...]], ...] = (
- ("claude", ("CLAUDE_SESSION_ID", "CLAUDE_CODE_SESSION_ID")),
- ("codex", ("CODEX_SESSION_ID", "CODEX_THREAD_ID")),
- ("cursor", ("CURSOR_SESSION_ID",)),
- ("opencode", ("OPENCODE_SESSION_ID", "OPENCODE_SESSIONID", "OPENCODE_RUN_ID")),
- ("gemini", ("GEMINI_SESSION_ID",)),
- ("droid", ("FACTORY_SESSION_ID", "DROID_SESSION_ID")),
- ("qoder", ("QODER_SESSION_ID",)),
- ("codebuddy", ("CODEBUDDY_SESSION_ID",)),
- ("kiro", ("KIRO_SESSION_ID",)),
- ("copilot", ("COPILOT_SESSION_ID", "COPILOT_SESSIONID")),
- ("pi", ("PI_SESSION_ID", "PI_SESSIONID")),
- )
- _ENV_CONVERSATION_KEYS: tuple[tuple[str, tuple[str, ...]], ...] = (
- ("cursor", ("CURSOR_CONVERSATION_ID", "CURSOR_CONVERSATIONID")),
- )
- _ENV_TRANSCRIPT_KEYS: tuple[tuple[str, tuple[str, ...]], ...] = (
- ("claude", ("CLAUDE_TRANSCRIPT_PATH",)),
- ("codex", ("CODEX_TRANSCRIPT_PATH",)),
- ("cursor", ("CURSOR_TRANSCRIPT_PATH",)),
- ("gemini", ("GEMINI_TRANSCRIPT_PATH",)),
- ("droid", ("FACTORY_TRANSCRIPT_PATH", "DROID_TRANSCRIPT_PATH")),
- ("qoder", ("QODER_TRANSCRIPT_PATH",)),
- ("codebuddy", ("CODEBUDDY_TRANSCRIPT_PATH",)),
- )
- _ENV_PLATFORM_ALIASES = {
- "claude-code": "claude",
- "factory": "droid",
- "factory-ai": "droid",
- "github-copilot": "copilot",
- }
- @dataclass(frozen=True)
- class ActiveTask:
- """Resolved active task state."""
- task_path: str | None
- source_type: str
- context_key: str | None = None
- stale: bool = False
- @property
- def source(self) -> str:
- """Human-readable source label."""
- if self.source_type == "session" and self.context_key:
- return f"session:{self.context_key}"
- if self.source_type == "session-fallback" and self.context_key:
- return f"session-fallback:{self.context_key}"
- return self.source_type
- def normalize_task_ref(task_ref: str) -> str:
- """Normalize a task ref for stable storage and comparison."""
- normalized = task_ref.strip()
- if not normalized:
- return ""
- path_obj = Path(normalized)
- if path_obj.is_absolute():
- return str(path_obj)
- normalized = normalized.replace("\\", "/")
- while normalized.startswith("./"):
- normalized = normalized[2:]
- if normalized.startswith(f"{DIR_TASKS}/"):
- return f"{DIR_WORKFLOW}/{normalized}"
- return normalized
- def resolve_task_ref(task_ref: str, repo_root: Path) -> Path | None:
- """Resolve a task ref to an absolute task directory."""
- normalized = normalize_task_ref(task_ref)
- if not normalized:
- return None
- path_obj = Path(normalized)
- if path_obj.is_absolute():
- return path_obj
- if normalized.startswith(f"{DIR_WORKFLOW}/"):
- return repo_root / path_obj
- return repo_root / DIR_WORKFLOW / DIR_TASKS / path_obj
- def _runtime_sessions_dir(repo_root: Path) -> Path:
- return repo_root / DIR_WORKFLOW / DIR_RUNTIME / DIR_SESSIONS
- def _sanitize_key(raw: str) -> str:
- safe = re.sub(r"[^A-Za-z0-9._-]+", "_", raw.strip())
- safe = safe.strip("._-")
- return safe[:160] if safe else ""
- def _hash_value(raw: str) -> str:
- return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:24]
- def _as_dict(value: Any) -> dict[str, Any] | None:
- return value if isinstance(value, dict) else None
- def _string_value(value: Any) -> str | None:
- if isinstance(value, str):
- stripped = value.strip()
- return stripped or None
- return None
- def _lookup_string(data: dict[str, Any], keys: tuple[str, ...]) -> str | None:
- for key in keys:
- value = _string_value(data.get(key))
- if value:
- return value
- for nested_key in _NESTED_KEYS:
- nested = _as_dict(data.get(nested_key))
- if not nested:
- continue
- value = _lookup_string(nested, keys)
- if value:
- return value
- return None
- def _detect_platform(platform_input: dict[str, Any] | None, platform: str | None) -> str:
- if platform:
- return _sanitize_key(platform) or "session"
- if platform_input:
- for key in ("_trellis_platform", "trellis_platform", "platform", "source"):
- value = _string_value(platform_input.get(key))
- if value:
- return _sanitize_key(value) or "session"
- if _string_value(platform_input.get("cursor_version")):
- return "cursor"
- return "session"
- def _context_key(platform_name: str, kind: str, value: str) -> str:
- if kind == "transcript":
- return f"{platform_name}_transcript_{_hash_value(value)}"
- safe_value = _sanitize_key(value)
- if safe_value:
- return f"{platform_name}_{safe_value}"
- return f"{platform_name}_{_hash_value(value)}"
- def _iter_env_keys(
- env_keys: tuple[tuple[str, tuple[str, ...]], ...],
- platform_name: str | None,
- ) -> tuple[tuple[str, tuple[str, ...]], ...]:
- if not platform_name:
- return env_keys
- matched = tuple((name, keys) for name, keys in env_keys if name == platform_name)
- return matched
- def _env_platform_name(platform_name: str | None) -> str | None:
- if not platform_name or platform_name == "session":
- return None
- return _ENV_PLATFORM_ALIASES.get(platform_name, platform_name)
- def _lookup_env_context_key(platform_name: str | None) -> str | None:
- """Resolve a context key from platform-provided environment variables.
- Hooks pass `TRELLIS_CONTEXT_ID` to subprocesses they launch, but an AI-run
- shell command can only see session identity if the host platform exports it
- in the command environment. These names are best-effort adapters; if none
- are present, there is no session-scoped active task.
- """
- env_platform_name = _env_platform_name(platform_name)
- for name, keys in _iter_env_keys(_ENV_SESSION_KEYS, env_platform_name):
- for key in keys:
- value = _string_value(os.environ.get(key))
- if value:
- return _context_key(name, "session", value)
- for name, keys in _iter_env_keys(_ENV_CONVERSATION_KEYS, env_platform_name):
- for key in keys:
- value = _string_value(os.environ.get(key))
- if value:
- return _context_key(name, "conversation", value)
- for name, keys in _iter_env_keys(_ENV_TRANSCRIPT_KEYS, env_platform_name):
- for key in keys:
- value = _string_value(os.environ.get(key))
- if value:
- return _context_key(name, "transcript", value)
- return None
- def _find_repo_root_from_cwd() -> Path | None:
- current = Path.cwd().resolve()
- while True:
- if (current / DIR_WORKFLOW).is_dir():
- return current
- if current == current.parent:
- return None
- current = current.parent
- def _cursor_shell_ticket_dir(repo_root: Path) -> Path:
- return repo_root / DIR_WORKFLOW / DIR_RUNTIME / DIR_CURSOR_SHELL
- def _remove_file(path: Path) -> bool:
- try:
- path.unlink()
- return True
- except OSError:
- return False
- def _task_refs_match(left: str | None, right: str | None, repo_root: Path) -> bool:
- if not left or not right:
- return False
- left_path = resolve_task_ref(left, repo_root)
- right_path = resolve_task_ref(right, repo_root)
- if left_path is not None and right_path is not None:
- return left_path == right_path
- return normalize_task_ref(left) == normalize_task_ref(right)
- def _pending_ticket_matches_args(ticket: dict[str, Any], repo_root: Path) -> bool:
- if Path(sys.argv[0]).name != "task.py":
- return False
- args = tuple(sys.argv[1:])
- if not args:
- return False
- command_name = args[0]
- if command_name not in TASK_SESSION_COMMANDS:
- return False
- subcommands = ticket.get("subcommands")
- if not isinstance(subcommands, list):
- return False
- for subcommand in subcommands:
- if not isinstance(subcommand, dict):
- continue
- if _string_value(subcommand.get("name")) != command_name:
- continue
- if command_name != "start":
- return True
- task_ref = args[1] if len(args) > 1 else None
- if _task_refs_match(_string_value(subcommand.get("task_ref")), task_ref, repo_root):
- return True
- return False
- def _ticket_is_fresh(ticket: dict[str, Any], ticket_path: Path, now: float) -> bool:
- expires_at = ticket.get("expires_at_epoch")
- if isinstance(expires_at, (int, float)) and expires_at < now:
- _remove_file(ticket_path)
- return False
- created_at = ticket.get("created_at_epoch")
- if isinstance(created_at, (int, float)):
- if now - created_at <= CURSOR_SHELL_TICKET_TTL_SECONDS:
- return True
- _remove_file(ticket_path)
- return False
- return True
- def _ticket_cwd_matches_repo(ticket: dict[str, Any], repo_root: Path) -> bool:
- cwd = _string_value(ticket.get("cwd"))
- if not cwd:
- return True
- try:
- Path(cwd).resolve().relative_to(repo_root)
- except ValueError:
- return False
- return True
- def _matching_cursor_ticket_context_key(
- ticket_path: Path,
- repo_root: Path,
- now: float,
- ) -> str | None:
- ticket = _read_json(ticket_path)
- if ticket is None or ticket.get("platform") != "cursor":
- return None
- if not _ticket_is_fresh(ticket, ticket_path, now):
- return None
- if not _ticket_cwd_matches_repo(ticket, repo_root):
- return None
- if not _pending_ticket_matches_args(ticket, repo_root):
- return None
- return _string_value(ticket.get("context_key"))
- def _lookup_cursor_shell_ticket_context_key() -> str | None:
- """Resolve Cursor conversation identity from a short-lived shell ticket.
- Cursor exposes `conversation_id` to `beforeShellExecution`, but does not
- export it into the shell command environment. The Cursor hook writes a
- short-lived ticket just before `task.py` runs. We accept a ticket only when
- the current `task.py` subcommand matches and exactly one fresh context key
- matches, which avoids cross-window pointer contamination.
- """
- repo_root = _find_repo_root_from_cwd()
- if repo_root is None:
- return None
- ticket_dir = _cursor_shell_ticket_dir(repo_root)
- if not ticket_dir.is_dir():
- return None
- now = time.time()
- candidates: set[str] = set()
- for ticket_path in ticket_dir.glob("*.json"):
- context_key = _matching_cursor_ticket_context_key(ticket_path, repo_root, now)
- if context_key:
- candidates.add(context_key)
- if len(candidates) == 1:
- return next(iter(candidates))
- return None
- def resolve_context_key(
- platform_input: dict[str, Any] | None = None,
- platform: str | None = None,
- ) -> str | None:
- """Resolve a stable session/window context key, if one is available.
- `TRELLIS_CONTEXT_ID` is an explicit context-key override used by CLI
- scripts and subprocesses. It does not store the task itself.
- """
- override = _string_value(os.environ.get("TRELLIS_CONTEXT_ID"))
- if override:
- return _sanitize_key(override) or _hash_value(override)
- data = _as_dict(platform_input)
- platform_name = _detect_platform(data, platform) if data or platform else None
- if data:
- session_id = _lookup_string(data, _SESSION_KEYS)
- if session_id:
- return _context_key(platform_name or "session", "session", session_id)
- conversation_id = _lookup_string(data, _CONVERSATION_KEYS)
- if conversation_id:
- return _context_key(platform_name or "session", "conversation", conversation_id)
- transcript_path = _lookup_string(data, _TRANSCRIPT_KEYS)
- if transcript_path:
- return _context_key(platform_name or "session", "transcript", transcript_path)
- env_context_key = _lookup_env_context_key(platform_name)
- if env_context_key:
- return env_context_key
- if platform_name in (None, "session", "cursor"):
- return _lookup_cursor_shell_ticket_context_key()
- return None
- def _read_json(path: Path) -> dict[str, Any] | None:
- try:
- data = json.loads(path.read_text(encoding="utf-8"))
- except (FileNotFoundError, json.JSONDecodeError, OSError):
- return None
- return data if isinstance(data, dict) else None
- def _write_json(path: Path, data: dict[str, Any]) -> bool:
- try:
- path.parent.mkdir(parents=True, exist_ok=True)
- path.write_text(
- json.dumps(data, indent=2, ensure_ascii=False) + "\n",
- encoding="utf-8",
- )
- return True
- except OSError:
- return False
- def _canonical_task_ref(task_path: str, repo_root: Path) -> str | None:
- normalized = normalize_task_ref(task_path)
- if not normalized:
- return None
- full_path = resolve_task_ref(normalized, repo_root)
- if full_path is None or not full_path.is_dir():
- return None
- try:
- return full_path.relative_to(repo_root).as_posix()
- except ValueError:
- return str(full_path)
- def _active_from_ref(
- task_ref: str | None,
- repo_root: Path,
- source_type: str,
- context_key: str | None = None,
- ) -> ActiveTask | None:
- if not task_ref:
- return None
- resolved = resolve_task_ref(task_ref, repo_root)
- stale = resolved is None or not resolved.is_dir()
- return ActiveTask(task_ref, source_type, context_key, stale)
- def _context_path(repo_root: Path, context_key: str) -> Path:
- return _runtime_sessions_dir(repo_root) / f"{context_key}.json"
- def resolve_active_task(
- repo_root: Path,
- platform_input: dict[str, Any] | None = None,
- platform: str | None = None,
- ) -> ActiveTask:
- """Resolve the active task from session runtime state only.
- A stale session task is returned as stale. Missing context identity or a
- missing/empty session context falls back to single-session inference: if
- exactly one session file exists in the runtime, return its task with
- source_type="session-fallback" — covers class-2 platform sub-agents (codex,
- copilot, gemini, qoder) that don't inherit the parent's session id. ≥2
- files or 0 files yield ActiveTask(None) — refuses to guess across windows.
- """
- context_key = resolve_context_key(platform_input, platform)
- if context_key:
- context = _read_json(_context_path(repo_root, context_key)) or {}
- task_ref = _string_value(context.get("current_task"))
- active = _active_from_ref(task_ref, repo_root, "session", context_key)
- if active:
- return active
- fallback = _resolve_single_session_fallback(repo_root)
- if fallback is not None:
- return fallback
- return ActiveTask(None, "none", context_key)
- def _resolve_single_session_fallback(repo_root: Path) -> ActiveTask | None:
- """Return the task pointed at by the sole session file, if exactly one exists.
- Used when context-key resolution fails (typical for class-2 platform
- sub-agents). Returns None if 0 or ≥2 session files are present — refuses
- to pick across windows so 04-21's multi-session isolation contract holds.
- """
- sessions_dir = _runtime_sessions_dir(repo_root)
- if not sessions_dir.is_dir():
- return None
- session_files = sorted(sessions_dir.glob("*.json"))
- if len(session_files) != 1:
- return None
- session_file = session_files[0]
- context = _read_json(session_file) or {}
- task_ref = _string_value(context.get("current_task"))
- if not task_ref:
- return None
- fallback_key = session_file.stem
- return _active_from_ref(task_ref, repo_root, "session-fallback", fallback_key)
- def _utc_now() -> str:
- return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
- def _context_metadata(
- platform_input: dict[str, Any] | None,
- platform: str | None,
- context_key: str | None = None,
- ) -> dict[str, Any]:
- data = _as_dict(platform_input) or {}
- platform_name = _detect_platform(data, platform)
- if platform_name == "session" and context_key:
- prefix = context_key.split("_", 1)[0]
- if prefix in _KNOWN_PLATFORMS:
- platform_name = prefix
- metadata: dict[str, Any] = {
- "platform": platform_name,
- "last_seen_at": _utc_now(),
- }
- for key in (*_SESSION_KEYS, *_CONVERSATION_KEYS, *_TRANSCRIPT_KEYS):
- value = _lookup_string(data, (key,))
- if value:
- metadata[key] = value
- return metadata
- def set_active_task(
- task_path: str,
- repo_root: Path,
- platform_input: dict[str, Any] | None = None,
- platform: str | None = None,
- ) -> ActiveTask | None:
- """Set the active task in session scope.
- Returns None when no context key is available; callers should surface a
- user-facing error that explains how to provide session identity.
- """
- canonical = _canonical_task_ref(task_path, repo_root)
- if canonical is None:
- return None
- context_key = resolve_context_key(platform_input, platform)
- if not context_key:
- return None
- context_path = _context_path(repo_root, context_key)
- context = _read_json(context_path) or {}
- context.update(_context_metadata(platform_input, platform, context_key))
- context["current_task"] = canonical
- context.setdefault("current_run", None)
- if not _write_json(context_path, context):
- return None
- return ActiveTask(canonical, "session", context_key)
- def clear_active_task(
- repo_root: Path,
- platform_input: dict[str, Any] | None = None,
- platform: str | None = None,
- ) -> ActiveTask:
- """Clear the active task by deleting the current session context file."""
- context_key = resolve_context_key(platform_input, platform)
- if not context_key:
- return ActiveTask(None, "none")
- previous = resolve_active_task(repo_root, platform_input, platform)
- context_path = _context_path(repo_root, context_key)
- if context_path.is_file():
- _remove_file(context_path)
- return previous
- def clear_task_from_sessions(task_path: str, repo_root: Path) -> int:
- """Delete all session runtime files that point at a task."""
- target = _canonical_task_ref(task_path, repo_root) or normalize_task_ref(task_path)
- if not target:
- return 0
- cleared = 0
- sessions_dir = _runtime_sessions_dir(repo_root)
- if not sessions_dir.is_dir():
- return cleared
- for session_path in sessions_dir.glob("*.json"):
- context = _read_json(session_path) or {}
- current = _string_value(context.get("current_task"))
- if not current:
- continue
- current_ref = _canonical_task_ref(current, repo_root) or normalize_task_ref(current)
- if current_ref != target:
- continue
- if session_path.is_file() and _remove_file(session_path):
- cleared += 1
- return cleared
- def get_current_task_source(
- repo_root: Path,
- platform_input: dict[str, Any] | None = None,
- platform: str | None = None,
- ) -> tuple[str, str | None, str | None]:
- """Return (`source_type`, `context_key`, `task_path`) for compatibility."""
- active = resolve_active_task(repo_root, platform_input, platform)
- return active.source_type, active.context_key, active.task_path
|