from pdb import run import sys from langchain.tools import tool from langchain.agents import create_agent from langchain_openai import ChatOpenAI from typing import Dict, Any, Optional from openai import chat from config.settings import settings from langchain_core.messages import HumanMessage, SystemMessage import datetime from utils.context_helper import safe_context_param from utils.logger import chat_logger """ 增加worker需要修改: worker_type_mapping """ # WorkerAgent缓存字典 _worker_cache = {} # 工具缓存字典 _tools_cache = {} # worker_type, tool_types worker_type_mapping = { "sale_worker": ("销售报表查询", ["sale"]), "ware_worker": ("库存查询", ["ware"]), "money_worker": ("财务查询", ["money"]), "price_worker": ("价格查询", ["price"]), "salebill_worker": ("销售类单据查询", ["salebill"]), "base_data_worker": ("基础资料查询", ["mtrl_data", "cust_data"]), } class WorkerAgent: """Worker Agent - 专门处理特定领域的任务""" def __init__(self, worker_type: str, tools: list): self.worker_type = worker_type self.tools = tools self.llm = ChatOpenAI( model=settings.LLM_MODEL, temperature=settings.LLM_TEMPERATURE, api_key=settings.DEEPSEEK_API_KEY, base_url=settings.DEEPSEEK_BASE_URL, max_tokens=settings.LLM_MAX_TOKENS, ) self.agent = self._create_agent() def _create_agent(self): """创建Agent实例""" return create_agent( self.llm, tools=self.tools, ) def execute(self, query: str, backend_url: str, token: str) -> any: """执行任务""" try: system_prompt = f""" 你是{self.worker_type}专家,专门处理{self.worker_type}相关任务。 职责: - 接收主Agent分配的具体任务 - 智能选择最合适的工具执行任务 - 理解工具返回的数据结构 - 组织完整的答案向用户回答 工具用到的参数: - backend_url: {backend_url} - token: {token} 当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} 数据查询结果尽量以 Markdown 表格格式输出 如果工具有输出格式要求,必须严格按照要求输出。 """ if settings.ECHARTS_ENABLED: system_prompt = ( system_prompt + """ 你可以输出echarts柱状图、折线图、饼图。 饼图格式范例如下: ```echarts {{ "title": {{ "text": "浏览器份额", "left": "center" }} "tooltip": {{ "trigger": "item" }}, "legend": {{ "orient": "vertical", "left": "left" }}, "series": [ {{ "name": "Share", "type": "pie", "radius": "55%", "center": ["50%", "60%"], "data": [ {{"value": 1048, "name": "Chrome" }} {{"value": 735, "name": "Firefox" }} {{"value": 580, "name": "Edge" }} ] }} ] }} ``` 柱状图格式范例如下: ```echarts {{ "title": {{"text": "每周销量" }} "tooltip": {{}}, "xAxis": {{"type": "category", "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"] }} "yAxis": {{"type": "value" }} "series": [ {{"type": "bar", "data": [120, 200, 150, 80, 70, 110, 130] }} ] }} ``` 折线图,格式范例如下: ```echarts {{ "title": {{ "text": "温度趋势" }}, "tooltip": {{ "trigger": "axis" }}, "legend": {{ "data": ["最高", "最低"] }}, "xAxis": {{ "type": "category", "boundaryGap": false, "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"] }}, "yAxis": {{"type": "value" }}, "series": [ {{"name": "最高", "type": "line", "data": [11, 11, 15, 13, 12, 13, 10], "smooth": true }} {{"name": "最低", "type": "line", "data": [1, -2, 2, 5, 3, 2, 0], "smooth": true }} ] }} ``` """ ) inputs = { "messages": [ SystemMessage(content=system_prompt), HumanMessage(content=query), ] } response = self.agent.invoke(inputs) return response except Exception as e: return f"Worker执行失败: {str(e)}" # 自动发现指定类型的工具 def _get_tools_by_type(tool_type: str) -> list: """根据类型自动发现工具 - 完整支持编译前后的文件""" # 检查缓存 if tool_type in _tools_cache: chat_logger.info(f"使用缓存的工具: {tool_type}") return _tools_cache[tool_type] chat_logger.info(f"发现并加载工具: {tool_type}") import importlib import inspect from pathlib import Path import sys from langchain.tools import BaseTool tools = [] # 获取项目根目录并添加到Python路径 project_root = Path(__file__).parent.parent if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) tools_dir = project_root / "tools" # 导入配置以检查知识库开关 from config.settings import settings # 扫描工具文件 tool_files = [] # 模式1: 编译前的.py文件 for file_path in tools_dir.glob("*_tools.py"): if file_path.is_file(): module_name = file_path.stem # 检查文件名是否以指定类型开头(前缀匹配) if module_name.startswith(f"{tool_type}_"): # 如果知识库被禁用,跳过知识库工具 if ( module_name == "knowledge_tools" and not settings.KNOWLEDGE_BASE_ENABLED ): print("知识库功能已禁用,跳过知识库工具") continue tool_files.append(module_name) print(f"发现工具文件: {module_name}") # 模式2: 编译后的.pyd文件 for file_path in tools_dir.glob("*_tools.cp*.pyd"): if file_path.is_file(): # 从文件名中提取模块名,如: ware_tools.cp313-win_amd64.pyd -> ware_tools module_name = file_path.stem.split(".")[0] # 检查文件名是否以指定类型开头(前缀匹配) if module_name.startswith(f"{tool_type}_"): # 如果知识库被禁用,跳过知识库工具 if ( module_name == "knowledge_tools" and not settings.KNOWLEDGE_BASE_ENABLED ): print("知识库功能已禁用,跳过知识库工具") continue if module_name not in tool_files: # 避免重复添加 tool_files.append(module_name) print(f"发现编译后工具文件: {module_name}") # 如果没有找到指定类型的工具文件,使用默认列表 if not tool_files: # 根据类型提供默认工具列表 default_tool_mapping = { "sale": ["sale_tools"], "ware": ["ware_tools"], "money": ["money_tools"], "price": ["price_tools"], "knowledge": ["knowledge_tools"] if settings.KNOWLEDGE_BASE_ENABLED else [], } if tool_type in default_tool_mapping: tool_files = default_tool_mapping[tool_type] print(f"使用默认工具列表: {tool_files}") # 导入模块并获取工具 for module_name in tool_files: try: full_module_path = f"tools.{module_name}" module = importlib.import_module(full_module_path) # 使用更全面的工具发现方法(参考tool_factory.py) tool_count = 0 # 方法1: 检查模块的所有属性 for attr_name in dir(module): if attr_name.startswith("_"): continue attr = getattr(module, attr_name) # 检查是否是BaseTool实例 if isinstance(attr, BaseTool): tools.append(attr) tool_count += 1 continue # 检查是否是函数且具有工具属性 if callable(attr) and hasattr(attr, "name"): tools.append(attr) tool_count += 1 # 方法2: 检查模块中是否有get_all_tools函数 if hasattr(module, "get_all_tools"): module_tools = module.get_all_tools() if isinstance(module_tools, list): tools.extend(module_tools) tool_count += len(module_tools) if tool_count > 0: print(f"✅ 从 {module_name} 加载了 {tool_count} 个工具") else: print(f"⚠️ {module_name} 中未发现工具函数") except Exception as e: print(f"❌ 导入模块 {module_name} 失败: {e}") # 缓存工具 _tools_cache[tool_type] = tools return tools def _get_worker_agent(worker_type: str, tools: list) -> WorkerAgent: """获取或创建WorkerAgent实例(使用缓存)""" cache_key = f"{worker_type}" # 使用worker类型作为缓存键 if cache_key not in _worker_cache: chat_logger.info(f"创建新的WorkerAgent缓存: {worker_type}") _worker_cache[cache_key] = WorkerAgent(worker_type, tools) else: chat_logger.info(f"使用缓存的WorkerAgent: {worker_type}") return _worker_cache[cache_key] def _execute_and_get_result( worker_type: str, tools: list, token: str, backend_url: str, query: str, context: str = "", ) -> str: """执行worker并获取结果""" try: chat_logger.info(f"开始执行worker: {worker_type}") # worker = WorkerAgent(worker_type, tools) worker = _get_worker_agent(worker_type, tools) # 如果有上下文,合并到查询中 # chat_logger.info(f"原始上下文: {context}") safe_context = safe_context_param(context) if safe_context: enhanced_query = f"对话上下文:{safe_context}\n当前问题:{query}" else: enhanced_query = query # chat_logger.info(f"处理后的上下文: {safe_context}") result = worker.execute(enhanced_query, backend_url, token) chat_logger.info(f"Worker{worker_type}执行成功:问题{query}") ai_messages = [ msg for msg in result.get("messages", []) if hasattr(msg, "type") and msg.type == "ai" ] if ai_messages: # 获取最后一条 AI 消息的内容 last_ai_content = ai_messages[-1].content return last_ai_content else: return "" except Exception as e: chat_logger.error( f"Worker{worker_type}执行失败:问题{query},错误信息: {str(e)}" ) return f"Worker{worker_type}执行失败:问题{query},错误信息: {str(e)}" def init_worker_cache() -> None: """初始化Worker缓存(在服务启动时调用)""" chat_logger.info("开始初始化Worker缓存...") # 直接遍历worker_type_mapping字典来初始化缓存 for worker_name, (worker_type, tool_types) in worker_type_mapping.items(): try: # 组合工具 tools = [] for tool_type in tool_types: tools.extend(_get_tools_by_type(tool_type)) if tools: # 创建WorkerAgent实例并缓存 cache_key = f"{worker_type}" if cache_key not in _worker_cache: _worker_cache[cache_key] = WorkerAgent(worker_type, tools) chat_logger.info( f"预缓存WorkerAgent: {worker_type} ({worker_name})" ) else: chat_logger.info(f"WorkerAgent已缓存: {worker_type}") else: chat_logger.info(f"未发现工具,跳过缓存: {worker_type}") except Exception as e: chat_logger.info(f"预缓存WorkerAgent失败 {worker_name}: {e}") chat_logger.info(f"Worker缓存初始化完成,共缓存 {len(_worker_cache)} 个WorkerAgent") def clear_worker_cache() -> None: """清除Worker缓存""" global _worker_cache, _tools_cache _worker_cache.clear() _tools_cache.clear() chat_logger.info("Worker缓存已清除") def get_worker_cache_stats() -> dict: """获取Worker缓存统计信息""" return { "worker_agents_cached": len(_worker_cache), "tools_cached": len(_tools_cache), "cached_worker_types": list(_worker_cache.keys()), "cached_tool_types": list(_tools_cache.keys()), } def get_worker_config(worker_name: str) -> tuple[str, list[str]]: """根据worker名称获取配置信息""" if worker_name not in worker_type_mapping: raise ValueError(f"未知的worker名称: {worker_name}") return worker_type_mapping[worker_name] def run_worker( worker_name: str, token: str, backend_url: str, query: str, context: str = "", ) -> str: worker_type, tool_types = get_worker_config(worker_name) tools = [] for tool_type in tool_types: tools.extend(_get_tools_by_type(tool_type)) return _execute_and_get_result( worker_type, tools, token, backend_url, query, context ) # Worker定义 @tool def sale_worker(query: str, backend_url: str, token: str, context: str = "") -> str: """ 销售报表查询专家 职责:销售金额、销售数量、客户销售排名、业务员销售排名、产品型号销售数据等 使用场景:"2024客户销售业绩排名"、"查2024年产品型号销售排名"、"查看2023年1月1日至2023年12月31日的销售金额"等 """ # tools = _get_tools_by_type("sale") # 即 sale_tools.py下的所有工具 # return _execute_and_get_result( # "销售报表查询", tools, token, backend_url, query, context # ) return run_worker("sale_worker", token, backend_url, query, context) @tool def ware_worker(query: str, backend_url: str, token: str, context: str = "") -> str: """ 库存查询专家 职责:库存管理查询相关问题解答 使用场景:"查看铜管的库存"、"安吉仓库最多的库存哪个型号" """ # tools = _get_tools_by_type("ware") # 即 ware_tools.py下的所有工具 # return _execute_and_get_result( # "库存查询", tools, token, backend_url, query, context # ) return run_worker("ware_worker", token, backend_url, query, context) @tool def money_worker(query: str, backend_url: str, token: str, context: str = "") -> str: """ 财务查询专家 职责:财务查询相关问题解答 使用场景:"查询客户A的应收帐情况"、"查询客户A,2025年的收款情况" """ # 图表工具可以复用数据查询工具的结果 # tools = _get_tools_by_type("money") # 即 money_tools.py下的所有工具 # return _execute_and_get_result( # "财务查询", tools, token, backend_url, query, context # ) return run_worker("money_worker", token, backend_url, query, context) @tool def price_worker(query: str, backend_url: str, token: str, context: str = "") -> str: """ 价格查询专家 职责:价格查询相关问题解答 使用场景:"查询铜管的销售价格" """ # tools = _get_tools_by_type("price") # return _execute_and_get_result( # "价格查询", tools, token, backend_url, query, context # ) return run_worker("price_worker", token, backend_url, query, context) @tool def salebill_worker(query: str, backend_url: str, token: str, context: str = "") -> str: """ 销售类单据查询专家 职责:销售类单据查询相关问题解答 使用场景:"查询客户A的订单进度"、"查询销售订单SG251119001" """ # tools = _get_tools_by_type("salebill") # return _execute_and_get_result( # "销售类单据查询", tools, token, backend_url, query, context # ) return run_worker("salebill_worker", token, backend_url, query, context) @tool def base_data_worker( query: str, backend_url: str, token: str, context: str = "" ) -> str: """ 基础资料查询专家 职责:基础资料查询相关问题解答 使用场景:"查产品 A388餐椅"、"查一下客户 南华家具" """ # tools = _get_tools_by_type("mtrl_data") + _get_tools_by_type("cust_data") # return _execute_and_get_result( # "基础资料查询", tools, token, backend_url, query, context # ) return run_worker("base_data_worker", token, backend_url, query, context) def get_worker_tools() -> list: """获取所有Worker工具""" return [ sale_worker, ware_worker, money_worker, price_worker, salebill_worker, base_data_worker, ]