entity_linker.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Entity Linker - 实体消歧辅助模块 (v5.4)
  5. 为 Data Agent 提供实体消歧的辅助功能:
  6. - 置信度判断
  7. - 别名索引管理 (通过 index.db aliases 表)
  8. - 消歧结果记录
  9. v5.1 变更(v5.4 沿用):
  10. - 别名存储从 state.json 迁移到 index.db aliases 表
  11. - 使用 IndexManager 进行别名读写
  12. - 移除对 state.json 的直接操作
  13. """
  14. from typing import Dict, List, Optional, Tuple
  15. from dataclasses import dataclass, field
  16. from .config import get_config
  17. from .index_manager import IndexManager
  18. @dataclass
  19. class DisambiguationResult:
  20. """消歧结果"""
  21. mention: str
  22. entity_id: Optional[str]
  23. confidence: float
  24. candidates: List[str] = field(default_factory=list)
  25. adopted: bool = False
  26. warning: Optional[str] = None
  27. class EntityLinker:
  28. """实体链接器 - 辅助 Data Agent 进行实体消歧 (v5.1 SQLite,v5.4 沿用)"""
  29. def __init__(self, config=None):
  30. self.config = config or get_config()
  31. self._index_manager = IndexManager(self.config)
  32. # ==================== 别名管理 (v5.1 SQLite,v5.4 沿用) ====================
  33. def register_alias(self, entity_id: str, alias: str, entity_type: str = "角色") -> bool:
  34. """注册新别名(v5.1 引入:写入 index.db aliases 表)"""
  35. if not alias or not entity_id:
  36. return False
  37. return self._index_manager.register_alias(alias, entity_id, entity_type)
  38. def lookup_alias(self, mention: str, entity_type: str = None) -> Optional[str]:
  39. """查找别名对应的实体ID(返回第一个匹配,可选按类型过滤)"""
  40. entries = self._index_manager.get_entities_by_alias(mention)
  41. if not entries:
  42. return None
  43. if entity_type:
  44. for entry in entries:
  45. if entry.get("type") == entity_type:
  46. return entry.get("id")
  47. return None
  48. else:
  49. return entries[0].get("id") if entries else None
  50. def lookup_alias_all(self, mention: str) -> List[Dict]:
  51. """查找别名对应的所有实体(一对多)"""
  52. entries = self._index_manager.get_entities_by_alias(mention)
  53. return [{"type": e.get("type"), "id": e.get("id")} for e in entries]
  54. def get_all_aliases(self, entity_id: str, entity_type: str = None) -> List[str]:
  55. """获取实体的所有别名"""
  56. return self._index_manager.get_entity_aliases(entity_id)
  57. # ==================== 置信度判断 ====================
  58. def evaluate_confidence(self, confidence: float) -> Tuple[str, bool, Optional[str]]:
  59. """
  60. 评估置信度,返回 (action, adopt, warning)
  61. - action: "auto" | "warn" | "manual"
  62. - adopt: 是否采用
  63. - warning: 警告信息
  64. """
  65. if confidence >= self.config.extraction_confidence_high:
  66. return ("auto", True, None)
  67. elif confidence >= self.config.extraction_confidence_medium:
  68. return ("warn", True, f"中置信度匹配 (confidence: {confidence:.2f})")
  69. else:
  70. return ("manual", False, f"需人工确认 (confidence: {confidence:.2f})")
  71. def process_uncertain(
  72. self,
  73. mention: str,
  74. candidates: List[str],
  75. suggested: str,
  76. confidence: float,
  77. context: str = ""
  78. ) -> DisambiguationResult:
  79. """
  80. 处理不确定的实体匹配
  81. 返回消歧结果,包含是否采用、警告信息等
  82. """
  83. action, adopt, warning = self.evaluate_confidence(confidence)
  84. result = DisambiguationResult(
  85. mention=mention,
  86. entity_id=suggested if adopt else None,
  87. confidence=confidence,
  88. candidates=candidates,
  89. adopted=adopt,
  90. warning=warning
  91. )
  92. return result
  93. # ==================== 批量处理 ====================
  94. def process_extraction_result(
  95. self,
  96. uncertain_items: List[Dict]
  97. ) -> Tuple[List[DisambiguationResult], List[str]]:
  98. """
  99. 处理 AI 提取结果中的 uncertain 项
  100. 返回 (results, warnings)
  101. """
  102. results = []
  103. warnings = []
  104. for item in uncertain_items:
  105. result = self.process_uncertain(
  106. mention=item.get("mention", ""),
  107. candidates=item.get("candidates", []),
  108. suggested=item.get("suggested", ""),
  109. confidence=item.get("confidence", 0.0),
  110. context=item.get("context", "")
  111. )
  112. results.append(result)
  113. if result.warning:
  114. warnings.append(f"{result.mention} → {result.entity_id}: {result.warning}")
  115. return results, warnings
  116. def register_new_entities(
  117. self,
  118. new_entities: List[Dict]
  119. ) -> List[str]:
  120. """
  121. 注册新实体的别名 (v5.1 引入,v5.4 沿用)
  122. 返回注册的实体ID列表
  123. """
  124. registered = []
  125. for entity in new_entities:
  126. entity_id = entity.get("suggested_id") or entity.get("id")
  127. if not entity_id or entity_id == "NEW":
  128. continue
  129. entity_type = entity.get("type", "角色")
  130. # 注册主名称
  131. name = entity.get("name", "")
  132. if name:
  133. self.register_alias(entity_id, name, entity_type)
  134. # 注册提及方式
  135. for mention in entity.get("mentions", []):
  136. if mention and mention != name:
  137. self.register_alias(entity_id, mention, entity_type)
  138. registered.append(entity_id)
  139. return registered
  140. # ==================== CLI 接口 ====================
  141. def main():
  142. import argparse
  143. from .cli_output import print_success, print_error
  144. from .index_manager import IndexManager
  145. parser = argparse.ArgumentParser(description="Entity Linker CLI (v5.4 SQLite)")
  146. parser.add_argument("--project-root", type=str, help="项目根目录")
  147. subparsers = parser.add_subparsers(dest="command")
  148. # 注册别名
  149. register_parser = subparsers.add_parser("register-alias")
  150. register_parser.add_argument("--entity", required=True, help="实体ID")
  151. register_parser.add_argument("--alias", required=True, help="别名")
  152. register_parser.add_argument("--type", default="角色", help="实体类型(默认:角色)")
  153. # 查找别名
  154. lookup_parser = subparsers.add_parser("lookup")
  155. lookup_parser.add_argument("--mention", required=True, help="提及文本")
  156. lookup_parser.add_argument("--type", help="按类型过滤")
  157. # 查找所有匹配(一对多)
  158. lookup_all_parser = subparsers.add_parser("lookup-all")
  159. lookup_all_parser.add_argument("--mention", required=True, help="提及文本")
  160. # 列出别名
  161. list_parser = subparsers.add_parser("list-aliases")
  162. list_parser.add_argument("--entity", required=True, help="实体ID")
  163. list_parser.add_argument("--type", help="实体类型")
  164. args = parser.parse_args()
  165. # 初始化
  166. config = None
  167. if args.project_root:
  168. from .config import DataModulesConfig
  169. config = DataModulesConfig.from_project_root(args.project_root)
  170. linker = EntityLinker(config)
  171. logger = IndexManager(config)
  172. tool_name = f"entity_linker:{args.command or 'unknown'}"
  173. def emit_success(data=None, message: str = "ok"):
  174. print_success(data, message=message)
  175. try:
  176. logger.log_tool_call(tool_name, True)
  177. except Exception:
  178. pass
  179. def emit_error(code: str, message: str, suggestion: str | None = None):
  180. print_error(code, message, suggestion=suggestion)
  181. try:
  182. logger.log_tool_call(tool_name, False, error_code=code, error_message=message)
  183. except Exception:
  184. pass
  185. if args.command == "register-alias":
  186. entity_type = getattr(args, "type", "角色")
  187. success = linker.register_alias(args.entity, args.alias, entity_type)
  188. if success:
  189. emit_success({"entity": args.entity, "alias": args.alias, "type": entity_type}, message="alias_registered")
  190. else:
  191. emit_error("ALIAS_EXISTS", "注册失败或已存在")
  192. elif args.command == "lookup":
  193. entity_type = getattr(args, "type", None)
  194. entity_id = linker.lookup_alias(args.mention, entity_type)
  195. if entity_id:
  196. emit_success({"mention": args.mention, "entity": entity_id}, message="lookup")
  197. else:
  198. emit_error("NOT_FOUND", f"未找到别名: {args.mention}")
  199. elif args.command == "lookup-all":
  200. matches = linker.lookup_alias_all(args.mention)
  201. emit_success(matches, message="lookup_all")
  202. elif args.command == "list-aliases":
  203. entity_type = getattr(args, "type", None)
  204. aliases = linker.get_all_aliases(args.entity, entity_type)
  205. emit_success(aliases, message="aliases")
  206. else:
  207. emit_error("UNKNOWN_COMMAND", "未指定有效命令", suggestion="请查看 --help")
  208. if __name__ == "__main__":
  209. main()