entity_linker.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Entity Linker - 实体消歧辅助模块 (v5.1)
  5. 为 Data Agent 提供实体消歧的辅助功能:
  6. - 置信度判断
  7. - 别名索引管理 (通过 index.db aliases 表)
  8. - 消歧结果记录
  9. v5.1 变更:
  10. - 别名存储从 state.json 迁移到 index.db aliases 表
  11. - 使用 IndexManager 进行别名读写
  12. - 移除对 state.json 的直接操作
  13. """
  14. from typing import Dict, List, Optional, Tuple
  15. from dataclasses import dataclass, field
  16. from .config import get_config
  17. from .index_manager import IndexManager
  18. @dataclass
  19. class DisambiguationResult:
  20. """消歧结果"""
  21. mention: str
  22. entity_id: Optional[str]
  23. confidence: float
  24. candidates: List[str] = field(default_factory=list)
  25. adopted: bool = False
  26. warning: Optional[str] = None
  27. class EntityLinker:
  28. """实体链接器 - 辅助 Data Agent 进行实体消歧 (v5.1 SQLite)"""
  29. def __init__(self, config=None):
  30. self.config = config or get_config()
  31. self._index_manager = IndexManager(self.config)
  32. # ==================== 别名管理 (v5.1 SQLite) ====================
  33. def register_alias(self, entity_id: str, alias: str, entity_type: str = "角色") -> bool:
  34. """注册新别名(v5.1: 写入 index.db aliases 表)"""
  35. if not alias or not entity_id:
  36. return False
  37. return self._index_manager.register_alias(alias, entity_id, entity_type)
  38. def lookup_alias(self, mention: str, entity_type: str = None) -> Optional[str]:
  39. """查找别名对应的实体ID(返回第一个匹配,可选按类型过滤)"""
  40. entries = self._index_manager.get_entities_by_alias(mention)
  41. if not entries:
  42. return None
  43. if entity_type:
  44. for entry in entries:
  45. if entry.get("type") == entity_type:
  46. return entry.get("id")
  47. return None
  48. else:
  49. return entries[0].get("id") if entries else None
  50. def lookup_alias_all(self, mention: str) -> List[Dict]:
  51. """查找别名对应的所有实体(一对多)"""
  52. entries = self._index_manager.get_entities_by_alias(mention)
  53. return [{"type": e.get("type"), "id": e.get("id")} for e in entries]
  54. def get_all_aliases(self, entity_id: str, entity_type: str = None) -> List[str]:
  55. """获取实体的所有别名"""
  56. return self._index_manager.get_entity_aliases(entity_id)
  57. # ==================== 置信度判断 ====================
  58. def evaluate_confidence(self, confidence: float) -> Tuple[str, bool, Optional[str]]:
  59. """
  60. 评估置信度,返回 (action, adopt, warning)
  61. - action: "auto" | "warn" | "manual"
  62. - adopt: 是否采用
  63. - warning: 警告信息
  64. """
  65. if confidence >= self.config.extraction_confidence_high:
  66. return ("auto", True, None)
  67. elif confidence >= self.config.extraction_confidence_medium:
  68. return ("warn", True, f"中置信度匹配 (confidence: {confidence:.2f})")
  69. else:
  70. return ("manual", False, f"需人工确认 (confidence: {confidence:.2f})")
  71. def process_uncertain(
  72. self,
  73. mention: str,
  74. candidates: List[str],
  75. suggested: str,
  76. confidence: float,
  77. context: str = ""
  78. ) -> DisambiguationResult:
  79. """
  80. 处理不确定的实体匹配
  81. 返回消歧结果,包含是否采用、警告信息等
  82. """
  83. action, adopt, warning = self.evaluate_confidence(confidence)
  84. result = DisambiguationResult(
  85. mention=mention,
  86. entity_id=suggested if adopt else None,
  87. confidence=confidence,
  88. candidates=candidates,
  89. adopted=adopt,
  90. warning=warning
  91. )
  92. return result
  93. # ==================== 批量处理 ====================
  94. def process_extraction_result(
  95. self,
  96. uncertain_items: List[Dict]
  97. ) -> Tuple[List[DisambiguationResult], List[str]]:
  98. """
  99. 处理 AI 提取结果中的 uncertain 项
  100. 返回 (results, warnings)
  101. """
  102. results = []
  103. warnings = []
  104. for item in uncertain_items:
  105. result = self.process_uncertain(
  106. mention=item.get("mention", ""),
  107. candidates=item.get("candidates", []),
  108. suggested=item.get("suggested", ""),
  109. confidence=item.get("confidence", 0.0),
  110. context=item.get("context", "")
  111. )
  112. results.append(result)
  113. if result.warning:
  114. warnings.append(f"{result.mention} → {result.entity_id}: {result.warning}")
  115. return results, warnings
  116. def register_new_entities(
  117. self,
  118. new_entities: List[Dict]
  119. ) -> List[str]:
  120. """
  121. 注册新实体的别名 (v5.1)
  122. 返回注册的实体ID列表
  123. """
  124. registered = []
  125. for entity in new_entities:
  126. entity_id = entity.get("suggested_id") or entity.get("id")
  127. if not entity_id or entity_id == "NEW":
  128. continue
  129. entity_type = entity.get("type", "角色")
  130. # 注册主名称
  131. name = entity.get("name", "")
  132. if name:
  133. self.register_alias(entity_id, name, entity_type)
  134. # 注册提及方式
  135. for mention in entity.get("mentions", []):
  136. if mention and mention != name:
  137. self.register_alias(entity_id, mention, entity_type)
  138. registered.append(entity_id)
  139. return registered
  140. # ==================== CLI 接口 ====================
  141. def main():
  142. import argparse
  143. parser = argparse.ArgumentParser(description="Entity Linker CLI (v5.1 SQLite)")
  144. parser.add_argument("--project-root", type=str, help="项目根目录")
  145. subparsers = parser.add_subparsers(dest="command")
  146. # 注册别名
  147. register_parser = subparsers.add_parser("register-alias")
  148. register_parser.add_argument("--entity", required=True, help="实体ID")
  149. register_parser.add_argument("--alias", required=True, help="别名")
  150. register_parser.add_argument("--type", default="角色", help="实体类型(默认:角色)")
  151. # 查找别名
  152. lookup_parser = subparsers.add_parser("lookup")
  153. lookup_parser.add_argument("--mention", required=True, help="提及文本")
  154. lookup_parser.add_argument("--type", help="按类型过滤")
  155. # 查找所有匹配(一对多)
  156. lookup_all_parser = subparsers.add_parser("lookup-all")
  157. lookup_all_parser.add_argument("--mention", required=True, help="提及文本")
  158. # 列出别名
  159. list_parser = subparsers.add_parser("list-aliases")
  160. list_parser.add_argument("--entity", required=True, help="实体ID")
  161. list_parser.add_argument("--type", help="实体类型")
  162. args = parser.parse_args()
  163. # 初始化
  164. config = None
  165. if args.project_root:
  166. from .config import DataModulesConfig
  167. config = DataModulesConfig.from_project_root(args.project_root)
  168. linker = EntityLinker(config)
  169. if args.command == "register-alias":
  170. entity_type = getattr(args, "type", "角色")
  171. success = linker.register_alias(args.entity, args.alias, entity_type)
  172. if success:
  173. print(f"✓ 已注册: {args.alias} → {args.entity} (类型: {entity_type})")
  174. else:
  175. print(f"✗ 注册失败或已存在")
  176. elif args.command == "lookup":
  177. entity_type = getattr(args, "type", None)
  178. entity_id = linker.lookup_alias(args.mention, entity_type)
  179. if entity_id:
  180. print(f"{args.mention} → {entity_id}")
  181. else:
  182. print(f"未找到: {args.mention}")
  183. elif args.command == "lookup-all":
  184. entries = linker.lookup_alias_all(args.mention)
  185. if entries:
  186. print(f"{args.mention} 的所有匹配:")
  187. for entry in entries:
  188. print(f" - {entry.get('id')} (类型: {entry.get('type')})")
  189. else:
  190. print(f"未找到: {args.mention}")
  191. elif args.command == "list-aliases":
  192. entity_type = getattr(args, "type", None)
  193. aliases = linker.get_all_aliases(args.entity, entity_type)
  194. if aliases:
  195. print(f"{args.entity} 的别名:")
  196. for alias in aliases:
  197. print(f" - {alias}")
  198. else:
  199. print(f"未找到 {args.entity} 的别名")
  200. if __name__ == "__main__":
  201. main()