worker_manager.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. from pdb import run
  2. import sys
  3. from langchain.tools import tool
  4. from langchain.agents import create_agent
  5. from langchain_openai import ChatOpenAI
  6. from typing import Dict, Any, Optional
  7. from openai import chat
  8. from config.settings import settings
  9. from langchain_core.messages import HumanMessage, SystemMessage
  10. import datetime
  11. from utils.context_helper import safe_context_param
  12. from utils.logger import chat_logger
  13. """
  14. 增加worker需要修改:
  15. worker_type_mapping
  16. """
  17. # WorkerAgent缓存字典
  18. _worker_cache = {}
  19. # 工具缓存字典
  20. _tools_cache = {}
  21. # worker_type, tool_types
  22. worker_type_mapping = {
  23. "sale_worker": ("销售报表查询", ["sale"]),
  24. "ware_worker": ("库存查询", ["ware"]),
  25. "money_worker": ("财务查询", ["money"]),
  26. "price_worker": ("价格查询", ["price"]),
  27. "salebill_worker": ("销售类单据查询", ["salebill"]),
  28. "base_data_worker": ("基础资料查询", ["mtrl_data", "cust_data"]),
  29. }
  30. class WorkerAgent:
  31. """Worker Agent - 专门处理特定领域的任务"""
  32. def __init__(self, worker_type: str, tools: list):
  33. self.worker_type = worker_type
  34. self.tools = tools
  35. self.llm = ChatOpenAI(
  36. model=settings.LLM_MODEL,
  37. temperature=settings.LLM_TEMPERATURE,
  38. api_key=settings.DEEPSEEK_API_KEY,
  39. base_url=settings.DEEPSEEK_BASE_URL,
  40. max_tokens=settings.LLM_MAX_TOKENS,
  41. )
  42. self.agent = self._create_agent()
  43. def _create_agent(self):
  44. """创建Agent实例"""
  45. return create_agent(
  46. self.llm,
  47. tools=self.tools,
  48. )
  49. def execute(self, query: str, backend_url: str, token: str) -> any:
  50. """执行任务"""
  51. try:
  52. system_prompt = f"""
  53. 你是{self.worker_type}专家,专门处理{self.worker_type}相关任务。
  54. 职责:
  55. - 接收主Agent分配的具体任务
  56. - 智能选择最合适的工具执行任务
  57. - 理解工具返回的数据结构
  58. - 组织完整的答案向用户回答
  59. 工具用到的参数:
  60. - backend_url: {backend_url}
  61. - token: {token}
  62. 当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
  63. 数据查询结果尽量以 Markdown 表格格式输出
  64. 如果工具有输出格式要求,必须严格按照要求输出。
  65. """
  66. if settings.ECHARTS_ENABLED:
  67. system_prompt = (
  68. system_prompt
  69. + """
  70. 你可以输出echarts柱状图、折线图、饼图。
  71. 饼图格式范例如下:
  72. ```echarts
  73. {{
  74. "title": {{
  75. "text": "浏览器份额", "left": "center" }}
  76. "tooltip": {{
  77. "trigger": "item" }},
  78. "legend": {{
  79. "orient": "vertical", "left": "left" }},
  80. "series": [
  81. {{
  82. "name": "Share",
  83. "type": "pie",
  84. "radius": "55%",
  85. "center": ["50%", "60%"],
  86. "data": [
  87. {{"value": 1048, "name": "Chrome" }}
  88. {{"value": 735, "name": "Firefox" }}
  89. {{"value": 580, "name": "Edge" }}
  90. ]
  91. }}
  92. ]
  93. }}
  94. ```
  95. 柱状图格式范例如下:
  96. ```echarts
  97. {{
  98. "title": {{"text": "每周销量" }}
  99. "tooltip": {{}},
  100. "xAxis": {{"type": "category", "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"] }}
  101. "yAxis": {{"type": "value" }}
  102. "series": [
  103. {{"type": "bar", "data": [120, 200, 150, 80, 70, 110, 130] }}
  104. ]
  105. }}
  106. ```
  107. 折线图,格式范例如下:
  108. ```echarts
  109. {{
  110. "title": {{ "text": "温度趋势" }},
  111. "tooltip": {{ "trigger": "axis" }},
  112. "legend": {{ "data": ["最高", "最低"] }},
  113. "xAxis": {{
  114. "type": "category",
  115. "boundaryGap": false,
  116. "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"]
  117. }},
  118. "yAxis": {{"type": "value" }},
  119. "series": [
  120. {{"name": "最高", "type": "line", "data": [11, 11, 15, 13, 12, 13, 10], "smooth": true }}
  121. {{"name": "最低", "type": "line", "data": [1, -2, 2, 5, 3, 2, 0], "smooth": true }}
  122. ]
  123. }}
  124. ```
  125. """
  126. )
  127. inputs = {
  128. "messages": [
  129. SystemMessage(content=system_prompt),
  130. HumanMessage(content=query),
  131. ]
  132. }
  133. response = self.agent.invoke(inputs)
  134. return response
  135. except Exception as e:
  136. return f"Worker执行失败: {str(e)}"
  137. # 自动发现指定类型的工具
  138. def _get_tools_by_type(tool_type: str) -> list:
  139. """根据类型自动发现工具 - 完整支持编译前后的文件"""
  140. # 检查缓存
  141. if tool_type in _tools_cache:
  142. chat_logger.info(f"使用缓存的工具: {tool_type}")
  143. return _tools_cache[tool_type]
  144. chat_logger.info(f"发现并加载工具: {tool_type}")
  145. import importlib
  146. import inspect
  147. from pathlib import Path
  148. import sys
  149. from langchain.tools import BaseTool
  150. tools = []
  151. # 获取项目根目录并添加到Python路径
  152. project_root = Path(__file__).parent.parent
  153. if str(project_root) not in sys.path:
  154. sys.path.insert(0, str(project_root))
  155. tools_dir = project_root / "tools"
  156. # 导入配置以检查知识库开关
  157. from config.settings import settings
  158. # 扫描工具文件
  159. tool_files = []
  160. # 模式1: 编译前的.py文件
  161. for file_path in tools_dir.glob("*_tools.py"):
  162. if file_path.is_file():
  163. module_name = file_path.stem
  164. # 检查文件名是否以指定类型开头(前缀匹配)
  165. if module_name.startswith(f"{tool_type}_"):
  166. # 如果知识库被禁用,跳过知识库工具
  167. if (
  168. module_name == "knowledge_tools"
  169. and not settings.KNOWLEDGE_BASE_ENABLED
  170. ):
  171. print("知识库功能已禁用,跳过知识库工具")
  172. continue
  173. tool_files.append(module_name)
  174. print(f"发现工具文件: {module_name}")
  175. # 模式2: 编译后的.pyd文件
  176. for file_path in tools_dir.glob("*_tools.cp*.pyd"):
  177. if file_path.is_file():
  178. # 从文件名中提取模块名,如: ware_tools.cp313-win_amd64.pyd -> ware_tools
  179. module_name = file_path.stem.split(".")[0]
  180. # 检查文件名是否以指定类型开头(前缀匹配)
  181. if module_name.startswith(f"{tool_type}_"):
  182. # 如果知识库被禁用,跳过知识库工具
  183. if (
  184. module_name == "knowledge_tools"
  185. and not settings.KNOWLEDGE_BASE_ENABLED
  186. ):
  187. print("知识库功能已禁用,跳过知识库工具")
  188. continue
  189. if module_name not in tool_files: # 避免重复添加
  190. tool_files.append(module_name)
  191. print(f"发现编译后工具文件: {module_name}")
  192. # 如果没有找到指定类型的工具文件,使用默认列表
  193. if not tool_files:
  194. # 根据类型提供默认工具列表
  195. default_tool_mapping = {
  196. "sale": ["sale_tools"],
  197. "ware": ["ware_tools"],
  198. "money": ["money_tools"],
  199. "price": ["price_tools"],
  200. "knowledge": ["knowledge_tools"] if settings.KNOWLEDGE_BASE_ENABLED else [],
  201. }
  202. if tool_type in default_tool_mapping:
  203. tool_files = default_tool_mapping[tool_type]
  204. print(f"使用默认工具列表: {tool_files}")
  205. # 导入模块并获取工具
  206. for module_name in tool_files:
  207. try:
  208. full_module_path = f"tools.{module_name}"
  209. module = importlib.import_module(full_module_path)
  210. # 使用更全面的工具发现方法(参考tool_factory.py)
  211. tool_count = 0
  212. # 方法1: 检查模块的所有属性
  213. for attr_name in dir(module):
  214. if attr_name.startswith("_"):
  215. continue
  216. attr = getattr(module, attr_name)
  217. # 检查是否是BaseTool实例
  218. if isinstance(attr, BaseTool):
  219. tools.append(attr)
  220. tool_count += 1
  221. continue
  222. # 检查是否是函数且具有工具属性
  223. if callable(attr) and hasattr(attr, "name"):
  224. tools.append(attr)
  225. tool_count += 1
  226. # 方法2: 检查模块中是否有get_all_tools函数
  227. if hasattr(module, "get_all_tools"):
  228. module_tools = module.get_all_tools()
  229. if isinstance(module_tools, list):
  230. tools.extend(module_tools)
  231. tool_count += len(module_tools)
  232. if tool_count > 0:
  233. print(f"✅ 从 {module_name} 加载了 {tool_count} 个工具")
  234. else:
  235. print(f"⚠️ {module_name} 中未发现工具函数")
  236. except Exception as e:
  237. print(f"❌ 导入模块 {module_name} 失败: {e}")
  238. # 缓存工具
  239. _tools_cache[tool_type] = tools
  240. return tools
  241. def _get_worker_agent(worker_type: str, tools: list) -> WorkerAgent:
  242. """获取或创建WorkerAgent实例(使用缓存)"""
  243. cache_key = f"{worker_type}" # 使用worker类型作为缓存键
  244. if cache_key not in _worker_cache:
  245. chat_logger.info(f"创建新的WorkerAgent缓存: {worker_type}")
  246. _worker_cache[cache_key] = WorkerAgent(worker_type, tools)
  247. else:
  248. chat_logger.info(f"使用缓存的WorkerAgent: {worker_type}")
  249. return _worker_cache[cache_key]
  250. def _execute_and_get_result(
  251. worker_type: str,
  252. tools: list,
  253. token: str,
  254. backend_url: str,
  255. query: str,
  256. context: str = "",
  257. ) -> str:
  258. """执行worker并获取结果"""
  259. try:
  260. chat_logger.info(f"开始执行worker: {worker_type}")
  261. # worker = WorkerAgent(worker_type, tools)
  262. worker = _get_worker_agent(worker_type, tools)
  263. # 如果有上下文,合并到查询中
  264. # chat_logger.info(f"原始上下文: {context}")
  265. safe_context = safe_context_param(context)
  266. if safe_context:
  267. enhanced_query = f"对话上下文:{safe_context}\n当前问题:{query}"
  268. else:
  269. enhanced_query = query
  270. # chat_logger.info(f"处理后的上下文: {safe_context}")
  271. result = worker.execute(enhanced_query, backend_url, token)
  272. chat_logger.info(f"Worker{worker_type}执行成功:问题{query}")
  273. ai_messages = [
  274. msg
  275. for msg in result.get("messages", [])
  276. if hasattr(msg, "type") and msg.type == "ai"
  277. ]
  278. if ai_messages:
  279. # 获取最后一条 AI 消息的内容
  280. last_ai_content = ai_messages[-1].content
  281. return last_ai_content
  282. else:
  283. return ""
  284. except Exception as e:
  285. chat_logger.error(
  286. f"Worker{worker_type}执行失败:问题{query},错误信息: {str(e)}"
  287. )
  288. return f"Worker{worker_type}执行失败:问题{query},错误信息: {str(e)}"
  289. def init_worker_cache() -> None:
  290. """初始化Worker缓存(在服务启动时调用)"""
  291. chat_logger.info("开始初始化Worker缓存...")
  292. # 直接遍历worker_type_mapping字典来初始化缓存
  293. for worker_name, (worker_type, tool_types) in worker_type_mapping.items():
  294. try:
  295. # 组合工具
  296. tools = []
  297. for tool_type in tool_types:
  298. tools.extend(_get_tools_by_type(tool_type))
  299. if tools:
  300. # 创建WorkerAgent实例并缓存
  301. cache_key = f"{worker_type}"
  302. if cache_key not in _worker_cache:
  303. _worker_cache[cache_key] = WorkerAgent(worker_type, tools)
  304. chat_logger.info(
  305. f"预缓存WorkerAgent: {worker_type} ({worker_name})"
  306. )
  307. else:
  308. chat_logger.info(f"WorkerAgent已缓存: {worker_type}")
  309. else:
  310. chat_logger.info(f"未发现工具,跳过缓存: {worker_type}")
  311. except Exception as e:
  312. chat_logger.info(f"预缓存WorkerAgent失败 {worker_name}: {e}")
  313. chat_logger.info(f"Worker缓存初始化完成,共缓存 {len(_worker_cache)} 个WorkerAgent")
  314. def clear_worker_cache() -> None:
  315. """清除Worker缓存"""
  316. global _worker_cache, _tools_cache
  317. _worker_cache.clear()
  318. _tools_cache.clear()
  319. chat_logger.info("Worker缓存已清除")
  320. def get_worker_cache_stats() -> dict:
  321. """获取Worker缓存统计信息"""
  322. return {
  323. "worker_agents_cached": len(_worker_cache),
  324. "tools_cached": len(_tools_cache),
  325. "cached_worker_types": list(_worker_cache.keys()),
  326. "cached_tool_types": list(_tools_cache.keys()),
  327. }
  328. def get_worker_config(worker_name: str) -> tuple[str, list[str]]:
  329. """根据worker名称获取配置信息"""
  330. if worker_name not in worker_type_mapping:
  331. raise ValueError(f"未知的worker名称: {worker_name}")
  332. return worker_type_mapping[worker_name]
  333. def run_worker(
  334. worker_name: str,
  335. token: str,
  336. backend_url: str,
  337. query: str,
  338. context: str = "",
  339. ) -> str:
  340. worker_type, tool_types = get_worker_config(worker_name)
  341. tools = []
  342. for tool_type in tool_types:
  343. tools.extend(_get_tools_by_type(tool_type))
  344. return _execute_and_get_result(
  345. worker_type, tools, token, backend_url, query, context
  346. )
  347. # Worker定义
  348. @tool
  349. def sale_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
  350. """
  351. 销售报表查询专家
  352. 职责:销售金额、销售数量、客户销售排名、业务员销售排名、产品型号销售数据等
  353. 使用场景:"2024客户销售业绩排名"、"查2024年产品型号销售排名"、"查看2023年1月1日至2023年12月31日的销售金额"等
  354. """
  355. # tools = _get_tools_by_type("sale") # 即 sale_tools.py下的所有工具
  356. # return _execute_and_get_result(
  357. # "销售报表查询", tools, token, backend_url, query, context
  358. # )
  359. return run_worker("sale_worker", token, backend_url, query, context)
  360. @tool
  361. def ware_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
  362. """
  363. 库存查询专家
  364. 职责:库存管理查询相关问题解答
  365. 使用场景:"查看铜管的库存"、"安吉仓库最多的库存哪个型号"
  366. """
  367. # tools = _get_tools_by_type("ware") # 即 ware_tools.py下的所有工具
  368. # return _execute_and_get_result(
  369. # "库存查询", tools, token, backend_url, query, context
  370. # )
  371. return run_worker("ware_worker", token, backend_url, query, context)
  372. @tool
  373. def money_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
  374. """
  375. 财务查询专家
  376. 职责:财务查询相关问题解答
  377. 使用场景:"查询客户A的应收帐情况"、"查询客户A,2025年的收款情况"
  378. """
  379. # 图表工具可以复用数据查询工具的结果
  380. # tools = _get_tools_by_type("money") # 即 money_tools.py下的所有工具
  381. # return _execute_and_get_result(
  382. # "财务查询", tools, token, backend_url, query, context
  383. # )
  384. return run_worker("money_worker", token, backend_url, query, context)
  385. @tool
  386. def price_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
  387. """
  388. 价格查询专家
  389. 职责:价格查询相关问题解答
  390. 使用场景:"查询铜管的销售价格"
  391. """
  392. # tools = _get_tools_by_type("price")
  393. # return _execute_and_get_result(
  394. # "价格查询", tools, token, backend_url, query, context
  395. # )
  396. return run_worker("price_worker", token, backend_url, query, context)
  397. @tool
  398. def salebill_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
  399. """
  400. 销售类单据查询专家
  401. 职责:销售类单据查询相关问题解答
  402. 使用场景:"查询客户A的订单进度"、"查询销售订单SG251119001"
  403. """
  404. # tools = _get_tools_by_type("salebill")
  405. # return _execute_and_get_result(
  406. # "销售类单据查询", tools, token, backend_url, query, context
  407. # )
  408. return run_worker("salebill_worker", token, backend_url, query, context)
  409. @tool
  410. def base_data_worker(
  411. query: str, backend_url: str, token: str, context: str = ""
  412. ) -> str:
  413. """
  414. 基础资料查询专家
  415. 职责:基础资料查询相关问题解答
  416. 使用场景:"查产品 A388餐椅"、"查一下客户 南华家具"
  417. """
  418. # tools = _get_tools_by_type("mtrl_data") + _get_tools_by_type("cust_data")
  419. # return _execute_and_get_result(
  420. # "基础资料查询", tools, token, backend_url, query, context
  421. # )
  422. return run_worker("base_data_worker", token, backend_url, query, context)
  423. def get_worker_tools() -> list:
  424. """获取所有Worker工具"""
  425. return [
  426. sale_worker,
  427. ware_worker,
  428. money_worker,
  429. price_worker,
  430. salebill_worker,
  431. base_data_worker,
  432. ]