entity_linker.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Entity Linker - 实体消歧辅助模块
  5. 为 Data Agent 提供实体消歧的辅助功能:
  6. - 置信度判断
  7. - 别名索引管理
  8. - 消歧结果记录
  9. """
  10. import json
  11. from pathlib import Path
  12. from typing import Dict, List, Optional, Tuple
  13. from dataclasses import dataclass, field
  14. import filelock
  15. from .config import get_config
  16. try:
  17. # 常见:从 scripts/ 目录运行,security_utils 在 sys.path 顶层
  18. from security_utils import atomic_write_json, read_json_safe
  19. except ImportError: # pragma: no cover
  20. # 兼容:从仓库根目录以 `python -m scripts...` 运行
  21. from scripts.security_utils import atomic_write_json, read_json_safe
  22. @dataclass
  23. class DisambiguationResult:
  24. """消歧结果"""
  25. mention: str
  26. entity_id: Optional[str]
  27. confidence: float
  28. candidates: List[str] = field(default_factory=list)
  29. adopted: bool = False
  30. warning: Optional[str] = None
  31. class EntityLinker:
  32. """实体链接器 - 辅助 Data Agent 进行实体消歧 (v5.0 一对多别名)"""
  33. def __init__(self, config=None):
  34. self.config = config or get_config()
  35. # v5.0: alias_index 改为一对多格式 {alias: [{"type": ..., "id": ...}, ...]}
  36. self._alias_index: Dict[str, List[Dict]] = {}
  37. self._state_file = self.config.state_file
  38. self._load_alias_index()
  39. def _load_alias_index(self):
  40. """从 state.json 加载 alias_index"""
  41. if self._state_file.exists():
  42. try:
  43. with open(self._state_file, "r", encoding="utf-8") as f:
  44. state = json.load(f)
  45. self._alias_index = state.get("alias_index", {})
  46. except (json.JSONDecodeError, IOError):
  47. self._alias_index = {}
  48. else:
  49. self._alias_index = {}
  50. def save_alias_index(self):
  51. """保存 alias_index 到 state.json(v5.0 内嵌格式,锁内合并 + 原子写入)"""
  52. if not self._state_file.exists():
  53. return
  54. lock_path = self._state_file.with_suffix(self._state_file.suffix + ".lock")
  55. lock = filelock.FileLock(str(lock_path), timeout=10)
  56. try:
  57. with lock:
  58. state = read_json_safe(self._state_file, default={})
  59. disk_alias = state.get("alias_index", {})
  60. if not isinstance(disk_alias, dict):
  61. disk_alias = {}
  62. # 一对多:合并去重(避免覆盖其他进程刚写入的 state 字段/别名)
  63. for alias, entries in (self._alias_index or {}).items():
  64. if not alias or not isinstance(entries, list):
  65. continue
  66. existing = disk_alias.get(alias)
  67. if not isinstance(existing, list):
  68. existing = []
  69. disk_alias[alias] = existing
  70. for entry in entries:
  71. if not isinstance(entry, dict):
  72. continue
  73. et = entry.get("type")
  74. eid = entry.get("id")
  75. if not et or not eid:
  76. continue
  77. if any(
  78. isinstance(e, dict) and e.get("type") == et and e.get("id") == eid
  79. for e in existing
  80. ):
  81. continue
  82. existing.append({"type": et, "id": eid})
  83. state["alias_index"] = disk_alias
  84. self.config.ensure_dirs()
  85. atomic_write_json(self._state_file, state, use_lock=False, backup=True)
  86. # 同步内存到磁盘最新快照
  87. self._alias_index = disk_alias
  88. except filelock.Timeout:
  89. raise RuntimeError("无法获取 state.json 文件锁,请稍后重试")
  90. # ==================== 别名管理 (v5.0 一对多) ====================
  91. def register_alias(self, entity_id: str, alias: str, entity_type: str = "角色") -> bool:
  92. """注册新别名(v5.0 一对多:同一别名可映射多个实体)"""
  93. if not alias:
  94. return False
  95. if alias not in self._alias_index:
  96. self._alias_index[alias] = []
  97. # 检查是否已存在相同 (type, id) 组合
  98. for entry in self._alias_index[alias]:
  99. if entry.get("type") == entity_type and entry.get("id") == entity_id:
  100. return True # 已存在,视为成功
  101. self._alias_index[alias].append({
  102. "type": entity_type,
  103. "id": entity_id
  104. })
  105. return True
  106. def lookup_alias(self, mention: str, entity_type: str = None) -> Optional[str]:
  107. """查找别名对应的实体ID(返回第一个匹配,可选按类型过滤)"""
  108. entries = self._alias_index.get(mention, [])
  109. if not entries:
  110. return None
  111. if entity_type:
  112. for entry in entries:
  113. if entry.get("type") == entity_type:
  114. return entry.get("id")
  115. return None
  116. else:
  117. return entries[0].get("id") if entries else None
  118. def lookup_alias_all(self, mention: str) -> List[Dict]:
  119. """查找别名对应的所有实体(一对多)"""
  120. return self._alias_index.get(mention, [])
  121. def get_all_aliases(self, entity_id: str, entity_type: str = None) -> List[str]:
  122. """获取实体的所有别名"""
  123. aliases = []
  124. for alias, entries in self._alias_index.items():
  125. for entry in entries:
  126. if entry.get("id") == entity_id:
  127. if entity_type is None or entry.get("type") == entity_type:
  128. aliases.append(alias)
  129. break
  130. return aliases
  131. # ==================== 置信度判断 ====================
  132. def evaluate_confidence(self, confidence: float) -> Tuple[str, bool, Optional[str]]:
  133. """
  134. 评估置信度,返回 (action, adopt, warning)
  135. - action: "auto" | "warn" | "manual"
  136. - adopt: 是否采用
  137. - warning: 警告信息
  138. """
  139. if confidence >= self.config.extraction_confidence_high:
  140. return ("auto", True, None)
  141. elif confidence >= self.config.extraction_confidence_medium:
  142. return ("warn", True, f"中置信度匹配 (confidence: {confidence:.2f})")
  143. else:
  144. return ("manual", False, f"需人工确认 (confidence: {confidence:.2f})")
  145. def process_uncertain(
  146. self,
  147. mention: str,
  148. candidates: List[str],
  149. suggested: str,
  150. confidence: float,
  151. context: str = ""
  152. ) -> DisambiguationResult:
  153. """
  154. 处理不确定的实体匹配
  155. 返回消歧结果,包含是否采用、警告信息等
  156. """
  157. action, adopt, warning = self.evaluate_confidence(confidence)
  158. result = DisambiguationResult(
  159. mention=mention,
  160. entity_id=suggested if adopt else None,
  161. confidence=confidence,
  162. candidates=candidates,
  163. adopted=adopt,
  164. warning=warning
  165. )
  166. return result
  167. # ==================== 批量处理 ====================
  168. def process_extraction_result(
  169. self,
  170. uncertain_items: List[Dict]
  171. ) -> Tuple[List[DisambiguationResult], List[str]]:
  172. """
  173. 处理 AI 提取结果中的 uncertain 项
  174. 返回 (results, warnings)
  175. """
  176. results = []
  177. warnings = []
  178. for item in uncertain_items:
  179. result = self.process_uncertain(
  180. mention=item.get("mention", ""),
  181. candidates=item.get("candidates", []),
  182. suggested=item.get("suggested", ""),
  183. confidence=item.get("confidence", 0.0),
  184. context=item.get("context", "")
  185. )
  186. results.append(result)
  187. if result.warning:
  188. warnings.append(f"{result.mention} → {result.entity_id}: {result.warning}")
  189. return results, warnings
  190. def register_new_entities(
  191. self,
  192. new_entities: List[Dict]
  193. ) -> List[str]:
  194. """
  195. 注册新实体的别名 (v5.0)
  196. 返回注册的实体ID列表
  197. """
  198. registered = []
  199. for entity in new_entities:
  200. entity_id = entity.get("suggested_id") or entity.get("id")
  201. if not entity_id or entity_id == "NEW":
  202. continue
  203. entity_type = entity.get("type", "角色")
  204. # 注册主名称
  205. name = entity.get("name", "")
  206. if name:
  207. self.register_alias(entity_id, name, entity_type)
  208. # 注册提及方式
  209. for mention in entity.get("mentions", []):
  210. if mention and mention != name:
  211. self.register_alias(entity_id, mention, entity_type)
  212. registered.append(entity_id)
  213. return registered
  214. # ==================== CLI 接口 ====================
  215. def main():
  216. import argparse
  217. parser = argparse.ArgumentParser(description="Entity Linker CLI (v5.0 一对多别名)")
  218. parser.add_argument("--project-root", type=str, help="项目根目录")
  219. subparsers = parser.add_subparsers(dest="command")
  220. # 注册别名
  221. register_parser = subparsers.add_parser("register-alias")
  222. register_parser.add_argument("--entity", required=True, help="实体ID")
  223. register_parser.add_argument("--alias", required=True, help="别名")
  224. register_parser.add_argument("--type", default="角色", help="实体类型(默认:角色)")
  225. # 查找别名
  226. lookup_parser = subparsers.add_parser("lookup")
  227. lookup_parser.add_argument("--mention", required=True, help="提及文本")
  228. lookup_parser.add_argument("--type", help="按类型过滤")
  229. # 查找所有匹配(一对多)
  230. lookup_all_parser = subparsers.add_parser("lookup-all")
  231. lookup_all_parser.add_argument("--mention", required=True, help="提及文本")
  232. # 列出别名
  233. list_parser = subparsers.add_parser("list-aliases")
  234. list_parser.add_argument("--entity", required=True, help="实体ID")
  235. list_parser.add_argument("--type", help="实体类型")
  236. args = parser.parse_args()
  237. # 初始化
  238. config = None
  239. if args.project_root:
  240. from .config import DataModulesConfig
  241. config = DataModulesConfig.from_project_root(args.project_root)
  242. linker = EntityLinker(config)
  243. if args.command == "register-alias":
  244. entity_type = getattr(args, "type", "角色")
  245. success = linker.register_alias(args.entity, args.alias, entity_type)
  246. if success:
  247. linker.save_alias_index()
  248. print(f"✓ 已注册: {args.alias} → {args.entity} (类型: {entity_type})")
  249. else:
  250. print(f"✗ 注册失败")
  251. elif args.command == "lookup":
  252. entity_type = getattr(args, "type", None)
  253. entity_id = linker.lookup_alias(args.mention, entity_type)
  254. if entity_id:
  255. print(f"{args.mention} → {entity_id}")
  256. else:
  257. print(f"未找到: {args.mention}")
  258. elif args.command == "lookup-all":
  259. entries = linker.lookup_alias_all(args.mention)
  260. if entries:
  261. print(f"{args.mention} 的所有匹配:")
  262. for entry in entries:
  263. print(f" - {entry.get('id')} (类型: {entry.get('type')})")
  264. else:
  265. print(f"未找到: {args.mention}")
  266. elif args.command == "list-aliases":
  267. entity_type = getattr(args, "type", None)
  268. aliases = linker.get_all_aliases(args.entity, entity_type)
  269. if aliases:
  270. print(f"{args.entity} 的别名:")
  271. for alias in aliases:
  272. print(f" - {alias}")
  273. else:
  274. print(f"未找到 {args.entity} 的别名")
  275. if __name__ == "__main__":
  276. main()