query_router.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """Query router for RAG requests."""
  4. from __future__ import annotations
  5. import re
  6. from typing import Any, Dict, List
  7. class QueryRouter:
  8. def __init__(self):
  9. self.intent_patterns = {
  10. "relationship": [r"关系", r"图谱", r"时间线", r"谁和谁", r"敌对", r"盟友"],
  11. "entity": [r"人物", r"角色", r"谁", r"身份", r"别名"],
  12. "scene": [r"地点", r"场景", r"哪里", r"位置"],
  13. "setting": [r"设定", r"规则", r"体系", r"世界观"],
  14. "plot": [r"剧情", r"发生", r"事件", r"经过"],
  15. }
  16. self.patterns = {
  17. "entity": list(self.intent_patterns["entity"]),
  18. "scene": list(self.intent_patterns["scene"]),
  19. "setting": list(self.intent_patterns["setting"]),
  20. "plot": list(self.intent_patterns["plot"]),
  21. }
  22. def _extract_entities(self, query: str) -> List[str]:
  23. # 轻量启发式提取:提取长度 2-6 的中文短语,过滤常见查询词
  24. candidates = re.findall(r"[\u4e00-\u9fff]{2,6}", query)
  25. stopwords = {
  26. "关系",
  27. "图谱",
  28. "时间线",
  29. "剧情",
  30. "发生",
  31. "事件",
  32. "角色",
  33. "人物",
  34. "设定",
  35. "世界观",
  36. "地点",
  37. "场景",
  38. }
  39. entities: List[str] = []
  40. for c in candidates:
  41. if c in stopwords:
  42. continue
  43. if c not in entities:
  44. entities.append(c)
  45. return entities[:4]
  46. def _extract_time_scope(self, query: str) -> Dict[str, Any]:
  47. m_range = re.search(r"第?\s*(\d+)\s*[-~到]\s*(\d+)\s*章", query)
  48. if m_range:
  49. start = int(m_range.group(1))
  50. end = int(m_range.group(2))
  51. if start > end:
  52. start, end = end, start
  53. return {"from_chapter": start, "to_chapter": end}
  54. m_single = re.search(r"第?\s*(\d+)\s*章", query)
  55. if m_single:
  56. chapter = int(m_single.group(1))
  57. return {"from_chapter": chapter, "to_chapter": chapter}
  58. return {}
  59. def route_intent(self, query: str) -> Dict[str, Any]:
  60. query = str(query or "")
  61. intent = "plot"
  62. for intent_name, patterns in self.intent_patterns.items():
  63. if any(re.search(pat, query) for pat in patterns):
  64. intent = intent_name
  65. break
  66. time_scope = self._extract_time_scope(query)
  67. entities = self._extract_entities(query)
  68. needs_graph = intent == "relationship" or "关系" in query or "图谱" in query
  69. return {
  70. "intent": intent,
  71. "entities": entities,
  72. "time_scope": time_scope,
  73. "needs_graph": needs_graph,
  74. "raw_query": query,
  75. }
  76. def plan_subqueries(self, intent_payload: Dict[str, Any]) -> List[Dict[str, Any]]:
  77. intent = str((intent_payload or {}).get("intent") or "plot")
  78. entities = list((intent_payload or {}).get("entities") or [])
  79. time_scope = dict((intent_payload or {}).get("time_scope") or {})
  80. needs_graph = bool((intent_payload or {}).get("needs_graph"))
  81. steps: List[Dict[str, Any]] = []
  82. if intent == "relationship":
  83. steps.append(
  84. {
  85. "name": "relationship_graph",
  86. "strategy": "graph_lookup",
  87. "entities": entities,
  88. "time_scope": time_scope,
  89. }
  90. )
  91. steps.append(
  92. {
  93. "name": "relationship_evidence",
  94. "strategy": "graph_hybrid",
  95. "entities": entities,
  96. "time_scope": time_scope,
  97. }
  98. )
  99. return steps
  100. if needs_graph and entities:
  101. steps.append(
  102. {
  103. "name": "graph_enhanced_retrieval",
  104. "strategy": "graph_hybrid",
  105. "entities": entities,
  106. "time_scope": time_scope,
  107. }
  108. )
  109. return steps
  110. strategy_map = {
  111. "entity": "hybrid",
  112. "scene": "bm25",
  113. "setting": "bm25",
  114. "plot": "hybrid",
  115. }
  116. steps.append(
  117. {
  118. "name": "default_retrieval",
  119. "strategy": strategy_map.get(intent, "hybrid"),
  120. "entities": entities,
  121. "time_scope": time_scope,
  122. }
  123. )
  124. return steps
  125. def route(self, query: str) -> str:
  126. return str(self.route_intent(query).get("intent") or "plot")
  127. def split(self, query: str) -> List[str]:
  128. parts = re.split(r"[,,;;以及和]\s*", query)
  129. return [p.strip() for p in parts if p.strip()]