worker_manager.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. from pdb import run
  2. import sys
  3. from langchain.tools import tool
  4. from langchain.agents import create_agent
  5. from langchain_openai import ChatOpenAI
  6. from typing import Dict, Any, Optional
  7. from openai import chat
  8. from config.settings import settings
  9. from langchain_core.messages import HumanMessage, SystemMessage
  10. import datetime
  11. from utils.context_helper import safe_context_param
  12. from utils.logger import chat_logger
  13. """
  14. 增加worker需要修改:
  15. worker_type_mapping
  16. """
  17. # WorkerAgent缓存字典
  18. _worker_cache = {}
  19. # 工具缓存字典
  20. _tools_cache = {}
  21. # worker_type, tool_types
  22. worker_type_mapping = {
  23. "sale_worker": ("销售报表查询", ["sale"]),
  24. "ware_worker": ("库存查询", ["ware"]),
  25. "money_worker": ("财务查询", ["money"]),
  26. "price_worker": ("价格查询", ["price"]),
  27. "salebill_worker": ("销售类单据查询", ["salebill"]),
  28. "base_data_worker": ("基础资料查询", ["mtrl_data", "cust_data"]),
  29. }
  30. class WorkerAgent:
  31. """Worker Agent - 专门处理特定领域的任务"""
  32. def __init__(self, worker_type: str, tools: list):
  33. self.worker_type = worker_type
  34. self.tools = tools
  35. self.llm = ChatOpenAI(
  36. model=settings.LLM_MODEL,
  37. temperature=settings.LLM_TEMPERATURE,
  38. api_key=settings.DEEPSEEK_API_KEY,
  39. base_url=settings.DEEPSEEK_BASE_URL,
  40. max_tokens=settings.LLM_MAX_TOKENS,
  41. )
  42. self.agent = self._create_agent()
  43. def _create_agent(self):
  44. """创建Agent实例"""
  45. return create_agent(
  46. self.llm,
  47. tools=self.tools,
  48. )
  49. def execute(self, query: str, backend_url: str, token: str) -> any:
  50. """执行任务"""
  51. try:
  52. system_prompt = f"""
  53. 你是{self.worker_type}专家,专门处理{self.worker_type}相关任务。
  54. 职责:
  55. - 接收主Agent分配的具体任务
  56. - 智能选择最合适的工具执行任务
  57. - 理解工具返回的数据结构
  58. - 组织完整的答案向用户回答
  59. 工具用到的参数:
  60. - backend_url: {backend_url}
  61. - token: {token}
  62. 当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
  63. 数据查询结果尽量以 Markdown 表格格式输出
  64. 如果工具有输出格式要求,必须严格按照要求输出。
  65. """
  66. if settings.ECHARTS_ENABLED:
  67. system_prompt = (
  68. system_prompt
  69. + """
  70. 你可以输出echarts柱状图、折线图、饼图。
  71. 饼图格式范例如下:
  72. ```echarts
  73. {{
  74. "title": {{
  75. "text": "浏览器份额", "left": "center" }}
  76. "tooltip": {{
  77. "trigger": "item" }},
  78. "legend": {{
  79. "orient": "vertical", "left": "left" }},
  80. "series": [
  81. {{
  82. "name": "Share",
  83. "type": "pie",
  84. "radius": "55%",
  85. "center": ["50%", "60%"],
  86. "data": [
  87. {{"value": 1048, "name": "Chrome" }}
  88. {{"value": 735, "name": "Firefox" }}
  89. {{"value": 580, "name": "Edge" }}
  90. ]
  91. }}
  92. ]
  93. }}
  94. ```
  95. 柱状图格式范例如下:
  96. ```echarts
  97. {{
  98. "title": {{"text": "每周销量" }}
  99. "tooltip": {{}},
  100. "xAxis": {{"type": "category", "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"] }}
  101. "yAxis": {{"type": "value" }}
  102. "series": [
  103. {{"type": "bar", "data": [120, 200, 150, 80, 70, 110, 130] }}
  104. ]
  105. }}
  106. ```
  107. 折线图,格式范例如下:
  108. ```echarts
  109. {{
  110. "title": {{ "text": "温度趋势" }},
  111. "tooltip": {{ "trigger": "axis" }},
  112. "legend": {{ "data": ["最高", "最低"] }},
  113. "xAxis": {{
  114. "type": "category",
  115. "boundaryGap": false,
  116. "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"]
  117. }},
  118. "yAxis": {{"type": "value" }},
  119. "series": [
  120. {{"name": "最高", "type": "line", "data": [11, 11, 15, 13, 12, 13, 10], "smooth": true }}
  121. {{"name": "最低", "type": "line", "data": [1, -2, 2, 5, 3, 2, 0], "smooth": true }}
  122. ]
  123. }}
  124. ```
  125. """
  126. )
  127. inputs = {
  128. "messages": [
  129. SystemMessage(content=system_prompt),
  130. HumanMessage(content=query),
  131. ]
  132. }
  133. response = self.agent.invoke(inputs)
  134. return response
  135. except Exception as e:
  136. return f"Worker执行失败: {str(e)}"
  137. # 自动发现指定类型的工具
  138. def _get_tools_by_type(tool_type: str) -> list:
  139. """根据类型自动发现工具 - 完整支持编译前后的文件"""
  140. # 检查缓存
  141. if tool_type in _tools_cache:
  142. chat_logger.info(f"使用缓存的工具: {tool_type}")
  143. return _tools_cache[tool_type]
  144. chat_logger.info(f"发现并加载工具: {tool_type}")
  145. import importlib
  146. import inspect
  147. from pathlib import Path
  148. import sys
  149. from langchain.tools import BaseTool
  150. tools = []
  151. # 获取项目根目录并添加到Python路径
  152. project_root = Path(__file__).parent.parent
  153. if str(project_root) not in sys.path:
  154. sys.path.insert(0, str(project_root))
  155. tools_dir = project_root / "tools"
  156. # 导入配置以检查知识库开关
  157. from config.settings import settings
  158. # 扫描工具文件
  159. tool_files = []
  160. # 模式1: 编译前的.py文件
  161. for file_path in tools_dir.glob("*_tools.py"):
  162. if file_path.is_file():
  163. module_name = file_path.stem
  164. # 检查文件名是否以指定类型开头(前缀匹配)
  165. if module_name.startswith(f"{tool_type}_"):
  166. # 如果知识库被禁用,跳过知识库工具
  167. if (
  168. module_name == "knowledge_tools"
  169. and not settings.KNOWLEDGE_BASE_ENABLED
  170. ):
  171. print("知识库功能已禁用,跳过知识库工具")
  172. continue
  173. tool_files.append(module_name)
  174. print(f"发现工具文件: {module_name}")
  175. # 模式2: 编译后的.pyd文件
  176. for file_path in tools_dir.glob("*_tools.cp*.pyd"):
  177. if file_path.is_file():
  178. # 从文件名中提取模块名,如: ware_tools.cp313-win_amd64.pyd -> ware_tools
  179. module_name = file_path.stem.split(".")[0]
  180. # 检查文件名是否以指定类型开头(前缀匹配)
  181. if module_name.startswith(f"{tool_type}_"):
  182. # 如果知识库被禁用,跳过知识库工具
  183. if (
  184. module_name == "knowledge_tools"
  185. and not settings.KNOWLEDGE_BASE_ENABLED
  186. ):
  187. print("知识库功能已禁用,跳过知识库工具")
  188. continue
  189. if module_name not in tool_files: # 避免重复添加
  190. tool_files.append(module_name)
  191. print(f"发现编译后工具文件: {module_name}")
  192. # 如果没有找到指定类型的工具文件,使用默认列表
  193. if not tool_files:
  194. # 根据类型提供默认工具列表
  195. default_tool_mapping = {
  196. "sale": ["sale_tools"],
  197. "ware": ["ware_tools"],
  198. "money": ["money_tools"],
  199. "price": ["price_tools"],
  200. "knowledge": ["knowledge_tools"] if settings.KNOWLEDGE_BASE_ENABLED else [],
  201. }
  202. if tool_type in default_tool_mapping:
  203. tool_files = default_tool_mapping[tool_type]
  204. print(f"使用默认工具列表: {tool_files}")
  205. # 导入模块并获取工具
  206. for module_name in tool_files:
  207. try:
  208. full_module_path = f"tools.{module_name}"
  209. module = importlib.import_module(full_module_path)
  210. # 使用更全面的工具发现方法(参考tool_factory.py)
  211. tool_count = 0
  212. # 方法1: 检查模块的所有属性
  213. for attr_name in dir(module):
  214. if attr_name.startswith("_"):
  215. continue
  216. attr = getattr(module, attr_name)
  217. # 检查是否是BaseTool实例
  218. if isinstance(attr, BaseTool):
  219. tools.append(attr)
  220. tool_count += 1
  221. continue
  222. # 检查是否是函数且具有工具属性
  223. if callable(attr) and hasattr(attr, "name"):
  224. tools.append(attr)
  225. tool_count += 1
  226. # 方法2: 检查模块中是否有get_all_tools函数
  227. if hasattr(module, "get_all_tools"):
  228. module_tools = module.get_all_tools()
  229. if isinstance(module_tools, list):
  230. tools.extend(module_tools)
  231. tool_count += len(module_tools)
  232. if tool_count > 0:
  233. print(f"✅ 从 {module_name} 加载了 {tool_count} 个工具")
  234. else:
  235. print(f"⚠️ {module_name} 中未发现工具函数")
  236. except Exception as e:
  237. print(f"❌ 导入模块 {module_name} 失败: {e}")
  238. # 缓存工具
  239. _tools_cache[tool_type] = tools
  240. return tools
  241. def _get_worker_agent(worker_type: str, tools: list) -> WorkerAgent:
  242. """获取或创建WorkerAgent实例(使用缓存)"""
  243. cache_key = f"{worker_type}" # 使用worker类型作为缓存键
  244. if cache_key not in _worker_cache:
  245. chat_logger.info(f"创建新的WorkerAgent缓存: {worker_type}")
  246. _worker_cache[cache_key] = WorkerAgent(worker_type, tools)
  247. else:
  248. chat_logger.info(f"使用缓存的WorkerAgent: {worker_type}")
  249. return _worker_cache[cache_key]
  250. def _execute_and_get_result(
  251. worker_type: str,
  252. tools: list,
  253. token: str,
  254. backend_url: str,
  255. query: str,
  256. context: str = "",
  257. ) -> str:
  258. """执行worker并获取结果"""
  259. try:
  260. chat_logger.info(f"开始执行worker: {worker_type}")
  261. # worker = WorkerAgent(worker_type, tools)
  262. worker = _get_worker_agent(worker_type, tools)
  263. # 如果有上下文,合并到查询中
  264. # chat_logger.info(f"原始上下文: {context}")
  265. safe_context = safe_context_param(context)
  266. if safe_context:
  267. enhanced_query = f"对话上下文:{safe_context}\n当前问题:{query}"
  268. else:
  269. enhanced_query = query
  270. # chat_logger.info(f"处理后的上下文: {safe_context}")
  271. result = worker.execute(enhanced_query, backend_url, token)
  272. chat_logger.info(f"Worker{worker_type}执行成功:问题{query}")
  273. last_ai_content = ""
  274. all_messages = result["messages"]
  275. for i, msg in enumerate(all_messages):
  276. msg_data = {
  277. "index": i,
  278. "type": getattr(msg, "type", "unknown"),
  279. "content": "",
  280. }
  281. # 获取内容
  282. if hasattr(msg, "content"):
  283. content = msg.content
  284. if isinstance(content, str):
  285. msg_data["content"] = content
  286. else:
  287. msg_data["content"] = str(content)
  288. # 收集AI消息
  289. if msg_data["type"] == "ai":
  290. last_ai_content = msg_data["content"]
  291. return last_ai_content
  292. # ai_messages = [
  293. # msg
  294. # for msg in result.get("messages", [])
  295. # if hasattr(msg, "type") and msg.type == "ai"
  296. # ]
  297. # if ai_messages:
  298. # # 获取最后一条 AI 消息的内容
  299. # last_ai_content = ai_messages[-1].content
  300. # chat_logger.info(
  301. # f"Worker{worker_type}执行成功:last_ai_content={last_ai_content}"
  302. # )
  303. # return last_ai_content
  304. # else:
  305. # return ""
  306. except Exception as e:
  307. chat_logger.error(
  308. f"Worker{worker_type}执行失败:问题{query},错误信息: {str(e)}"
  309. )
  310. return f"Worker{worker_type}执行失败:问题{query},错误信息: {str(e)}"
  311. def init_worker_cache() -> None:
  312. """初始化Worker缓存(在服务启动时调用)"""
  313. chat_logger.info("开始初始化Worker缓存...")
  314. # 直接遍历worker_type_mapping字典来初始化缓存
  315. for worker_name, (worker_type, tool_types) in worker_type_mapping.items():
  316. try:
  317. # 组合工具
  318. tools = []
  319. for tool_type in tool_types:
  320. tools.extend(_get_tools_by_type(tool_type))
  321. if tools:
  322. # 创建WorkerAgent实例并缓存
  323. cache_key = f"{worker_type}"
  324. if cache_key not in _worker_cache:
  325. _worker_cache[cache_key] = WorkerAgent(worker_type, tools)
  326. chat_logger.info(
  327. f"预缓存WorkerAgent: {worker_type} ({worker_name})"
  328. )
  329. else:
  330. chat_logger.info(f"WorkerAgent已缓存: {worker_type}")
  331. else:
  332. chat_logger.info(f"未发现工具,跳过缓存: {worker_type}")
  333. except Exception as e:
  334. chat_logger.info(f"预缓存WorkerAgent失败 {worker_name}: {e}")
  335. chat_logger.info(f"Worker缓存初始化完成,共缓存 {len(_worker_cache)} 个WorkerAgent")
  336. def clear_worker_cache() -> None:
  337. """清除Worker缓存"""
  338. global _worker_cache, _tools_cache
  339. _worker_cache.clear()
  340. _tools_cache.clear()
  341. chat_logger.info("Worker缓存已清除")
  342. def get_worker_cache_stats() -> dict:
  343. """获取Worker缓存统计信息"""
  344. return {
  345. "worker_agents_cached": len(_worker_cache),
  346. "tools_cached": len(_tools_cache),
  347. "cached_worker_types": list(_worker_cache.keys()),
  348. "cached_tool_types": list(_tools_cache.keys()),
  349. }
  350. def get_worker_config(worker_name: str) -> tuple[str, list[str]]:
  351. """根据worker名称获取配置信息"""
  352. if worker_name not in worker_type_mapping:
  353. raise ValueError(f"未知的worker名称: {worker_name}")
  354. return worker_type_mapping[worker_name]
  355. def run_worker(
  356. worker_name: str,
  357. token: str,
  358. backend_url: str,
  359. query: str,
  360. context: str = "",
  361. ) -> str:
  362. worker_type, tool_types = get_worker_config(worker_name)
  363. tools = []
  364. for tool_type in tool_types:
  365. tools.extend(_get_tools_by_type(tool_type))
  366. return _execute_and_get_result(
  367. worker_type, tools, token, backend_url, query, context
  368. )
  369. # Worker定义
  370. @tool
  371. def sale_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
  372. """
  373. 销售报表查询专家
  374. 职责:销售金额、销售数量、客户销售排名、业务员销售排名、产品型号销售数据等
  375. 使用场景:"2024客户销售业绩排名"、"查2024年产品型号销售排名"、"查看2023年1月1日至2023年12月31日的销售金额"等
  376. """
  377. # tools = _get_tools_by_type("sale") # 即 sale_tools.py下的所有工具
  378. # return _execute_and_get_result(
  379. # "销售报表查询", tools, token, backend_url, query, context
  380. # )
  381. return run_worker("sale_worker", token, backend_url, query, context)
  382. @tool
  383. def ware_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
  384. """
  385. 库存查询专家
  386. 职责:库存管理查询相关问题解答
  387. 使用场景:"查看铜管的库存"、"安吉仓库最多的库存哪个型号"
  388. """
  389. # tools = _get_tools_by_type("ware") # 即 ware_tools.py下的所有工具
  390. # return _execute_and_get_result(
  391. # "库存查询", tools, token, backend_url, query, context
  392. # )
  393. return run_worker("ware_worker", token, backend_url, query, context)
  394. @tool
  395. def money_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
  396. """
  397. 财务查询专家
  398. 职责:财务查询相关问题解答
  399. 使用场景:"查询客户A的应收帐情况"、"查询客户A,2025年的收款情况"
  400. """
  401. # 图表工具可以复用数据查询工具的结果
  402. # tools = _get_tools_by_type("money") # 即 money_tools.py下的所有工具
  403. # return _execute_and_get_result(
  404. # "财务查询", tools, token, backend_url, query, context
  405. # )
  406. return run_worker("money_worker", token, backend_url, query, context)
  407. @tool
  408. def price_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
  409. """
  410. 价格查询专家
  411. 职责:价格查询相关问题解答
  412. 使用场景:"查询铜管的销售价格"
  413. """
  414. # tools = _get_tools_by_type("price")
  415. # return _execute_and_get_result(
  416. # "价格查询", tools, token, backend_url, query, context
  417. # )
  418. return run_worker("price_worker", token, backend_url, query, context)
  419. @tool
  420. def salebill_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
  421. """
  422. 销售类单据查询专家
  423. 职责:销售类单据查询相关问题解答
  424. 使用场景:"查询客户A的订单进度"、"查询销售订单SG251119001"
  425. """
  426. # tools = _get_tools_by_type("salebill")
  427. # return _execute_and_get_result(
  428. # "销售类单据查询", tools, token, backend_url, query, context
  429. # )
  430. return run_worker("salebill_worker", token, backend_url, query, context)
  431. @tool
  432. def base_data_worker(
  433. query: str, backend_url: str, token: str, context: str = ""
  434. ) -> str:
  435. """
  436. 基础资料查询专家
  437. 职责:基础资料查询相关问题解答
  438. 使用场景:"查产品 A388餐椅"、"查一下客户 南华家具"
  439. """
  440. # tools = _get_tools_by_type("mtrl_data") + _get_tools_by_type("cust_data")
  441. # return _execute_and_get_result(
  442. # "基础资料查询", tools, token, backend_url, query, context
  443. # )
  444. return run_worker("base_data_worker", token, backend_url, query, context)
  445. def get_worker_tools() -> list:
  446. """获取所有Worker工具"""
  447. return [
  448. sale_worker,
  449. ware_worker,
  450. money_worker,
  451. price_worker,
  452. salebill_worker,
  453. base_data_worker,
  454. ]