query_router.py 902 B

12345678910111213141516171819202122232425262728
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """Query router for RAG requests."""
  4. from __future__ import annotations
  5. import re
  6. from typing import List
  7. class QueryRouter:
  8. def __init__(self):
  9. self.patterns = {
  10. "entity": [r"人物", r"角色", r"谁", r"身份", r"别名"],
  11. "scene": [r"地点", r"场景", r"哪里", r"位置"],
  12. "setting": [r"设定", r"规则", r"体系", r"世界观"],
  13. "plot": [r"剧情", r"发生", r"事件", r"经过"],
  14. }
  15. def route(self, query: str) -> str:
  16. for qtype, patterns in self.patterns.items():
  17. for pat in patterns:
  18. if re.search(pat, query):
  19. return qtype
  20. return "plot"
  21. def split(self, query: str) -> List[str]:
  22. parts = re.split(r"[,,;;以及和]\s*", query)
  23. return [p.strip() for p in parts if p.strip()]