| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549 |
- from pdb import run
- import sys
- from langchain.tools import tool
- from langchain.agents import create_agent
- from langchain_openai import ChatOpenAI
- from typing import Dict, Any, Optional
- from openai import chat
- from config.settings import settings
- from langchain_core.messages import HumanMessage, SystemMessage
- import datetime
- from utils.context_helper import safe_context_param
- from utils.logger import chat_logger
- """
- 增加worker需要修改:
- worker_type_mapping
- """
- # WorkerAgent缓存字典
- _worker_cache = {}
- # 工具缓存字典
- _tools_cache = {}
- # worker_type, tool_types
- worker_type_mapping = {
- "sale_worker": ("销售报表查询", ["sale"]),
- "ware_worker": ("库存查询", ["ware"]),
- "money_worker": ("财务查询", ["money"]),
- "price_worker": ("价格查询", ["price"]),
- "salebill_worker": ("销售类单据查询", ["salebill"]),
- "base_data_worker": ("基础资料查询", ["mtrl_data", "cust_data"]),
- }
- class WorkerAgent:
- """Worker Agent - 专门处理特定领域的任务"""
- def __init__(self, worker_type: str, tools: list):
- self.worker_type = worker_type
- self.tools = tools
- self.llm = ChatOpenAI(
- model=settings.LLM_MODEL,
- temperature=settings.LLM_TEMPERATURE,
- api_key=settings.DEEPSEEK_API_KEY,
- base_url=settings.DEEPSEEK_BASE_URL,
- max_tokens=settings.LLM_MAX_TOKENS,
- )
- self.agent = self._create_agent()
- def _create_agent(self):
- """创建Agent实例"""
- return create_agent(
- self.llm,
- tools=self.tools,
- )
- def execute(self, query: str, backend_url: str, token: str) -> any:
- """执行任务"""
- try:
- system_prompt = f"""
- 你是{self.worker_type}专家,专门处理{self.worker_type}相关任务。
- 职责:
- - 接收主Agent分配的具体任务
- - 智能选择最合适的工具执行任务
- - 理解工具返回的数据结构
- - 组织完整的答案向用户回答
- 工具用到的参数:
- - backend_url: {backend_url}
- - token: {token}
- 当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
- 数据查询结果尽量以 Markdown 表格格式输出
- 如果工具有输出格式要求,必须严格按照要求输出。
- """
- if settings.ECHARTS_ENABLED:
- system_prompt = (
- system_prompt
- + """
- 你可以输出echarts柱状图、折线图、饼图。
- 饼图格式范例如下:
- ```echarts
- {{
- "title": {{
- "text": "浏览器份额", "left": "center" }}
- "tooltip": {{
- "trigger": "item" }},
- "legend": {{
- "orient": "vertical", "left": "left" }},
- "series": [
- {{
- "name": "Share",
- "type": "pie",
- "radius": "55%",
- "center": ["50%", "60%"],
- "data": [
- {{"value": 1048, "name": "Chrome" }}
- {{"value": 735, "name": "Firefox" }}
- {{"value": 580, "name": "Edge" }}
- ]
- }}
- ]
- }}
- ```
- 柱状图格式范例如下:
- ```echarts
- {{
- "title": {{"text": "每周销量" }}
- "tooltip": {{}},
- "xAxis": {{"type": "category", "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"] }}
- "yAxis": {{"type": "value" }}
- "series": [
- {{"type": "bar", "data": [120, 200, 150, 80, 70, 110, 130] }}
- ]
- }}
- ```
- 折线图,格式范例如下:
- ```echarts
- {{
- "title": {{ "text": "温度趋势" }},
- "tooltip": {{ "trigger": "axis" }},
- "legend": {{ "data": ["最高", "最低"] }},
- "xAxis": {{
- "type": "category",
- "boundaryGap": false,
- "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"]
- }},
- "yAxis": {{"type": "value" }},
- "series": [
- {{"name": "最高", "type": "line", "data": [11, 11, 15, 13, 12, 13, 10], "smooth": true }}
- {{"name": "最低", "type": "line", "data": [1, -2, 2, 5, 3, 2, 0], "smooth": true }}
- ]
- }}
- ```
- """
- )
- inputs = {
- "messages": [
- SystemMessage(content=system_prompt),
- HumanMessage(content=query),
- ]
- }
- response = self.agent.invoke(inputs)
- return response
- except Exception as e:
- return f"Worker执行失败: {str(e)}"
- # 自动发现指定类型的工具
- def _get_tools_by_type(tool_type: str) -> list:
- """根据类型自动发现工具 - 完整支持编译前后的文件"""
- # 检查缓存
- if tool_type in _tools_cache:
- chat_logger.info(f"使用缓存的工具: {tool_type}")
- return _tools_cache[tool_type]
- chat_logger.info(f"发现并加载工具: {tool_type}")
- import importlib
- import inspect
- from pathlib import Path
- import sys
- from langchain.tools import BaseTool
- tools = []
- # 获取项目根目录并添加到Python路径
- project_root = Path(__file__).parent.parent
- if str(project_root) not in sys.path:
- sys.path.insert(0, str(project_root))
- tools_dir = project_root / "tools"
- # 导入配置以检查知识库开关
- from config.settings import settings
- # 扫描工具文件
- tool_files = []
- # 模式1: 编译前的.py文件
- for file_path in tools_dir.glob("*_tools.py"):
- if file_path.is_file():
- module_name = file_path.stem
- # 检查文件名是否以指定类型开头(前缀匹配)
- if module_name.startswith(f"{tool_type}_"):
- # 如果知识库被禁用,跳过知识库工具
- if (
- module_name == "knowledge_tools"
- and not settings.KNOWLEDGE_BASE_ENABLED
- ):
- print("知识库功能已禁用,跳过知识库工具")
- continue
- tool_files.append(module_name)
- print(f"发现工具文件: {module_name}")
- # 模式2: 编译后的.pyd文件
- for file_path in tools_dir.glob("*_tools.cp*.pyd"):
- if file_path.is_file():
- # 从文件名中提取模块名,如: ware_tools.cp313-win_amd64.pyd -> ware_tools
- module_name = file_path.stem.split(".")[0]
- # 检查文件名是否以指定类型开头(前缀匹配)
- if module_name.startswith(f"{tool_type}_"):
- # 如果知识库被禁用,跳过知识库工具
- if (
- module_name == "knowledge_tools"
- and not settings.KNOWLEDGE_BASE_ENABLED
- ):
- print("知识库功能已禁用,跳过知识库工具")
- continue
- if module_name not in tool_files: # 避免重复添加
- tool_files.append(module_name)
- print(f"发现编译后工具文件: {module_name}")
- # 如果没有找到指定类型的工具文件,使用默认列表
- if not tool_files:
- # 根据类型提供默认工具列表
- default_tool_mapping = {
- "sale": ["sale_tools"],
- "ware": ["ware_tools"],
- "money": ["money_tools"],
- "price": ["price_tools"],
- "knowledge": ["knowledge_tools"] if settings.KNOWLEDGE_BASE_ENABLED else [],
- }
- if tool_type in default_tool_mapping:
- tool_files = default_tool_mapping[tool_type]
- print(f"使用默认工具列表: {tool_files}")
- # 导入模块并获取工具
- for module_name in tool_files:
- try:
- full_module_path = f"tools.{module_name}"
- module = importlib.import_module(full_module_path)
- # 使用更全面的工具发现方法(参考tool_factory.py)
- tool_count = 0
- # 方法1: 检查模块的所有属性
- for attr_name in dir(module):
- if attr_name.startswith("_"):
- continue
- attr = getattr(module, attr_name)
- # 检查是否是BaseTool实例
- if isinstance(attr, BaseTool):
- tools.append(attr)
- tool_count += 1
- continue
- # 检查是否是函数且具有工具属性
- if callable(attr) and hasattr(attr, "name"):
- tools.append(attr)
- tool_count += 1
- # 方法2: 检查模块中是否有get_all_tools函数
- if hasattr(module, "get_all_tools"):
- module_tools = module.get_all_tools()
- if isinstance(module_tools, list):
- tools.extend(module_tools)
- tool_count += len(module_tools)
- if tool_count > 0:
- print(f"✅ 从 {module_name} 加载了 {tool_count} 个工具")
- else:
- print(f"⚠️ {module_name} 中未发现工具函数")
- except Exception as e:
- print(f"❌ 导入模块 {module_name} 失败: {e}")
- # 缓存工具
- _tools_cache[tool_type] = tools
- return tools
- def _get_worker_agent(worker_type: str, tools: list) -> WorkerAgent:
- """获取或创建WorkerAgent实例(使用缓存)"""
- cache_key = f"{worker_type}" # 使用worker类型作为缓存键
- if cache_key not in _worker_cache:
- chat_logger.info(f"创建新的WorkerAgent缓存: {worker_type}")
- _worker_cache[cache_key] = WorkerAgent(worker_type, tools)
- else:
- chat_logger.info(f"使用缓存的WorkerAgent: {worker_type}")
- return _worker_cache[cache_key]
- def _execute_and_get_result(
- worker_type: str,
- tools: list,
- token: str,
- backend_url: str,
- query: str,
- context: str = "",
- ) -> str:
- """执行worker并获取结果"""
- try:
- chat_logger.info(f"开始执行worker: {worker_type}")
- # worker = WorkerAgent(worker_type, tools)
- worker = _get_worker_agent(worker_type, tools)
- # 如果有上下文,合并到查询中
- # chat_logger.info(f"原始上下文: {context}")
- safe_context = safe_context_param(context)
- if safe_context:
- enhanced_query = f"对话上下文:{safe_context}\n当前问题:{query}"
- else:
- enhanced_query = query
- # chat_logger.info(f"处理后的上下文: {safe_context}")
- result = worker.execute(enhanced_query, backend_url, token)
- chat_logger.info(f"Worker{worker_type}执行成功:问题{query}")
- last_ai_content = ""
- all_messages = result["messages"]
- for i, msg in enumerate(all_messages):
- msg_data = {
- "index": i,
- "type": getattr(msg, "type", "unknown"),
- "content": "",
- }
- # 获取内容
- if hasattr(msg, "content"):
- content = msg.content
- if isinstance(content, str):
- msg_data["content"] = content
- else:
- msg_data["content"] = str(content)
- # 收集AI消息
- if msg_data["type"] == "ai":
- last_ai_content = msg_data["content"]
- return last_ai_content
- # ai_messages = [
- # msg
- # for msg in result.get("messages", [])
- # if hasattr(msg, "type") and msg.type == "ai"
- # ]
- # if ai_messages:
- # # 获取最后一条 AI 消息的内容
- # last_ai_content = ai_messages[-1].content
- # chat_logger.info(
- # f"Worker{worker_type}执行成功:last_ai_content={last_ai_content}"
- # )
- # return last_ai_content
- # else:
- # return ""
- except Exception as e:
- chat_logger.error(
- f"Worker{worker_type}执行失败:问题{query},错误信息: {str(e)}"
- )
- return f"Worker{worker_type}执行失败:问题{query},错误信息: {str(e)}"
- def init_worker_cache() -> None:
- """初始化Worker缓存(在服务启动时调用)"""
- chat_logger.info("开始初始化Worker缓存...")
- # 直接遍历worker_type_mapping字典来初始化缓存
- for worker_name, (worker_type, tool_types) in worker_type_mapping.items():
- try:
- # 组合工具
- tools = []
- for tool_type in tool_types:
- tools.extend(_get_tools_by_type(tool_type))
- if tools:
- # 创建WorkerAgent实例并缓存
- cache_key = f"{worker_type}"
- if cache_key not in _worker_cache:
- _worker_cache[cache_key] = WorkerAgent(worker_type, tools)
- chat_logger.info(
- f"预缓存WorkerAgent: {worker_type} ({worker_name})"
- )
- else:
- chat_logger.info(f"WorkerAgent已缓存: {worker_type}")
- else:
- chat_logger.info(f"未发现工具,跳过缓存: {worker_type}")
- except Exception as e:
- chat_logger.info(f"预缓存WorkerAgent失败 {worker_name}: {e}")
- chat_logger.info(f"Worker缓存初始化完成,共缓存 {len(_worker_cache)} 个WorkerAgent")
- def clear_worker_cache() -> None:
- """清除Worker缓存"""
- global _worker_cache, _tools_cache
- _worker_cache.clear()
- _tools_cache.clear()
- chat_logger.info("Worker缓存已清除")
- def get_worker_cache_stats() -> dict:
- """获取Worker缓存统计信息"""
- return {
- "worker_agents_cached": len(_worker_cache),
- "tools_cached": len(_tools_cache),
- "cached_worker_types": list(_worker_cache.keys()),
- "cached_tool_types": list(_tools_cache.keys()),
- }
- def get_worker_config(worker_name: str) -> tuple[str, list[str]]:
- """根据worker名称获取配置信息"""
- if worker_name not in worker_type_mapping:
- raise ValueError(f"未知的worker名称: {worker_name}")
- return worker_type_mapping[worker_name]
- def run_worker(
- worker_name: str,
- token: str,
- backend_url: str,
- query: str,
- context: str = "",
- ) -> str:
- worker_type, tool_types = get_worker_config(worker_name)
- tools = []
- for tool_type in tool_types:
- tools.extend(_get_tools_by_type(tool_type))
- return _execute_and_get_result(
- worker_type, tools, token, backend_url, query, context
- )
- # Worker定义
- @tool
- def sale_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
- """
- 销售报表查询专家
- 职责:销售金额、销售数量、客户销售排名、业务员销售排名、产品型号销售数据等
- 使用场景:"2024客户销售业绩排名"、"查2024年产品型号销售排名"、"查看2023年1月1日至2023年12月31日的销售金额"等
- """
- # tools = _get_tools_by_type("sale") # 即 sale_tools.py下的所有工具
- # return _execute_and_get_result(
- # "销售报表查询", tools, token, backend_url, query, context
- # )
- return run_worker("sale_worker", token, backend_url, query, context)
- @tool
- def ware_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
- """
- 库存查询专家
- 职责:库存管理查询相关问题解答
- 使用场景:"查看铜管的库存"、"安吉仓库最多的库存哪个型号"
- """
- # tools = _get_tools_by_type("ware") # 即 ware_tools.py下的所有工具
- # return _execute_and_get_result(
- # "库存查询", tools, token, backend_url, query, context
- # )
- return run_worker("ware_worker", token, backend_url, query, context)
- @tool
- def money_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
- """
- 财务查询专家
- 职责:财务查询相关问题解答
- 使用场景:"查询客户A的应收帐情况"、"查询客户A,2025年的收款情况"
- """
- # 图表工具可以复用数据查询工具的结果
- # tools = _get_tools_by_type("money") # 即 money_tools.py下的所有工具
- # return _execute_and_get_result(
- # "财务查询", tools, token, backend_url, query, context
- # )
- return run_worker("money_worker", token, backend_url, query, context)
- @tool
- def price_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
- """
- 价格查询专家
- 职责:价格查询相关问题解答
- 使用场景:"查询铜管的销售价格"
- """
- # tools = _get_tools_by_type("price")
- # return _execute_and_get_result(
- # "价格查询", tools, token, backend_url, query, context
- # )
- return run_worker("price_worker", token, backend_url, query, context)
- @tool
- def salebill_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
- """
- 销售类单据查询专家
- 职责:销售类单据查询相关问题解答
- 使用场景:"查询客户A的订单进度"、"查询销售订单SG251119001"
- """
- # tools = _get_tools_by_type("salebill")
- # return _execute_and_get_result(
- # "销售类单据查询", tools, token, backend_url, query, context
- # )
- return run_worker("salebill_worker", token, backend_url, query, context)
- @tool
- def base_data_worker(
- query: str, backend_url: str, token: str, context: str = ""
- ) -> str:
- """
- 基础资料查询专家
- 职责:基础资料查询相关问题解答
- 使用场景:"查产品 A388餐椅"、"查一下客户 南华家具"
- """
- # tools = _get_tools_by_type("mtrl_data") + _get_tools_by_type("cust_data")
- # return _execute_and_get_result(
- # "基础资料查询", tools, token, backend_url, query, context
- # )
- return run_worker("base_data_worker", token, backend_url, query, context)
- def get_worker_tools() -> list:
- """获取所有Worker工具"""
- return [
- sale_worker,
- ware_worker,
- money_worker,
- price_worker,
- salebill_worker,
- base_data_worker,
- ]
|