async_chat_service.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. import asyncio
  2. import time
  3. from typing import Dict, Any
  4. from langchain_core.messages import HumanMessage, ToolMessage
  5. from openai import chat
  6. from utils.logger import chat_logger, log_chat_entry
  7. from core.agent_manager import agent_manager
  8. from core.chat_result_manager import chat_result_manager
  9. class AsyncChatService:
  10. """异步聊天服务 - 支持轮询方式的版本"""
  11. def __init__(self):
  12. self.agent_manager = agent_manager
  13. self._thread_pool = None
  14. self._processing_tasks = {} # 正在处理的任务缓存
  15. def _get_thread_pool(self):
  16. """获取或创建线程池"""
  17. if self._thread_pool is None:
  18. import concurrent.futures
  19. self._thread_pool = concurrent.futures.ThreadPoolExecutor(
  20. max_workers=20,
  21. thread_name_prefix="async_chat_worker",
  22. )
  23. return self._thread_pool
  24. async def submit_chat_task(self, request_data: Dict[str, Any]) -> str:
  25. """提交聊天任务(立即返回任务ID)"""
  26. username = request_data["username"]
  27. # 创建任务记录
  28. task_id = chat_result_manager.create_task(request_data)
  29. chat_logger.info(f"用户{username},已提交聊天任务: {task_id}")
  30. # 异步执行任务
  31. asyncio.create_task(self._process_chat_task(task_id, request_data))
  32. return task_id
  33. async def _process_chat_task(self, task_id: str, request_data: Dict[str, Any]):
  34. """异步处理聊天任务"""
  35. try:
  36. # 更新状态为处理中
  37. chat_result_manager.update_task_status(task_id, "processing")
  38. # 提取请求数据
  39. message = request_data["message"]
  40. thread_id = request_data["thread_id"]
  41. username = request_data["username"]
  42. backend_url = request_data["backend_url"]
  43. token = request_data["token"]
  44. user_id = username
  45. chat_logger.info(
  46. f"开始处理任务 - 任务ID={task_id}, 用户={user_id},问题={message}"
  47. )
  48. # 获取对话上下文
  49. context = context_manager.get_recent_context(thread_id)
  50. if context:
  51. chat_logger.info(
  52. f"获取到对话上下文 - 线程={thread_id}, 上下文长度={len(context)}"
  53. )
  54. # 异步获取agent实例
  55. agent = await self.agent_manager.get_agent_instance(
  56. thread_id=thread_id,
  57. username=username,
  58. backend_url=backend_url,
  59. token=token,
  60. context=context,
  61. )
  62. # 在线程池中执行同步的Langchain操作
  63. result = await self._run_agent_in_threadpool(
  64. agent, message, thread_id, user_id
  65. )
  66. # chat_logger.info(f"主Agent返回结果: {result}")
  67. print("主Agent返回结果:", result)
  68. if not isinstance(result, dict) or "messages" not in result:
  69. raise ValueError(f"Agent返回格式异常: {type(result)}")
  70. # 处理结果
  71. response_data = self._process_agent_result(result, user_id, request_data)
  72. # 更新对话上下文(只记录成功的对话)
  73. if response_data.get("success", False) and response_data.get(
  74. "final_answer"
  75. ):
  76. context_manager.update_context(
  77. thread_id, message, response_data["final_answer"]
  78. )
  79. # 更新任务状态为完成
  80. chat_result_manager.update_task_status(task_id, "completed", response_data)
  81. chat_logger.info(f"任务处理完成 - 任务ID={task_id}")
  82. except Exception as e:
  83. error_msg = f"聊天处理失败: {str(e)}"
  84. chat_logger.error(f"{error_msg} - 任务ID={task_id}")
  85. # 更新任务状态为失败
  86. chat_result_manager.update_task_status(
  87. task_id, "failed", error_message=error_msg
  88. )
  89. async def _run_agent_in_threadpool(
  90. self, agent, message: str, thread_id: str, user_id: str
  91. ):
  92. """在线程池中执行Langchain Agent"""
  93. loop = asyncio.get_event_loop()
  94. thread_pool = self._get_thread_pool()
  95. # 准备输入
  96. inputs = {"messages": [HumanMessage(content=message)]}
  97. config = {"configurable": {"thread_id": thread_id}}
  98. chat_logger.info(f"在线程池中执行Agent - 用户={user_id}")
  99. try:
  100. # 在线程池中执行同步操作
  101. result = await loop.run_in_executor(
  102. thread_pool, lambda: agent.invoke(inputs, config)
  103. )
  104. return result
  105. except Exception as e:
  106. chat_logger.error(f"Agent执行失败 - 用户={user_id}: {str(e)}")
  107. raise
  108. def _process_agent_result(
  109. self, result: Dict[str, Any], user_id: str, request_data: Dict
  110. ) -> Dict[str, Any]:
  111. """处理Agent返回结果"""
  112. all_messages = result["messages"]
  113. processed_messages = []
  114. all_ai_messages = []
  115. all_tool_calls = []
  116. final_answer = ""
  117. for i, msg in enumerate(all_messages):
  118. msg_data = {
  119. "index": i,
  120. "type": getattr(msg, "type", "unknown"),
  121. "content": "",
  122. }
  123. # 获取内容
  124. if hasattr(msg, "content"):
  125. content = msg.content
  126. if isinstance(content, str):
  127. msg_data["content"] = content
  128. else:
  129. msg_data["content"] = str(content)
  130. # 获取工具调用
  131. if hasattr(msg, "tool_calls") and msg.tool_calls:
  132. msg_data["tool_calls"] = msg.tool_calls
  133. all_tool_calls.extend(msg.tool_calls)
  134. for tool_call in msg.tool_calls:
  135. tool_name = tool_call.get("name", "unknown")
  136. tool_args = tool_call.get("args", {})
  137. chat_logger.info(f"完成工具调用 - 用户={user_id}, 工具={tool_name}")
  138. if hasattr(msg, "tool_call_id"):
  139. msg_data["tool_call_id"] = msg.tool_call_id
  140. if hasattr(msg, "name"):
  141. msg_data["name"] = msg.name
  142. processed_messages.append(msg_data)
  143. # 收集AI消息
  144. if msg_data["type"] == "ai":
  145. all_ai_messages.append(msg_data)
  146. final_answer = msg_data["content"]
  147. if final_answer == "工具调用成功":
  148. last_tool_content = None
  149. for msg in reversed(result.get("messages", [])):
  150. if isinstance(msg, ToolMessage):
  151. last_tool_content = msg.content
  152. break
  153. if last_tool_content:
  154. final_answer = last_tool_content
  155. # 构建响应
  156. response = {
  157. "final_answer": final_answer,
  158. "all_ai_messages": all_ai_messages,
  159. "all_messages": processed_messages,
  160. "tool_calls": all_tool_calls,
  161. "thread_id": request_data["thread_id"],
  162. "user_identifier": user_id,
  163. "backend_config": {
  164. "backend_url": request_data["backend_url"] or "未配置",
  165. "username": request_data["username"],
  166. "has_token": bool(request_data["token"]),
  167. },
  168. "success": True,
  169. }
  170. # 记录日志
  171. log_chat_entry(user_id, request_data["message"], response)
  172. return response
  173. async def get_task_result(self, task_id: str) -> Dict[str, Any]:
  174. """获取任务结果"""
  175. task_info = chat_result_manager.get_task(task_id)
  176. # chat_logger.info(
  177. # f"获取任务结果 - 任务ID={task_id}, 状态={task_info['status']},error_message={task_info['error_message']}"
  178. # )
  179. if not task_info:
  180. return {
  181. "success": False,
  182. "error": f"任务不存在: {task_id}",
  183. "task_id": task_id,
  184. }
  185. return {
  186. "task_id": task_id,
  187. "status": task_info["status"],
  188. "response": task_info["response_data"],
  189. "error": task_info["error_message"],
  190. "created_at": task_info["created_at"],
  191. "updated_at": task_info["updated_at"],
  192. "success": task_info["status"] == "completed",
  193. }
  194. async def shutdown(self):
  195. """关闭服务"""
  196. if self._thread_pool:
  197. self._thread_pool.shutdown(wait=False)
  198. self._thread_pool = None
  199. chat_result_manager.close()
  200. chat_logger.info("异步聊天服务已关闭")
  201. # 全局实例
  202. async_chat_service = AsyncChatService()
  203. # 在async_chat_service.py中添加上下文管理功能
  204. class ContextManager:
  205. def __init__(self, max_history=3):
  206. self.conversation_history = (
  207. {}
  208. ) # thread_id -> list of (human_message, ai_message)
  209. self.max_history = max_history
  210. def get_recent_context(self, thread_id: str) -> str:
  211. """获取最近3轮对话的上下文"""
  212. if thread_id not in self.conversation_history:
  213. return ""
  214. history = self.conversation_history[thread_id]
  215. if not history:
  216. return ""
  217. # 获取最近3轮对话
  218. recent_history = history[-self.max_history :]
  219. # 构建上下文字符串(只包含human和ai消息)
  220. context_parts = []
  221. for i, (human_msg, ai_msg) in enumerate(recent_history):
  222. context_parts.append(f"第{len(history)-len(recent_history)+i+1}轮对话:")
  223. context_parts.append(f"用户:{human_msg}")
  224. context_parts.append(f"AI:{ai_msg}")
  225. context_parts.append("") # 空行分隔
  226. return "\n".join(context_parts).strip()
  227. def update_context(self, thread_id: str, human_message: str, ai_message: str):
  228. """更新对话上下文"""
  229. if thread_id not in self.conversation_history:
  230. self.conversation_history[thread_id] = []
  231. # 添加新的对话轮次
  232. self.conversation_history[thread_id].append((human_message, ai_message))
  233. # 保持最多max_history轮对话
  234. if len(self.conversation_history[thread_id]) > self.max_history:
  235. self.conversation_history[thread_id] = self.conversation_history[thread_id][
  236. -self.max_history :
  237. ]
  238. def clear_context(self, thread_id: str):
  239. """清空特定线程的上下文"""
  240. if thread_id in self.conversation_history:
  241. self.conversation_history[thread_id] = []
  242. # 全局上下文管理器实例
  243. context_manager = ContextManager(max_history=3)