Sfoglia il codice sorgente

逻辑架构升级为主Agent带各部门的专家副Agent

longjoedyy 3 settimane fa
parent
commit
d6a539f8e1

+ 11 - 5
api/routes.py

@@ -378,6 +378,7 @@ async def root():
 async def batch_calculate_vectors_endpoint(requests: List[ImageVectorRequest]):
     """批量计算图片特征向量"""
     try:
+        chat_logger.info("以图搜图:开始批量计算图片特征向量")
         # 构建请求数据
         image_items = []
         for req in requests:
@@ -400,7 +401,7 @@ async def batch_calculate_vectors_endpoint(requests: List[ImageVectorRequest]):
         return responses
 
     except Exception as e:
-        chat_logger.error(f"批量计算图片特征向量失败: {str(e)}")
+        chat_logger.error(f"以图搜图:批量计算图片特征向量失败: {str(e)}")
         raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
 
 
@@ -408,6 +409,8 @@ async def batch_calculate_vectors_endpoint(requests: List[ImageVectorRequest]):
 async def build_index_endpoint(request: BuildIndexRequest):
     """构建索引及映射关系"""
     try:
+        chat_logger.info("以图搜图:开始构建索引")
+
         # 构建请求数据
         image_vectors = []
         for item in request.image_vectors:
@@ -429,7 +432,7 @@ async def build_index_endpoint(request: BuildIndexRequest):
         return response
 
     except Exception as e:
-        chat_logger.error(f"构建索引失败: {str(e)}")
+        chat_logger.error(f"以图搜图:构建索引失败: {str(e)}")
         raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
 
 
@@ -437,6 +440,7 @@ async def build_index_endpoint(request: BuildIndexRequest):
 async def search_endpoint(request: SearchRequest):
     """搜索相似图片(支持以图搜图和以文搜图)"""
     try:
+        chat_logger.info(f"以图搜图:开始搜索")
         import time
 
         start_time = time.time()
@@ -487,7 +491,7 @@ async def search_endpoint(request: SearchRequest):
     except HTTPException:
         raise
     except Exception as e:
-        chat_logger.error(f"搜索失败: {str(e)}")
+        chat_logger.error(f"以图搜图:搜索失败: {str(e)}")
         raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
 
 
@@ -495,11 +499,12 @@ async def search_endpoint(request: SearchRequest):
 async def get_index_status_endpoint():
     """获取索引状态"""
     try:
+        chat_logger.info("以图搜图:开始获取索引状态")
         status = await image_search_service.get_index_status()
         return {"success": True, "status": status}
 
     except Exception as e:
-        chat_logger.error(f"获取索引状态失败: {str(e)}")
+        chat_logger.error(f"以图搜图:获取索引状态失败: {str(e)}")
         raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
 
 
@@ -507,9 +512,10 @@ async def get_index_status_endpoint():
 async def clear_index_endpoint():
     """清空索引"""
     try:
+        chat_logger.info("以图搜图:开始清空索引")
         await image_search_service.clear_index()
         return {"success": True, "message": "索引已清空"}
 
     except Exception as e:
-        chat_logger.error(f"清空索引失败: {str(e)}")
+        chat_logger.error(f"以图搜图:清空索引失败: {str(e)}")
         raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")

+ 7 - 4
config/tool_config.json

@@ -17,7 +17,7 @@
         "使用示例": "用户输入:'查看铜管的库存' -> 系统调用此工具获取铜管的库存信息"
     },
     "get_sale_amt": {
-        "基础描述": "获取指定时间范围的销售金额,可指定不同的汇总方式",
+        "基础描述": "获取指定时间范围的销售金额,可指定不同的汇总方式,如果未指定汇总方式,默认按月汇总,如果未指定时间范围,默认最近3个月",
         "入参说明": {
             "backend_url": "后端API地址",
             "token": "认证令牌",
@@ -243,7 +243,7 @@
         "使用示例": "查产品 A388餐椅"
     },
     "get_saler_ranking": {
-        "基础描述": "业务员销售排行统计表",
+        "基础描述": "业务员销售排行统计表,如果未指定时间范围,默认最近3个月",
         "入参说明": {
             "backend_url": "后端API地址",
             "token": "认证令牌",
@@ -255,6 +255,7 @@
         },
         "输出格式要求": [
             "# 业务员销售排行统计表",
+            "时间范围:{beginmonth}到{endmonth}",
             "##国内",
             "(输出国内业务员销售额echarts饼图,如果图表功能可用)",
             "| 业务员 | 销售额 |  占比  |",
@@ -275,7 +276,7 @@
         ]
     },
     "get_cust_ranking": {
-        "基础描述": "客户销售排名统计表",
+        "基础描述": "客户销售排名统计表,如果未指定时间范围,默认最近3个月",
         "入参说明": {
             "backend_url": "后端API地址",
             "token": "认证令牌",
@@ -287,6 +288,7 @@
         },
         "输出格式要求": [
             "# 客户销售排名统计表",
+            "时间范围:{beginmonth}到{endmonth}",
             "##国内",
             "(输出国内客户销售额echarts饼图,如果图表功能可用)",
             "| 客户 | 销售额 |  占比  |",
@@ -306,7 +308,7 @@
         ]
     },
     "get_mtrl_ranking": {
-        "基础描述": "产品型号销售统计表",
+        "基础描述": "产品型号销售统计表,如果未指定时间范围,默认最近3个月",
         "入参说明": {
             "backend_url": "后端API地址",
             "token": "认证令牌",
@@ -318,6 +320,7 @@
         },
         "输出格式要求": [
             "# 产品型号销售排名统计表",
+            "时间范围:{beginmonth}到{endmonth}",
             "##国内",
             "(输出国内产品型号销售额echarts饼图,如果图表功能可用)",
             "| 产品型号 | 销量 | 销售额 |  金额占比  |",

+ 145 - 218
core/agent.py

@@ -11,6 +11,7 @@ from langchain_core.messages import (
     BaseMessage,
     trim_messages,
 )
+from openai import chat
 from tools.tool_factory import get_all_tools
 from langchain_core.runnables import RunnableConfig
 from langchain.agents.middleware import before_model
@@ -21,149 +22,134 @@ from langgraph.graph.message import REMOVE_ALL_MESSAGES
 import sqlite3
 from config.settings import settings
 from langchain_core.messages.utils import count_tokens_approximately
+from core.worker_manager import get_worker_tools
+from utils.context_helper import safe_context_param
+from utils.logger import chat_logger
 
 dotenv.load_dotenv()
 
 
 def create_system_prompt(
-    backend_url: str = "", token: str = "", username: str = "default"
+    backend_url: str = "", token: str = "", username: str = "default", context: str = ""
 ) -> str:
-    auth_status = "已认证" if token else "未认证"
-    backend_available = "API可用" if backend_url and token else "仅数据查询"
-    knowledge_status = (
-        "知识库可用" if settings.KNOWLEDGE_BASE_ENABLED else "知识库已禁用"
-    )
-    echart_status = "图表可用" if settings.ECHARTS_ENABLED else "图表已禁用"
-
-    if settings.KNOWLEDGE_BASE_ENABLED:
-        # 知识库启用时的提示词
-        system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status} 图表:{echart_status}
-    职责:ERP数据查询和问题解答,按用户语言回答。
-
-    **核心安全指令 (必遵)**:
-    1.  **当前凭据 (每次工具调用必须使用)**:
-        - 后端地址: {backend_url if backend_url else '无'}
-        - API令牌: {token if token else '无'}
-    2.  **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
-    3.  **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
-    工作流:
-    1. 分析问题意图,提取模块关键词
-    2. 如果是数据查询类问题,直接调用相关工具查询数据
-    3. 如果是其他问题,则通过工具搜索知识库,知识库工具使用流程:a.通过关键字获取相关文章列表,b.判断哪些文章最符合,c.再通过工具获取文章内容.严格按文章内容回复,不能编造答案.
-    4. 关键词要精准,避免无意义词
-    工具调用规格:
-    - 如果连续3次调用相同工具相同参数,自动停止
-    - 工具返回相同结果但仍在重复调用时,自动停止
-    回答规则:
-    - 知识库找不到时提示"正在学习该问题"
-    - {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询和知识库支持"}
-    - 保护隐私,专业准确,精炼简要
-    时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
-    数据查询结果尽量以 Markdown 表格格式输出,格式如下:
-    | 列名1 | 列名2 | 列名3 |
-    | :--- | :--- | :--- |
-    | 数据1 | 数据2 | 数据3 |
-    | 数据4 | 数据5 | 数据6 |
-    """
-    else:
-        # 知识库禁用时的提示词 - 灵活处理工具返回结果
-        system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status} 图表:{echart_status}
-    职责:处理ERP数据查询类问题,按用户语言回答。
-    **核心安全指令 (必遵)**:
-    1.  **当前凭据 (每次工具调用必须使用)**:
-        - 后端地址: {backend_url if backend_url else '无'}
-        - API令牌: {token if token else '无'}
-    2.  **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
-    3.  **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
-    工作流:
-    1. 分析问题意图,判断是否为数据查询类问题
-    2. 如果是数据查询类问题,直接调用相关工具查询数据
-    3. 根据工具返回的结果进行回答:
-    - 如果工具返回了具体数据,按数据内容回答
-    - 如果工具返回了错误信息(如"API返回错误","查询失败","没有权限"等),如实告知用户错误信息
-    - 如果工具返回空数据或"未找到数据",如实告知用户
-    4. 如果是非数据查询类问题(如疑问、流程、操作等),回复:"知识库正在完善,无法回答该问题"
-    工具调用规格:
-    - 禁止连续调用相同工具相同参数
-    - 工具返回相同结果但仍在重复调用时,自动停止
-    回答规则:
-    - 如用户提出非ERP范围的问题(例如:"你好"等闲聊),明确告知用户自己的职责:主要处理ERP数据查询类问题
-    - 工具提示没有权限时,明确回复用户没有权限
-    - 严格按工具返回的内容回答,不能编造答案,可对结果进行简单总结
-    - 当工具返回错误信息时,如实转达给用户,不要添加额外解释
-    - 保持专业、准确、简洁的回答风格
-    {"- 需要个人数据时验证认证状态" if backend_url else "- 仅提供数据查询支持"}
-    当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
-    数据查询结果尽量以 Markdown 表格格式输出,格式如下:
-    | 列名1 | 列名2 | 列名3 |
-    | :--- | :--- | :--- |
-    | 数据1 | 数据2 | 数据3 |
-    | 数据4 | 数据5 | 数据6 |
-    """
-    if settings.ECHARTS_ENABLED:
-        system_prompt = (
-            system_prompt
-            + """
-    并且根据数据的格式,主动选择合适的图表输出,你可以输出柱状图、折线图、饼图。
-
-    饼图格式范例如下:
-    ```echarts
-    {{
-    "title": {{
-        "text": "浏览器份额", "left": "center" }}
-    "tooltip": {{
-        "trigger": "item" }},
-    "legend": {{
-        "orient": "vertical", "left": "left" }},
-    "series": [
-        {{
-        "name": "Share",
-        "type": "pie",
-        "radius": "55%",
-        "center": ["50%", "60%"],
-        "data": [
-            {{"value": 1048, "name": "Chrome" }}
-            {{"value": 735, "name": "Firefox" }}
-            {{"value": 580, "name": "Edge" }}
-        ]
-        }}
-    ]
-    }}
-    ```
-
-    柱状图格式范例如下:
-    ```echarts
-    {{
-    "title": {{"text": "每周销量" }}
-    "tooltip": {{}},
-    "xAxis": {{"type": "category", "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"] }}
-    "yAxis": {{"type": "value" }}
-    "series": [
-        {{"type": "bar", "data": [120, 200, 150, 80, 70, 110, 130] }}
-    ]
-    }}
-    ```
-
-    折线图,格式范例如下:
-    ```echarts
-    {{
-    "title": {{ "text": "温度趋势" }},
-    "tooltip": {{ "trigger": "axis" }},
-    "legend": {{ "data": ["最高", "最低"] }},
-    "xAxis": {{
-        "type": "category",
-        "boundaryGap": false,
-        "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"]
-    }},
-    "yAxis": {{"type": "value" }},
-    "series": [
-        {{"name": "最高", "type": "line", "data": [11, 11, 15, 13, 12, 13, 10], "smooth": true }}
-        {{"name": "最低", "type": "line", "data": [1, -2, 2, 5, 3, 2, 0], "smooth": true }}
-    ]
-    }}
-    ```
-    """
-        )
+    # auth_status = "已认证" if token else "未认证"
+    # backend_available = "API可用" if backend_url and token else "仅数据查询"
+    # knowledge_status = (
+    #     "知识库可用" if settings.KNOWLEDGE_BASE_ENABLED else "知识库已禁用"
+    # )
+    # echart_status = "图表可用" if settings.ECHARTS_ENABLED else "图表已禁用"
+
+    # if settings.KNOWLEDGE_BASE_ENABLED:
+    #     # 知识库启用时的提示词
+    #     system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status} 图表:{echart_status}
+    # 职责:ERP数据查询和问题解答,按用户语言回答。
+
+    # **核心安全指令 (必遵)**:
+    # 1.  **当前凭据 (每次工具调用必须使用)**:
+    #     - 后端地址: {backend_url if backend_url else '无'}
+    #     - API令牌: {token if token else '无'}
+    # 2.  **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
+    # 3.  **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
+    # 工作流:
+    # 1. 分析问题意图,提取模块关键词
+    # 2. 如果是数据查询类问题,直接调用相关工具查询数据
+    # 3. 如果是其他问题,则通过工具搜索知识库,知识库工具使用流程:a.通过关键字获取相关文章列表,b.判断哪些文章最符合,c.再通过工具获取文章内容.严格按文章内容回复,不能编造答案.
+    # 4. 关键词要精准,避免无意义词
+    # 工具调用规格:
+    # - 如果连续3次调用相同工具相同参数,自动停止
+    # - 工具返回相同结果但仍在重复调用时,自动停止
+    # 回答规则:
+    # - 知识库找不到时提示"正在学习该问题"
+    # - {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询和知识库支持"}
+    # - 保护隐私,专业准确,精炼简要
+    # 时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
+    # 数据查询结果尽量以 Markdown 表格格式输出,格式如下:
+    # | 列名1 | 列名2 | 列名3 |
+    # | :--- | :--- | :--- |
+    # | 数据1 | 数据2 | 数据3 |
+    # | 数据4 | 数据5 | 数据6 |
+    # """
+    # else:
+    #     # 知识库禁用时的提示词 - 灵活处理工具返回结果
+    #     system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status} 图表:{echart_status}
+    # 职责:处理ERP数据查询类问题,按用户语言回答。
+    # **核心安全指令 (必遵)**:
+    # 1.  **当前凭据 (每次工具调用必须使用)**:
+    #     - 后端地址: {backend_url if backend_url else '无'}
+    #     - API令牌: {token if token else '无'}
+    # 2.  **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
+    # 3.  **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
+    # 工作流:
+    # 1. 分析问题意图,判断是否为数据查询类问题
+    # 2. 如果是数据查询类问题,直接调用相关工具查询数据
+    # 3. 根据工具返回的结果进行回答:
+    # - 如果工具返回了具体数据,按数据内容回答
+    # - 如果工具返回了错误信息(如"API返回错误","查询失败","没有权限"等),如实告知用户错误信息
+    # - 如果工具返回空数据或"未找到数据",如实告知用户
+    # 4. 如果是非数据查询类问题(如疑问、流程、操作等),回复:"知识库正在完善,无法回答该问题"
+    # 工具调用规格:
+    # - 禁止连续调用相同工具相同参数
+    # - 工具返回相同结果但仍在重复调用时,自动停止
+    # 回答规则:
+    # - 如用户提出非ERP范围的问题(例如:"你好"等闲聊),明确告知用户自己的职责:主要处理ERP数据查询类问题
+    # - 工具提示没有权限时,明确回复用户没有权限
+    # - 严格按工具返回的内容回答,不能编造答案,可对结果进行简单总结
+    # - 当工具返回错误信息时,如实转达给用户,不要添加额外解释
+    # - 保持专业、准确、简洁的回答风格
+    # {"- 需要个人数据时验证认证状态" if backend_url else "- 仅提供数据查询支持"}
+    # 当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
+    # 数据查询结果尽量以 Markdown 表格格式输出,格式如下:
+    # | 列名1 | 列名2 | 列名3 |
+    # | :--- | :--- | :--- |
+    # | 数据1 | 数据2 | 数据3 |
+    # | 数据4 | 数据5 | 数据6 |
+    # """
+
+    context = safe_context_param(context)
+
+    system_prompt = f"""
+# 龙嘉AI助手-多Agent协调系统
+
+## 用户信息
+- 用户名: {username}
+
+##上下文
+上下文长度: {len(context)}
+
+## 严格行为规则
+你是一个调度器,**不是回答者**。你的唯一职责是决定是否调用工具。
+每轮对话只能调用一次工具,不能连续调用。
+
+## 决策规则
+- 就算提供了上下文,你也**必须调用**合适的工具,不能直接回答用户问题。
+- 如果无法判断调用那个工具,引导用户提供更多信息。
+
+## 关键行为约束
+### 当调用工具时:
+1. 你**必须**调用合适的工具
+2. 调用后**立即停止**,只能输出**"工具调用成功"**
+
+## 重要警告
+- 调用工具后,**不要**基于工具返回的结果继续生成回答
+- 系统会自动将工具结果返回给用户
+- 你在调用工具后,输出"工具调用成功",你的任务就结束了
+
+当需要工具时
+- **必须**调用合适的Worker工具
+- **必须传递以下4个参数**:
+  1. query: 用户问题
+  2. backend_url: {backend_url if backend_url else '无'}
+  3. token: {token if token else '无'}
+  4. context: {f"###对话上下文开始###\n{context}\n###对话上下文结束###" if context else '无'}
+- **输出格式**:必须且只能输出:`工具调用成功`
+
+## 零容忍规则
+**严禁参数错误**:每次调用都必须使用当前提供的backend_url/token,不能使用历史值
+**严禁猜测**:如果不能确定,就调用工具
+
+## 当前状态
+- 时间: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
+"""
 
     return system_prompt
 
@@ -260,6 +246,7 @@ def create_langchain_agent(
     token: str = "",
     username: str = "default",
     thread_id: str = "default",
+    context: str = "",
 ):
     llm = ChatOpenAI(
         model=settings.LLM_MODEL,
@@ -269,87 +256,23 @@ def create_langchain_agent(
         max_tokens=settings.LLM_MAX_TOKENS,
     )
 
-    tools = get_all_tools()
+    tools = get_worker_tools()
     # 添加调试信息
-    print(f"[DEBUG]Agent 创建调试信息:")
-    print(f"  - 用户: {username}")
-    print(f"  - Thread ID: {thread_id}")
-    print(f"  - 后端地址: {backend_url}")
-    print(f"  - Token: {'已提供' if token else '未提供'}")
-    print(f"  - 工具数量: {len(tools)}")
+    # print(f"[DEBUG]Agent 创建调试信息:")
+    # print(f"  - 用户: {username}")
+    # print(f"  - Thread ID: {thread_id}")
+    # print(f"  - 后端地址: {backend_url}")
+    # print(f"  - Token: {'已提供' if token else '未提供'}")
+    # print(f"  - worker数量: {len(tools)}")
 
-    for i, tool in enumerate(tools):
-        print(f"  - 工具 {i+1}: {tool.name}")
+    # for i, tool in enumerate(tools):
+    #     print(f"  - worker {i+1}: {tool.name}")
 
     # 获取动态的system_prompt
-    system_prompt = create_system_prompt(backend_url, token, username)
-    print(system_prompt)
-    # def simple_turn_based_trim(
-    #     messages: Sequence[BaseMessage],
-    #     keep_turns: int = 3,
-    #     system_message: BaseMessage = None,
-    # ) -> List[BaseMessage]:
-    #     """
-    #     修正版:按完整对话轮次修剪消息
-    #     每轮对话从Human开始,到下一个Human之前结束
-    #     """
-    #     if not messages:
-    #         return []
-    #     # 分离系统消息(始终保留)
-    #     system_messages = []
-    #     other_messages = []
-    #     for msg in messages:
-    #         if (
-    #             isinstance(msg, SystemMessage)
-    #             or getattr(msg, "type", None) == "system"
-    #             or getattr(msg, "role", None) == "system"
-    #             or msg.__class__.__name__ == "SystemMessage"
-    #         ):
-    #             system_messages.append(msg)
-    #         else:
-    #             other_messages.append(msg)
-
-    #     if len(other_messages) <= 1:
-    #         return system_messages + other_messages
-
-    #     # 找出所有Human消息的位置
-    #     human_indices = []
-    #     for i, msg in enumerate(other_messages):
-    #         if (
-    #             isinstance(msg, HumanMessage)
-    #             or getattr(msg, "type", None) == "human"
-    #             or getattr(msg, "role", None) == "user"
-    #         ):
-    #             human_indices.append(i)
-
-    #     # 如果Human消息不足keep_turns轮,返回所有
-    #     if not human_indices or len(human_indices) <= keep_turns:
-    #         return system_messages + other_messages
-
-    #     # 计算起始索引
-    #     start_idx = human_indices[-keep_turns]
-
-    #     # 获取要保留的消息
-    #     preserved_messages = other_messages[start_idx:]
-
-    #     # 4. 返回从该索引开始的所有消息
-    #     result = system_messages + preserved_messages
-    #     # print(f"修剪后消息数: {len(result)}")
-
-    #     return result
-
-    # @before_model
-    # def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
-    #     """Keep only the last few messages to fit context window."""
-    #     messages = state["messages"]
-
-    #     if len(messages) <= 3:
-    #         return None  # No changes needed
-
-    #     # 保留最后4轮对话
-    #     trimmed_messages = simple_turn_based_trim(messages, keep_turns=4)
-
-    #     return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)] + trimmed_messages}
+    system_prompt = create_system_prompt(backend_url, token, username, context)
+    # print(f"[DEBUG]上下文长度: {len(context)}")
+    # print(system_prompt)
+    # chat_logger.info(f"主Agent System Prompt上下文: {system_prompt}")
 
     @before_model
     def trim_messages_middleware(
@@ -363,7 +286,7 @@ def create_langchain_agent(
 
         trimmed_messages = trim_messages(
             messages,
-            max_tokens=1000,
+            max_tokens=500,
             strategy="last",  # 保留最近的对话
             token_counter=count_tokens_approximately,  # token计数器
             start_on="human",  # 从human消息开始计算轮次
@@ -388,12 +311,16 @@ def create_langchain_agent(
     if os.getenv("AUTO_CLEANUP", "false").lower() == "true":
         cleanup_old_checkpoints(max_days=7)  # 保留最近7天数据
 
+    # agent = create_agent(
+    #     llm,
+    #     tools,
+    #     checkpointer=checkpointer,
+    #     system_prompt=system_prompt,
+    #     middleware=[trim_messages_middleware],
+    # )
     agent = create_agent(
         llm,
         tools,
-        checkpointer=checkpointer,
         system_prompt=system_prompt,
-        middleware=[trim_messages_middleware],
     )
-
     return agent

+ 23 - 5
core/agent_manager.py

@@ -24,10 +24,15 @@ class AgentManager:
         chat_logger.info("Agent管理器已关闭")
 
     def _get_agent_config_key(
-        self, thread_id: str, username: str, backend_url: str, token: str
+        self,
+        thread_id: str,
+        username: str,
+        backend_url: str,
+        token: str,
+        context: str = "",
     ) -> str:
         """生成agent配置的缓存key"""
-        key_data = f"{thread_id}:{username}:{backend_url}:{token}"
+        key_data = f"{thread_id}:{username}:{backend_url}:{token}:{context}"
         return hashlib.md5(key_data.encode()).hexdigest()
 
     def _get_user_identifier(self, username: str, token: str) -> str:
@@ -45,7 +50,12 @@ class AgentManager:
         return f"{username_part}_{token_part}"
 
     async def get_agent_instance(
-        self, thread_id: str, username: str, backend_url: str, token: str
+        self,
+        thread_id: str,
+        username: str,
+        backend_url: str,
+        token: str,
+        context: str = "",
     ):
         if self._is_shutdown:
             raise RuntimeError("Agent管理器已关闭")
@@ -53,10 +63,11 @@ class AgentManager:
         clean_username = username or "anonymous"
         clean_backend = backend_url or ""
         clean_token = token or ""
+        clean_context = context or ""
 
         user_id = self._get_user_identifier(clean_username, clean_token)
         config_key = self._get_agent_config_key(
-            thread_id, clean_username, clean_backend, clean_token
+            thread_id, clean_username, clean_backend, clean_token, clean_context
         )
         print(f"config_key: {config_key}")
         # 检查本地配置缓存
@@ -73,6 +84,7 @@ class AgentManager:
             token=clean_token,
             username=clean_username,
             thread_id=thread_id,
+            context=clean_context,
         )
 
         # 缓存agent配置到本地
@@ -82,7 +94,12 @@ class AgentManager:
         return agent_instance
 
     async def _create_agent_async(
-        self, backend_url: str, token: str, username: str, thread_id: str
+        self,
+        backend_url: str,
+        token: str,
+        username: str,
+        thread_id: str,
+        context: str = "",
     ):
         """创建agent实例"""
 
@@ -92,6 +109,7 @@ class AgentManager:
                 token=token,
                 username=username,
                 thread_id=thread_id,
+                context=context,
             )
 
         loop = asyncio.get_event_loop()

+ 90 - 4
core/async_chat_service.py

@@ -1,7 +1,8 @@
 import asyncio
 import time
 from typing import Dict, Any
-from langchain_core.messages import HumanMessage
+from langchain_core.messages import HumanMessage, ToolMessage
+from openai import chat
 from utils.logger import chat_logger, log_chat_entry
 from core.agent_manager import agent_manager
 from core.chat_result_manager import chat_result_manager
@@ -52,7 +53,16 @@ class AsyncChatService:
             token = request_data["token"]
             user_id = username
 
-            chat_logger.info(f"开始处理任务 - 任务ID={task_id}, 用户={user_id}")
+            chat_logger.info(
+                f"开始处理任务 - 任务ID={task_id}, 用户={user_id},问题={message}"
+            )
+
+            # 获取对话上下文
+            context = context_manager.get_recent_context(thread_id)
+            if context:
+                chat_logger.info(
+                    f"获取到对话上下文 - 线程={thread_id}, 上下文长度={len(context)}"
+                )
 
             # 异步获取agent实例
             agent = await self.agent_manager.get_agent_instance(
@@ -60,6 +70,7 @@ class AsyncChatService:
                 username=username,
                 backend_url=backend_url,
                 token=token,
+                context=context,
             )
 
             # 在线程池中执行同步的Langchain操作
@@ -67,12 +78,23 @@ class AsyncChatService:
                 agent, message, thread_id, user_id
             )
 
+            # chat_logger.info(f"主Agent返回结果: {result}")
+            print("主Agent返回结果:", result)
+
             if not isinstance(result, dict) or "messages" not in result:
                 raise ValueError(f"Agent返回格式异常: {type(result)}")
 
             # 处理结果
             response_data = self._process_agent_result(result, user_id, request_data)
 
+            # 更新对话上下文(只记录成功的对话)
+            if response_data.get("success", False) and response_data.get(
+                "final_answer"
+            ):
+                context_manager.update_context(
+                    thread_id, message, response_data["final_answer"]
+                )
+
             # 更新任务状态为完成
             chat_result_manager.update_task_status(task_id, "completed", response_data)
 
@@ -143,7 +165,7 @@ class AsyncChatService:
                 for tool_call in msg.tool_calls:
                     tool_name = tool_call.get("name", "unknown")
                     tool_args = tool_call.get("args", {})
-                    chat_logger.info(f"工具调用 - 用户={user_id}, 工具={tool_name}")
+                    chat_logger.info(f"完成工具调用 - 用户={user_id}, 工具={tool_name}")
 
             if hasattr(msg, "tool_call_id"):
                 msg_data["tool_call_id"] = msg.tool_call_id
@@ -158,6 +180,16 @@ class AsyncChatService:
                 all_ai_messages.append(msg_data)
                 final_answer = msg_data["content"]
 
+        if final_answer == "工具调用成功":
+            last_tool_content = None
+            for msg in reversed(result.get("messages", [])):
+                if isinstance(msg, ToolMessage):
+                    last_tool_content = msg.content
+                    break
+
+            if last_tool_content:
+                final_answer = last_tool_content
+
         # 构建响应
         response = {
             "final_answer": final_answer,
@@ -182,7 +214,7 @@ class AsyncChatService:
     async def get_task_result(self, task_id: str) -> Dict[str, Any]:
         """获取任务结果"""
         task_info = chat_result_manager.get_task(task_id)
-        chat_logger.info(f"获取任务结果 - 任务ID={task_id}, 状态={task_info['status']}")
+        # chat_logger.info(f"获取任务结果 - 任务ID={task_id}, 状态={task_info['status']}")
         if not task_info:
             return {
                 "success": False,
@@ -211,3 +243,57 @@ class AsyncChatService:
 
 # 全局实例
 async_chat_service = AsyncChatService()
+
+
+# 在async_chat_service.py中添加上下文管理功能
+class ContextManager:
+    def __init__(self, max_history=3):
+        self.conversation_history = (
+            {}
+        )  # thread_id -> list of (human_message, ai_message)
+        self.max_history = max_history
+
+    def get_recent_context(self, thread_id: str) -> str:
+        """获取最近3轮对话的上下文"""
+        if thread_id not in self.conversation_history:
+            return ""
+
+        history = self.conversation_history[thread_id]
+        if not history:
+            return ""
+
+        # 获取最近3轮对话
+        recent_history = history[-self.max_history :]
+
+        # 构建上下文字符串(只包含human和ai消息)
+        context_parts = []
+        for i, (human_msg, ai_msg) in enumerate(recent_history):
+            context_parts.append(f"第{len(history)-len(recent_history)+i+1}轮对话:")
+            context_parts.append(f"用户:{human_msg}")
+            context_parts.append(f"AI:{ai_msg}")
+            context_parts.append("")  # 空行分隔
+
+        return "\n".join(context_parts).strip()
+
+    def update_context(self, thread_id: str, human_message: str, ai_message: str):
+        """更新对话上下文"""
+        if thread_id not in self.conversation_history:
+            self.conversation_history[thread_id] = []
+
+        # 添加新的对话轮次
+        self.conversation_history[thread_id].append((human_message, ai_message))
+
+        # 保持最多max_history轮对话
+        if len(self.conversation_history[thread_id]) > self.max_history:
+            self.conversation_history[thread_id] = self.conversation_history[thread_id][
+                -self.max_history :
+            ]
+
+    def clear_context(self, thread_id: str):
+        """清空特定线程的上下文"""
+        if thread_id in self.conversation_history:
+            self.conversation_history[thread_id] = []
+
+
+# 全局上下文管理器实例
+context_manager = ContextManager(max_history=3)

+ 1 - 1
core/chat_service.py

@@ -127,7 +127,7 @@ class ChatService:
                 for tool_call in msg.tool_calls:
                     tool_name = tool_call.get("name", "unknown")
                     tool_args = tool_call.get("args", {})
-                    chat_logger.info(f"工具调用 - 用户={user_id}, 工具={tool_name}")
+                    chat_logger.info(f"完成工具调用 - 用户={user_id}, 工具={tool_name}")
 
             if hasattr(msg, "tool_call_id"):
                 msg_data["tool_call_id"] = msg.tool_call_id

+ 5 - 0
core/lifespan_manager.py

@@ -4,6 +4,7 @@ from core.agent_manager import agent_manager
 from utils.registration_manager import registration_manager
 from utils.logger import chat_logger
 from core.chat_result_manager import chat_result_manager
+from core.worker_manager import init_worker_cache
 
 
 @asynccontextmanager
@@ -18,6 +19,10 @@ async def lifespan(app: FastAPI):
             chat_logger.info("服务启动:注册检查通过")
 
         await agent_manager.initialize()
+
+        # 初始化Worker缓存
+        init_worker_cache()
+
         chat_logger.info("AI助手服务启动")
         yield
     finally:

+ 524 - 0
core/worker_manager.py

@@ -0,0 +1,524 @@
+from pdb import run
+import sys
+
+from langchain.tools import tool
+from langchain.agents import create_agent
+from langchain_openai import ChatOpenAI
+from typing import Dict, Any, Optional
+
+from openai import chat
+from config.settings import settings
+from langchain_core.messages import HumanMessage, SystemMessage
+import datetime
+from utils.context_helper import safe_context_param
+from utils.logger import chat_logger
+
+"""
+增加worker需要修改:
+worker_type_mapping
+"""
+
+# WorkerAgent缓存字典
+_worker_cache = {}
+# 工具缓存字典
+_tools_cache = {}
+
+# worker_type, tool_types
+worker_type_mapping = {
+    "sale_worker": ("销售报表查询", ["sale"]),
+    "ware_worker": ("库存查询", ["ware"]),
+    "money_worker": ("财务查询", ["money"]),
+    "price_worker": ("价格查询", ["price"]),
+    "salebill_worker": ("销售类单据查询", ["salebill"]),
+    "base_data_worker": ("基础资料查询", ["mtrl_data", "cust_data"]),
+}
+
+
+class WorkerAgent:
+    """Worker Agent - 专门处理特定领域的任务"""
+
+    def __init__(self, worker_type: str, tools: list):
+        self.worker_type = worker_type
+        self.tools = tools
+        self.llm = ChatOpenAI(
+            model=settings.LLM_MODEL,
+            temperature=settings.LLM_TEMPERATURE,
+            api_key=settings.DEEPSEEK_API_KEY,
+            base_url=settings.DEEPSEEK_BASE_URL,
+            max_tokens=settings.LLM_MAX_TOKENS,
+        )
+        self.agent = self._create_agent()
+
+    def _create_agent(self):
+        """创建Agent实例"""
+
+        return create_agent(
+            self.llm,
+            tools=self.tools,
+        )
+
+    def execute(self, query: str, backend_url: str, token: str) -> any:
+        """执行任务"""
+        try:
+            system_prompt = f"""
+ 你是{self.worker_type}专家,专门处理{self.worker_type}相关任务。
+
+ 职责:
+ - 接收主Agent分配的具体任务
+ - 智能选择最合适的工具执行任务
+ - 理解工具返回的数据结构
+ - 组织完整的答案向用户回答
+
+ 工具用到的参数:
+ - backend_url: {backend_url}
+ - token: {token}
+
+ 当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
+
+ 数据查询结果尽量以 Markdown 表格格式输出
+
+ 如果工具有输出格式要求,必须严格按照要求输出。
+ """
+
+            if settings.ECHARTS_ENABLED:
+                system_prompt = (
+                    system_prompt
+                    + """
+ 你可以输出echarts柱状图、折线图、饼图。
+
+ 饼图格式范例如下:
+ ```echarts
+ {{
+ "title": {{
+     "text": "浏览器份额", "left": "center" }}
+ "tooltip": {{
+     "trigger": "item" }},
+ "legend": {{
+     "orient": "vertical", "left": "left" }},
+ "series": [
+     {{
+     "name": "Share",
+     "type": "pie",
+     "radius": "55%",
+     "center": ["50%", "60%"],
+     "data": [
+         {{"value": 1048, "name": "Chrome" }}
+         {{"value": 735, "name": "Firefox" }}
+         {{"value": 580, "name": "Edge" }}
+     ]
+     }}
+ ]
+ }}
+ ```
+
+ 柱状图格式范例如下:
+ ```echarts
+ {{
+ "title": {{"text": "每周销量" }}
+ "tooltip": {{}},
+ "xAxis": {{"type": "category", "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"] }}
+ "yAxis": {{"type": "value" }}
+ "series": [
+     {{"type": "bar", "data": [120, 200, 150, 80, 70, 110, 130] }}
+ ]
+ }}
+ ```
+
+ 折线图,格式范例如下:
+ ```echarts
+ {{
+ "title": {{ "text": "温度趋势" }},
+ "tooltip": {{ "trigger": "axis" }},
+ "legend": {{ "data": ["最高", "最低"] }},
+ "xAxis": {{
+     "type": "category",
+     "boundaryGap": false,
+     "data": ["Mon","Tue","Wed","Thu","Fri","Sat","Sun"]
+ }},
+ "yAxis": {{"type": "value" }},
+ "series": [
+     {{"name": "最高", "type": "line", "data": [11, 11, 15, 13, 12, 13, 10], "smooth": true }}
+     {{"name": "最低", "type": "line", "data": [1, -2, 2, 5, 3, 2, 0], "smooth": true }}
+ ]
+ }}
+ ```
+ """
+                )
+
+            inputs = {
+                "messages": [
+                    SystemMessage(content=system_prompt),
+                    HumanMessage(content=query),
+                ]
+            }
+            response = self.agent.invoke(inputs)
+            return response
+
+        except Exception as e:
+            return f"Worker执行失败: {str(e)}"
+
+
+# 自动发现指定类型的工具
+def _get_tools_by_type(tool_type: str) -> list:
+    """根据类型自动发现工具 - 完整支持编译前后的文件"""
+
+    # 检查缓存
+    if tool_type in _tools_cache:
+        chat_logger.info(f"使用缓存的工具: {tool_type}")
+        return _tools_cache[tool_type]
+
+    chat_logger.info(f"发现并加载工具: {tool_type}")
+
+    import importlib
+    import inspect
+    from pathlib import Path
+    import sys
+    from langchain.tools import BaseTool
+
+    tools = []
+
+    # 获取项目根目录并添加到Python路径
+    project_root = Path(__file__).parent.parent
+    if str(project_root) not in sys.path:
+        sys.path.insert(0, str(project_root))
+
+    tools_dir = project_root / "tools"
+    # 导入配置以检查知识库开关
+    from config.settings import settings
+
+    # 扫描工具文件
+    tool_files = []
+
+    # 模式1: 编译前的.py文件
+    for file_path in tools_dir.glob("*_tools.py"):
+        if file_path.is_file():
+            module_name = file_path.stem
+
+            # 检查文件名是否以指定类型开头(前缀匹配)
+            if module_name.startswith(f"{tool_type}_"):
+                # 如果知识库被禁用,跳过知识库工具
+                if (
+                    module_name == "knowledge_tools"
+                    and not settings.KNOWLEDGE_BASE_ENABLED
+                ):
+                    print("知识库功能已禁用,跳过知识库工具")
+                    continue
+                tool_files.append(module_name)
+                print(f"发现工具文件: {module_name}")
+
+    # 模式2: 编译后的.pyd文件
+    for file_path in tools_dir.glob("*_tools.cp*.pyd"):
+        if file_path.is_file():
+            # 从文件名中提取模块名,如: ware_tools.cp313-win_amd64.pyd -> ware_tools
+            module_name = file_path.stem.split(".")[0]
+
+            # 检查文件名是否以指定类型开头(前缀匹配)
+            if module_name.startswith(f"{tool_type}_"):
+                # 如果知识库被禁用,跳过知识库工具
+                if (
+                    module_name == "knowledge_tools"
+                    and not settings.KNOWLEDGE_BASE_ENABLED
+                ):
+                    print("知识库功能已禁用,跳过知识库工具")
+                    continue
+                if module_name not in tool_files:  # 避免重复添加
+                    tool_files.append(module_name)
+                    print(f"发现编译后工具文件: {module_name}")
+
+    # 如果没有找到指定类型的工具文件,使用默认列表
+    if not tool_files:
+        # 根据类型提供默认工具列表
+        default_tool_mapping = {
+            "sale": ["sale_tools"],
+            "ware": ["ware_tools"],
+            "money": ["money_tools"],
+            "price": ["price_tools"],
+            "knowledge": ["knowledge_tools"] if settings.KNOWLEDGE_BASE_ENABLED else [],
+        }
+
+        if tool_type in default_tool_mapping:
+            tool_files = default_tool_mapping[tool_type]
+            print(f"使用默认工具列表: {tool_files}")
+
+    # 导入模块并获取工具
+    for module_name in tool_files:
+        try:
+            full_module_path = f"tools.{module_name}"
+            module = importlib.import_module(full_module_path)
+
+            # 使用更全面的工具发现方法(参考tool_factory.py)
+            tool_count = 0
+
+            # 方法1: 检查模块的所有属性
+            for attr_name in dir(module):
+                if attr_name.startswith("_"):
+                    continue
+
+                attr = getattr(module, attr_name)
+
+                # 检查是否是BaseTool实例
+                if isinstance(attr, BaseTool):
+                    tools.append(attr)
+                    tool_count += 1
+                    continue
+
+                # 检查是否是函数且具有工具属性
+                if callable(attr) and hasattr(attr, "name"):
+                    tools.append(attr)
+                    tool_count += 1
+
+            # 方法2: 检查模块中是否有get_all_tools函数
+            if hasattr(module, "get_all_tools"):
+                module_tools = module.get_all_tools()
+                if isinstance(module_tools, list):
+                    tools.extend(module_tools)
+                    tool_count += len(module_tools)
+
+            if tool_count > 0:
+                print(f"✅ 从 {module_name} 加载了 {tool_count} 个工具")
+            else:
+                print(f"⚠️  {module_name} 中未发现工具函数")
+
+        except Exception as e:
+            print(f"❌ 导入模块 {module_name} 失败: {e}")
+
+    # 缓存工具
+    _tools_cache[tool_type] = tools
+
+    return tools
+
+
+def _get_worker_agent(worker_type: str, tools: list) -> WorkerAgent:
+    """获取或创建WorkerAgent实例(使用缓存)"""
+    cache_key = f"{worker_type}"  # 使用worker类型作为缓存键
+
+    if cache_key not in _worker_cache:
+        chat_logger.info(f"创建新的WorkerAgent缓存: {worker_type}")
+        _worker_cache[cache_key] = WorkerAgent(worker_type, tools)
+    else:
+        chat_logger.info(f"使用缓存的WorkerAgent: {worker_type}")
+
+    return _worker_cache[cache_key]
+
+
+def _execute_and_get_result(
+    worker_type: str,
+    tools: list,
+    token: str,
+    backend_url: str,
+    query: str,
+    context: str = "",
+) -> str:
+    """执行worker并获取结果"""
+    try:
+        chat_logger.info(f"开始执行worker: {worker_type}")
+        # worker = WorkerAgent(worker_type, tools)
+        worker = _get_worker_agent(worker_type, tools)
+        # 如果有上下文,合并到查询中
+        # chat_logger.info(f"原始上下文: {context}")
+
+        safe_context = safe_context_param(context)
+        if safe_context:
+            enhanced_query = f"对话上下文:{safe_context}\n当前问题:{query}"
+        else:
+            enhanced_query = query
+
+        # chat_logger.info(f"处理后的上下文: {safe_context}")
+
+        result = worker.execute(enhanced_query, backend_url, token)
+        chat_logger.info(f"Worker{worker_type}执行成功:问题{query}")
+        ai_messages = [
+            msg
+            for msg in result.get("messages", [])
+            if hasattr(msg, "type") and msg.type == "ai"
+        ]
+
+        if ai_messages:
+            # 获取最后一条 AI 消息的内容
+            last_ai_content = ai_messages[-1].content
+            return last_ai_content
+        else:
+            return ""
+
+    except Exception as e:
+        chat_logger.error(
+            f"Worker{worker_type}执行失败:问题{query},错误信息: {str(e)}"
+        )
+        return f"Worker{worker_type}执行失败:问题{query},错误信息: {str(e)}"
+
+
+def init_worker_cache() -> None:
+    """初始化Worker缓存(在服务启动时调用)"""
+    chat_logger.info("开始初始化Worker缓存...")
+
+    # 直接遍历worker_type_mapping字典来初始化缓存
+    for worker_name, (worker_type, tool_types) in worker_type_mapping.items():
+        try:
+            # 组合工具
+            tools = []
+            for tool_type in tool_types:
+                tools.extend(_get_tools_by_type(tool_type))
+
+            if tools:
+                # 创建WorkerAgent实例并缓存
+                cache_key = f"{worker_type}"
+                if cache_key not in _worker_cache:
+                    _worker_cache[cache_key] = WorkerAgent(worker_type, tools)
+                    chat_logger.info(
+                        f"预缓存WorkerAgent: {worker_type} ({worker_name})"
+                    )
+                else:
+                    chat_logger.info(f"WorkerAgent已缓存: {worker_type}")
+            else:
+                chat_logger.info(f"未发现工具,跳过缓存: {worker_type}")
+
+        except Exception as e:
+            chat_logger.info(f"预缓存WorkerAgent失败 {worker_name}: {e}")
+    chat_logger.info(f"Worker缓存初始化完成,共缓存 {len(_worker_cache)} 个WorkerAgent")
+
+
+def clear_worker_cache() -> None:
+    """清除Worker缓存"""
+    global _worker_cache, _tools_cache
+    _worker_cache.clear()
+    _tools_cache.clear()
+    chat_logger.info("Worker缓存已清除")
+
+
+def get_worker_cache_stats() -> dict:
+    """获取Worker缓存统计信息"""
+    return {
+        "worker_agents_cached": len(_worker_cache),
+        "tools_cached": len(_tools_cache),
+        "cached_worker_types": list(_worker_cache.keys()),
+        "cached_tool_types": list(_tools_cache.keys()),
+    }
+
+
+def get_worker_config(worker_name: str) -> tuple[str, list[str]]:
+    """根据worker名称获取配置信息"""
+    if worker_name not in worker_type_mapping:
+        raise ValueError(f"未知的worker名称: {worker_name}")
+
+    return worker_type_mapping[worker_name]
+
+
+def run_worker(
+    worker_name: str,
+    token: str,
+    backend_url: str,
+    query: str,
+    context: str = "",
+) -> str:
+    worker_type, tool_types = get_worker_config(worker_name)
+    tools = []
+    for tool_type in tool_types:
+        tools.extend(_get_tools_by_type(tool_type))
+
+    return _execute_and_get_result(
+        worker_type, tools, token, backend_url, query, context
+    )
+
+
+# Worker定义
+@tool
+def sale_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
+    """
+    销售报表查询专家
+    职责:销售金额、销售数量、客户销售排名、业务员销售排名、产品型号销售数据等
+    使用场景:"2024客户销售业绩排名"、"查2024年产品型号销售排名"、"查看2023年1月1日至2023年12月31日的销售金额"等
+    """
+    # tools = _get_tools_by_type("sale")  # 即 sale_tools.py下的所有工具
+    # return _execute_and_get_result(
+    #     "销售报表查询", tools, token, backend_url, query, context
+    # )
+    return run_worker("sale_worker", token, backend_url, query, context)
+
+
+@tool
+def ware_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
+    """
+    库存查询专家
+    职责:库存管理查询相关问题解答
+    使用场景:"查看铜管的库存"、"安吉仓库最多的库存哪个型号"
+    """
+    # tools = _get_tools_by_type("ware")  # 即 ware_tools.py下的所有工具
+    # return _execute_and_get_result(
+    #     "库存查询", tools, token, backend_url, query, context
+    # )
+    return run_worker("ware_worker", token, backend_url, query, context)
+
+
+@tool
+def money_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
+    """
+    财务查询专家
+
+    职责:财务查询相关问题解答
+    使用场景:"查询客户A的应收帐情况"、"查询客户A,2025年的收款情况"
+    """
+    # 图表工具可以复用数据查询工具的结果
+    # tools = _get_tools_by_type("money")  # 即 money_tools.py下的所有工具
+    # return _execute_and_get_result(
+    #     "财务查询", tools, token, backend_url, query, context
+    # )
+    return run_worker("money_worker", token, backend_url, query, context)
+
+
+@tool
+def price_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
+    """
+    价格查询专家
+
+    职责:价格查询相关问题解答
+    使用场景:"查询铜管的销售价格"
+    """
+    # tools = _get_tools_by_type("price")
+    # return _execute_and_get_result(
+    #     "价格查询", tools, token, backend_url, query, context
+    # )
+    return run_worker("price_worker", token, backend_url, query, context)
+
+
+@tool
+def salebill_worker(query: str, backend_url: str, token: str, context: str = "") -> str:
+    """
+    销售类单据查询专家
+
+    职责:销售类单据查询相关问题解答
+    使用场景:"查询客户A的订单进度"、"查询销售订单SG251119001"
+    """
+    # tools = _get_tools_by_type("salebill")
+    # return _execute_and_get_result(
+    #     "销售类单据查询", tools, token, backend_url, query, context
+    # )
+    return run_worker("salebill_worker", token, backend_url, query, context)
+
+
+@tool
+def base_data_worker(
+    query: str, backend_url: str, token: str, context: str = ""
+) -> str:
+    """
+    基础资料查询专家
+
+    职责:基础资料查询相关问题解答
+    使用场景:"查产品 A388餐椅"、"查一下客户 南华家具"
+    """
+    # tools = _get_tools_by_type("mtrl_data") + _get_tools_by_type("cust_data")
+    # return _execute_and_get_result(
+    #     "基础资料查询", tools, token, backend_url, query, context
+    # )
+    return run_worker("base_data_worker", token, backend_url, query, context)
+
+
+def get_worker_tools() -> list:
+    """获取所有Worker工具"""
+    return [
+        sale_worker,
+        ware_worker,
+        money_worker,
+        price_worker,
+        salebill_worker,
+        base_data_worker,
+    ]

+ 11 - 6
tools/base_tool.py

@@ -4,6 +4,7 @@ import requests
 import json
 from typing import List, Dict, Any, Optional, Callable
 from pathlib import Path
+from utils.logger import chat_logger
 
 
 def html_to_text(html_content: str) -> str:
@@ -196,13 +197,13 @@ def call_csharp_api(
     }
 
     try:
-        print(f"🌐 发送API请求到: {backend_url}")
-        response = requests.post(backend_url, headers=headers, json=payload, timeout=30)
-        print(f"📡 响应状态码: {response.status_code}")
+        response = requests.post(backend_url, headers=headers, json=payload, timeout=60)
 
         if response.status_code == 200:
             data = response.json()
-
+            chat_logger.info(
+                f"API响应数据:uoName={uoName}, functionName={functionName}, SParms={SParms}"
+            )
             # 检查是否存在ErrMsg字段,如果有则直接返回错误信息
             if "ErrMsg" in data and data["ErrMsg"]:
                 error_msg = f"API返回错误: {data['ErrMsg']}"
@@ -212,11 +213,15 @@ def call_csharp_api(
             return process_api_response(data)
         else:
             error_msg = f"API请求失败,状态码: {response.status_code}"
-            print(f"❌ {error_msg}")
+            chat_logger.error(
+                f"uoName={uoName}, functionName={functionName}, SParms={SParms}, error_msg: {error_msg}"
+            )
             return error_msg
     except Exception as e:
         error_msg = f"API调用异常: {str(e)}"
-        print(f"❌ {error_msg}")
+        chat_logger.error(
+            f"uoName={uoName}, functionName={functionName}, SParms={SParms}, error_msg: {error_msg}"
+        )
         return error_msg
 
 

+ 4 - 3
tools/price_tools.py

@@ -6,10 +6,11 @@ def get_mtrl_saleprice_default_config():
     """get_mtrl_saleprice 工具的默认配置"""
     return {
         "get_mtrl_saleprice": {
-            "基础描述": "获取指定时间范围的销售金额,按月汇总",
+            "基础描述": "获取指定物料的销售价格",
+            "功能说明": "从销售管理系统中查询物料的销售价格,包括含税价、辅助单位、转换率等详细信息",
             "入参说明": {
                 "backend_url": "后端API地址",
-                "token": "认证令牌",
+                "token": "用户认证令牌,用于身份验证",
                 "mtrlname": "物料名称 或 物料编码, 支持模糊查询",
             },
             "返回值说明": {
@@ -17,7 +18,7 @@ def get_mtrl_saleprice_default_config():
                 "字段含义": "listname:价格表,currency:币种, mtrlcode:物料编码, mtrlname:物料名称, unit:单位, price:含税价, unit_buy:辅助单位, rate_buy:转换率,price_unit1:辅助单位含税价,saleqty:销售数量下限,saleqty1:销售数量上限",
             },
             "输出格式要求": [
-                "以表格输出",
+                "以表格输出,标题按字段含义显示",
                 "币种、物料编码、物料名称:若所有行该列值完全一致,则整列隐藏",
                 "辅助单位,含税价辅助单位,转换率,数量区间:若所有行该列值为初始值(如'',0),则整列隐藏",
                 "其他列原样显示",

+ 25 - 0
utils/context_helper.py

@@ -0,0 +1,25 @@
+def safe_context_param(context_str: str) -> str:
+    """确保上下文参数可以安全地序列化为JSON"""
+    if not context_str:
+        return ""
+
+    # 移除ECharts代码块,这些对上下文没有帮助且容易导致格式错误
+    import re
+
+    # 移除 ```echarts 代码块
+    cleaned = re.sub(r"```echarts[\s\S]*?```", "", context_str)
+    # 移除其他代码块
+    cleaned = re.sub(r"```[\s\S]*?```", "", cleaned)
+
+    # 清理可能导致JSON解析问题的字符
+    cleaned = str(cleaned).replace('"', "'").replace("\\", "/").strip()
+    # 移除多余的空行和空格
+    cleaned = re.sub(r"\n\s*\n", "\n", cleaned)  # 移除连续空行
+    cleaned = re.sub(r" +", " ", cleaned)  # 合并多个空格
+
+    # 限制上下文长度,避免过长导致序列化问题
+    max_length = 2000
+    if len(cleaned) > max_length:
+        cleaned = cleaned[:max_length] + "... [上下文已截断]"
+
+    return cleaned.strip()