agent.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import os
  2. import dotenv
  3. import datetime
  4. from pathlib import Path
  5. from langchain.agents import create_agent, AgentState
  6. from langchain_openai import ChatOpenAI
  7. from langchain_core.messages import (
  8. SystemMessage,
  9. HumanMessage,
  10. BaseMessage,
  11. trim_messages,
  12. )
  13. from openai import chat
  14. from tools.tool_factory import get_all_tools
  15. from langchain_core.runnables import RunnableConfig
  16. from langchain.agents.middleware import before_model
  17. from langgraph.runtime import Runtime
  18. from typing import Any, List, Sequence
  19. from langchain.messages import RemoveMessage
  20. from langgraph.graph.message import REMOVE_ALL_MESSAGES
  21. import sqlite3
  22. from config.settings import settings
  23. from langchain_core.messages.utils import count_tokens_approximately
  24. from core.worker_manager import get_worker_tools
  25. from utils.context_helper import safe_context_param
  26. from utils.logger import chat_logger
  27. dotenv.load_dotenv()
  28. def create_system_prompt(
  29. backend_url: str = "", token: str = "", username: str = "default", context: str = ""
  30. ) -> str:
  31. # auth_status = "已认证" if token else "未认证"
  32. # backend_available = "API可用" if backend_url and token else "仅数据查询"
  33. # knowledge_status = (
  34. # "知识库可用" if settings.KNOWLEDGE_BASE_ENABLED else "知识库已禁用"
  35. # )
  36. # echart_status = "图表可用" if settings.ECHARTS_ENABLED else "图表已禁用"
  37. # if settings.KNOWLEDGE_BASE_ENABLED:
  38. # # 知识库启用时的提示词
  39. # system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status} 图表:{echart_status}
  40. # 职责:ERP数据查询和问题解答,按用户语言回答。
  41. # **核心安全指令 (必遵)**:
  42. # 1. **当前凭据 (每次工具调用必须使用)**:
  43. # - 后端地址: {backend_url if backend_url else '无'}
  44. # - API令牌: {token if token else '无'}
  45. # 2. **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
  46. # 3. **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
  47. # 工作流:
  48. # 1. 分析问题意图,提取模块关键词
  49. # 2. 如果是数据查询类问题,直接调用相关工具查询数据
  50. # 3. 如果是其他问题,则通过工具搜索知识库,知识库工具使用流程:a.通过关键字获取相关文章列表,b.判断哪些文章最符合,c.再通过工具获取文章内容.严格按文章内容回复,不能编造答案.
  51. # 4. 关键词要精准,避免无意义词
  52. # 工具调用规格:
  53. # - 如果连续3次调用相同工具相同参数,自动停止
  54. # - 工具返回相同结果但仍在重复调用时,自动停止
  55. # 回答规则:
  56. # - 知识库找不到时提示"正在学习该问题"
  57. # - {"需要个人数据时验证认证状态" if backend_url else "仅提供数据查询和知识库支持"}
  58. # - 保护隐私,专业准确,精炼简要
  59. # 时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
  60. # 数据查询结果尽量以 Markdown 表格格式输出,格式如下:
  61. # | 列名1 | 列名2 | 列名3 |
  62. # | :--- | :--- | :--- |
  63. # | 数据1 | 数据2 | 数据3 |
  64. # | 数据4 | 数据5 | 数据6 |
  65. # """
  66. # else:
  67. # # 知识库禁用时的提示词 - 灵活处理工具返回结果
  68. # system_prompt = f"""龙嘉软件助手- 用户:{username} 认证:{auth_status} 服务:{backend_available} 知识库:{knowledge_status} 图表:{echart_status}
  69. # 职责:处理ERP数据查询类问题,按用户语言回答。
  70. # **核心安全指令 (必遵)**:
  71. # 1. **当前凭据 (每次工具调用必须使用)**:
  72. # - 后端地址: {backend_url if backend_url else '无'}
  73. # - API令牌: {token if token else '无'}
  74. # 2. **禁止沿用历史**:**严禁**从对话历史中复制、沿用任何旧的后端地址、令牌或工具参数。历史记录仅用于理解背景,其中的工具详情**不能**作为本次调用的参数来源。
  75. # 3. **调用规范**:调用查询工具时,**必须且只能**使用上方提供的当前凭据。
  76. # 工作流:
  77. # 1. 分析问题意图,判断是否为数据查询类问题
  78. # 2. 如果是数据查询类问题,直接调用相关工具查询数据
  79. # 3. 根据工具返回的结果进行回答:
  80. # - 如果工具返回了具体数据,按数据内容回答
  81. # - 如果工具返回了错误信息(如"API返回错误","查询失败","没有权限"等),如实告知用户错误信息
  82. # - 如果工具返回空数据或"未找到数据",如实告知用户
  83. # 4. 如果是非数据查询类问题(如疑问、流程、操作等),回复:"知识库正在完善,无法回答该问题"
  84. # 工具调用规格:
  85. # - 禁止连续调用相同工具相同参数
  86. # - 工具返回相同结果但仍在重复调用时,自动停止
  87. # 回答规则:
  88. # - 如用户提出非ERP范围的问题(例如:"你好"等闲聊),明确告知用户自己的职责:主要处理ERP数据查询类问题
  89. # - 工具提示没有权限时,明确回复用户没有权限
  90. # - 严格按工具返回的内容回答,不能编造答案,可对结果进行简单总结
  91. # - 当工具返回错误信息时,如实转达给用户,不要添加额外解释
  92. # - 保持专业、准确、简洁的回答风格
  93. # {"- 需要个人数据时验证认证状态" if backend_url else "- 仅提供数据查询支持"}
  94. # 当前时间:{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
  95. # 数据查询结果尽量以 Markdown 表格格式输出,格式如下:
  96. # | 列名1 | 列名2 | 列名3 |
  97. # | :--- | :--- | :--- |
  98. # | 数据1 | 数据2 | 数据3 |
  99. # | 数据4 | 数据5 | 数据6 |
  100. # """
  101. context = safe_context_param(context)
  102. system_prompt = f"""
  103. # 龙嘉AI助手-多Agent协调系统
  104. ## 用户信息
  105. - 用户名: {username}
  106. ##上下文
  107. 上下文长度: {len(context)}
  108. ## 严格行为规则
  109. 你是一个调度器,**不是回答者**。你的唯一职责是决定是否调用工具。
  110. 每轮对话只能调用一次工具,不能连续调用。
  111. ## 决策规则
  112. - 就算提供了上下文,你也**必须调用**合适的工具,不能直接回答用户问题。
  113. - 如果无法判断调用那个工具,引导用户提供更多信息。
  114. ## 关键行为约束
  115. ### 当调用工具时:
  116. 1. 你**必须**调用合适的工具
  117. 2. 调用后**立即停止**,只能输出**"工具调用成功"**
  118. ## 重要警告
  119. - 调用工具后,**不要**基于工具返回的结果继续生成回答
  120. - 系统会自动将工具结果返回给用户
  121. - 你在调用工具后,输出"工具调用成功",你的任务就结束了
  122. 当需要工具时
  123. - **必须**调用合适的Worker工具
  124. - **必须传递以下4个参数**:
  125. 1. query: 用户问题
  126. 2. backend_url: {backend_url if backend_url else '无'}
  127. 3. token: {token if token else '无'}
  128. 4. context: {f"###对话上下文开始###\n{context}\n###对话上下文结束###" if context else '无'}
  129. - **输出格式**:必须且只能输出:`工具调用成功`
  130. ## 零容忍规则
  131. **严禁参数错误**:每次调用都必须使用当前提供的backend_url/token,不能使用历史值
  132. **严禁猜测**:如果不能确定,就调用工具
  133. ## 当前状态
  134. - 时间: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
  135. """
  136. return system_prompt
  137. def get_day_number(date=None):
  138. """获取日期编号 (YYYYMMDD 格式)"""
  139. if date is None:
  140. date = datetime.datetime.now()
  141. return date.strftime("%Y%m%d") # 格式: 20251229
  142. def get_sqlite_checkpointer():
  143. """创建按天分割的SQLite检查点保存器"""
  144. try:
  145. from langgraph.checkpoint.sqlite import SqliteSaver
  146. # 获取当前日期编号
  147. current_day = get_day_number()
  148. # 数据库文件存放目录
  149. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  150. base_dir = os.path.join(project_root, "data", "checkpoints")
  151. os.makedirs(base_dir, exist_ok=True)
  152. # 数据库文件名格式: checkpoints_20251229.db
  153. db_filename = f"checkpoints_{current_day}.db"
  154. db_path = os.path.join(base_dir, db_filename)
  155. # checkpointer = SqliteSaver.from_conn_string(db_path)
  156. conn = sqlite3.connect(db_path, check_same_thread=False)
  157. conn.execute("PRAGMA wal_autocheckpoint=500") # 2MB 就提交
  158. conn.execute("PRAGMA journal_size_limit=52428800") # 最大 50MB
  159. checkpointer = SqliteSaver(conn)
  160. return checkpointer
  161. except Exception as e:
  162. print(f"[ERROR]创建 SQLite 检查器失败: {e}")
  163. import traceback
  164. traceback.print_exc()
  165. # 回退到内存保存器
  166. from langgraph.checkpoint.memory import InMemorySaver
  167. print("[WARN]使用 InMemorySaver 作为回退")
  168. return InMemorySaver()
  169. def cleanup_old_checkpoints(max_days=7):
  170. """清理超过指定天数的旧检查点文件(可选功能)"""
  171. try:
  172. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  173. base_dir = os.path.join(project_root, "data", "checkpoints")
  174. if not os.path.exists(base_dir):
  175. return
  176. # 获取当前日期
  177. current_date = datetime.datetime.now()
  178. # 遍历目录中的所有.db文件
  179. for filename in os.listdir(base_dir):
  180. if filename.startswith("checkpoints_") and filename.endswith(".db"):
  181. try:
  182. print(f"检查旧检查点文件: {filename}")
  183. # 提取日期 (checkpoints_day_20251229.db -> 20251229)
  184. date_str = filename.replace("checkpoints_day_", "").replace(
  185. ".db", ""
  186. )
  187. file_date = datetime.datetime.strptime(date_str, "%Y%m%d")
  188. # 计算天数差
  189. days_diff = (current_date - file_date).days
  190. # 删除超过 max_days 天的旧数据
  191. if days_diff > max_days:
  192. file_path = os.path.join(base_dir, filename)
  193. os.remove(file_path)
  194. print(
  195. f"[CLEAN]清理旧检查点文件: {filename} (超过 {max_days} 天)"
  196. )
  197. except (ValueError, IndexError):
  198. # 文件名不符合预期,跳过
  199. continue
  200. except Exception as e:
  201. print(f"[WARN]清理旧检查点失败: {e}")
  202. # 创建agent
  203. def create_langchain_agent(
  204. backend_url: str = "",
  205. token: str = "",
  206. username: str = "default",
  207. thread_id: str = "default",
  208. context: str = "",
  209. ):
  210. llm = ChatOpenAI(
  211. model=settings.LLM_MODEL,
  212. temperature=settings.LLM_TEMPERATURE,
  213. api_key=settings.DEEPSEEK_API_KEY,
  214. base_url=settings.DEEPSEEK_BASE_URL,
  215. max_tokens=settings.LLM_MAX_TOKENS,
  216. )
  217. tools = get_worker_tools()
  218. # 添加调试信息
  219. # print(f"[DEBUG]Agent 创建调试信息:")
  220. # print(f" - 用户: {username}")
  221. # print(f" - Thread ID: {thread_id}")
  222. # print(f" - 后端地址: {backend_url}")
  223. # print(f" - Token: {'已提供' if token else '未提供'}")
  224. # print(f" - worker数量: {len(tools)}")
  225. # for i, tool in enumerate(tools):
  226. # print(f" - worker {i+1}: {tool.name}")
  227. # 获取动态的system_prompt
  228. system_prompt = create_system_prompt(backend_url, token, username, context)
  229. # print(f"[DEBUG]上下文长度: {len(context)}")
  230. # print(system_prompt)
  231. # chat_logger.info(f"主Agent System Prompt上下文: {system_prompt}")
  232. @before_model
  233. def trim_messages_middleware(
  234. state: AgentState, runtime: Runtime
  235. ) -> dict[str, Any] | None:
  236. """使用官方trim_messages函数修剪消息"""
  237. messages = state.get("messages", [])
  238. print(f"trim_messages_middleware[DEBUG]原始消息数: {len(messages)}")
  239. # if len(messages) <= 3:
  240. # return None # 不需要修剪
  241. trimmed_messages = trim_messages(
  242. messages,
  243. max_tokens=500,
  244. strategy="last", # 保留最近的对话
  245. token_counter=count_tokens_approximately, # token计数器
  246. start_on="human", # 从human消息开始计算轮次
  247. include_system=True, # 包含系统消息
  248. )
  249. # 添加调试信息
  250. original_count = len(messages)
  251. trimmed_count = len(trimmed_messages)
  252. print(f"trim_messages_middleware[DEBUG]修剪后消息数: {trimmed_count}")
  253. if trimmed_count < original_count:
  254. print(f"[INFO]消息修剪: {original_count} -> {trimmed_count} 条消息")
  255. return {"messages": trimmed_messages}
  256. # 使用SQLiteSaver(按天分割)
  257. checkpointer = get_sqlite_checkpointer()
  258. # print(f"打印检查点保存器: {checkpointer}")
  259. # 可选:清理旧检查点(可配置为定期执行)
  260. if os.getenv("AUTO_CLEANUP", "false").lower() == "true":
  261. cleanup_old_checkpoints(max_days=7) # 保留最近7天数据
  262. # agent = create_agent(
  263. # llm,
  264. # tools,
  265. # checkpointer=checkpointer,
  266. # system_prompt=system_prompt,
  267. # middleware=[trim_messages_middleware],
  268. # )
  269. agent = create_agent(
  270. llm,
  271. tools,
  272. system_prompt=system_prompt,
  273. )
  274. return agent