from __future__ import annotations from typing import Dict, List from fastapi import APIRouter, Request from langchain.prompts.prompt import PromptTemplate from sse_starlette import EventSourceResponse from chatchat.server.api_server.api_schemas import OpenAIChatInput from chatchat.server.chat.chat import chat from chatchat.server.chat.kb_chat import kb_chat from chatchat.server.chat.feedback import chat_feedback from chatchat.server.chat.file_chat import file_chat from chatchat.server.db.repository import add_message_to_db from chatchat.server.utils import ( get_OpenAIClient, get_prompt_template, get_tool, get_tool_config, ) from chatchat.settings import Settings from chatchat.utils import build_logger from .openai_routes import openai_request, OpenAIChatOutput # FastAPI 路由处理器,处理与对话相关的请求 logger = build_logger() chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"]) # chat_router.post( # "/chat", # summary="与llm模型对话(通过LLMChain)", # )(chat) chat_router.post( "/feedback", summary="返回llm模型对话评分", )(chat_feedback) chat_router.post("/kb_chat", summary="知识库对话")(kb_chat) chat_router.post("/file_chat", summary="文件对话")(file_chat) #本来想增加存LLM指定模版的对话,后来发现框架已经提供 # chat_router.post("/llm_chat", summary="llm对话")(completion) @chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口") async def chat_completions( request: Request, body: OpenAIChatInput, ) -> Dict: """ 请求参数与 openai.chat.completions.create 一致,可以通过 extra_body 传入额外参数 tools 和 tool_choice 可以直接传工具名称,会根据项目里包含的 tools 进行转换 通过不同的参数组合调用不同的 chat 功能: - tool_choice - extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input) - extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice - tools: agent 对话 - 其它:LLM 对话 以后还要考虑其它的组合(如文件对话) 返回与 openai 兼容的 Dict """ # import rich # rich.print(body) # 当调用本接口且 body 中没有传入 "max_tokens" 参数时, 默认使用配置中定义的值 # logger.info(f"body.model_config:{body.model_config},body.tools: {body.tools},body.messages:{body.messages}") if body.max_tokens in [None, 0]: body.max_tokens = Settings.model_settings.MAX_TOKENS import time current_time = time.time() logger.info(f"current time:{time.time()}") client = get_OpenAIClient(model_name=body.model, is_async=True) logger.info(f"") end_time = time.time() logger.info(f"get_OpenAIClient takes{end_time - current_time}") extra = {**body.model_extra} or {} for key in list(extra): delattr(body, key) # check tools & tool_choice in request body if isinstance(body.tool_choice, str): logger.info(f"isinstance(body.tool_choice, str)") if t := get_tool(body.tool_choice): logger.info(f"function name: {t.name}") body.tool_choice = {"function": {"name": t.name}, "type": "function"} if isinstance(body.tools, list): logger.info(f"isinstance(body.tools, list)") logger.info(f"body.tools:{body.tools}") for i in range(len(body.tools)): if isinstance(body.tools[i], str): if t := get_tool(body.tools[i]): logger.info(f"function name: {t.name}") logger.info(f"parameters: {t.args}") body.tools[i] = { "type": "function", "function": { "name": t.name, "description": t.description, "parameters": t.args, }, } # "14c68cfa65ab4d7b8f5dba05f20f4eec"# conversation_id = extra.get("conversation_id") logger.info(f"conversation_id:{conversation_id}") # chat based on result from one choiced tool if body.tool_choice: logger.info(f"if body.tool_choice:{body.tool_choice:}") tool = get_tool(body.tool_choice["function"]["name"]) if not body.tools: logger.info(f"if not body.tools") body.tools = [ { "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": tool.args, }, } ] if tool_input := extra.get("tool_input"): logger.info(f"if tool_input := extra.get") try: message_id = ( add_message_to_db( chat_type="tool_call", query=body.messages[-1]["content"], conversation_id=conversation_id, ) if conversation_id else None ) except Exception as e: logger.warning(f"failed to add message to db: {e}") message_id = None tool_result = await tool.ainvoke(tool_input) logger.info(f"tool result: {tool_result}") prompt_template = PromptTemplate.from_template( get_prompt_template("rag", "default"), template_format="jinja2" ) body.messages[-1]["content"] = prompt_template.format( context=tool_result, question=body.messages[-1]["content"] ) del body.tools del body.tool_choice extra_json = { "message_id": message_id, "status": None, "model": body.model, } header = [ { **extra_json, "content": f"{tool_result}", "tool_call": tool.get_name(), "tool_output": tool_result.data, "is_ref": False if tool.return_direct else True, } ] if tool.return_direct: def temp_gen(): yield OpenAIChatOutput(**header[0]).model_dump_json() return EventSourceResponse(temp_gen()) else: return await openai_request( client.chat.completions.create, body, extra_json=extra_json, header=header, ) # agent chat with tool calls if body.tools: logger.info(f"if body.tools:{body.tools:}") try: message_str = body.messages[-1]["content"] logger.info(f"message_str:{message_str}") message_id = ( add_message_to_db( chat_type="agent_chat", query=body.messages[-1]["content"], conversation_id=conversation_id, ) if conversation_id else None ) logger.info(f"message_id:{message_id}") except Exception as e: logger.warning(f"failed to add message to db: {e}") message_id = None chat_model_config = {} # TODO: 前端支持配置模型 tool_names = [x["function"]["name"] for x in body.tools] tool_config = {name: get_tool_config(name) for name in tool_names} result = await chat( query=body.messages[-1]["content"], metadata=extra.get("metadata", {}), conversation_id=extra.get("conversation_id", ""), message_id=message_id, history_len=-1, history=body.messages[:-1], stream=body.stream, chat_model_config=extra.get("chat_model_config", chat_model_config), tool_config=extra.get("tool_config", tool_config), max_tokens=body.max_tokens, ) logger.info(f"chat result: {result}") return result else: # LLM chat directly logger.info(f"LLM chat directly") try: # query is complex object that unable add to db when using qwen-vl-chat logger.info(f"conversation_id:{conversation_id}") message_id = ( add_message_to_db( chat_type="llm_chat", query=body.messages[-1]["content"], conversation_id=conversation_id, ) if conversation_id else None ) logger.info(f"message_id:{message_id}") except Exception as e: logger.warning(f"failed to add message to db: {e}") message_id = None logger.info(f"**message_id:{message_id}") extra_json = { "message_id": message_id, "status": None, } logger.info(f"extra_json:{extra_json}") if message_id is not None: logger.info(f"message_id is not None") return await openai_request( client.chat.completions.create, body, extra_json=extra_json ) else: logger.info(f"message_id is None") return await openai_request( client.chat.completions.create, body)