Langchain-Chatchat/libs/chatchat-server/chatchat/server/api_server/chat_routes.py

251 lines
9.3 KiB
Python
Raw Normal View History

2024-12-20 16:04:03 +08:00
from __future__ import annotations
from typing import Dict, List
from fastapi import APIRouter, Request
from langchain.prompts.prompt import PromptTemplate
from sse_starlette import EventSourceResponse
from chatchat.server.api_server.api_schemas import OpenAIChatInput
from chatchat.server.chat.chat import chat
from chatchat.server.chat.kb_chat import kb_chat
from chatchat.server.chat.feedback import chat_feedback
from chatchat.server.chat.file_chat import file_chat
from chatchat.server.db.repository import add_message_to_db
from chatchat.server.utils import (
get_OpenAIClient,
get_prompt_template,
get_tool,
get_tool_config,
)
from chatchat.settings import Settings
from chatchat.utils import build_logger
from .openai_routes import openai_request, OpenAIChatOutput
2024-12-29 15:59:16 +08:00
# FastAPI 路由处理器,处理与对话相关的请求
2024-12-20 16:04:03 +08:00
logger = build_logger()
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
# chat_router.post(
# "/chat",
# summary="与llm模型对话(通过LLMChain)",
# )(chat)
chat_router.post(
"/feedback",
summary="返回llm模型对话评分",
)(chat_feedback)
chat_router.post("/kb_chat", summary="知识库对话")(kb_chat)
chat_router.post("/file_chat", summary="文件对话")(file_chat)
2025-01-05 18:31:03 +08:00
#本来想增加存LLM指定模版的对话后来发现框架已经提供
# chat_router.post("/llm_chat", summary="llm对话")(completion)
2024-12-20 16:04:03 +08:00
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
async def chat_completions(
request: Request,
body: OpenAIChatInput,
) -> Dict:
"""
请求参数与 openai.chat.completions.create 一致可以通过 extra_body 传入额外参数
tools tool_choice 可以直接传工具名称会根据项目里包含的 tools 进行转换
通过不同的参数组合调用不同的 chat 功能
- tool_choice
- extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input)
- extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice
- tools: agent 对话
- 其它LLM 对话
以后还要考虑其它的组合如文件对话
返回与 openai 兼容的 Dict
"""
# import rich
# rich.print(body)
# 当调用本接口且 body 中没有传入 "max_tokens" 参数时, 默认使用配置中定义的值
2025-01-05 18:31:03 +08:00
logger.info(f"body.model_config:{body.model_config},body.tools: {body.tools},body.messages:{body.messages}")
2024-12-20 16:04:03 +08:00
if body.max_tokens in [None, 0]:
body.max_tokens = Settings.model_settings.MAX_TOKENS
2025-01-05 18:31:03 +08:00
import time
current_time = time.time()
logger.info(f"current time:{time.time()}")
2024-12-20 16:04:03 +08:00
client = get_OpenAIClient(model_name=body.model, is_async=True)
2025-01-05 18:31:03 +08:00
logger.info(f"")
end_time = time.time()
logger.info(f"get_OpenAIClient takes{end_time - current_time}")
2024-12-20 16:04:03 +08:00
extra = {**body.model_extra} or {}
for key in list(extra):
delattr(body, key)
# check tools & tool_choice in request body
if isinstance(body.tool_choice, str):
2025-01-05 18:31:03 +08:00
logger.info(f"isinstance(body.tool_choice, str)")
2024-12-20 16:04:03 +08:00
if t := get_tool(body.tool_choice):
2024-12-29 15:59:16 +08:00
logger.info(f"function name: {t.name}")
2024-12-20 16:04:03 +08:00
body.tool_choice = {"function": {"name": t.name}, "type": "function"}
if isinstance(body.tools, list):
2025-01-05 18:31:03 +08:00
logger.info(f"isinstance(body.tools, list)")
logger.info(f"body.tools:{body.tools}")
2024-12-20 16:04:03 +08:00
for i in range(len(body.tools)):
if isinstance(body.tools[i], str):
if t := get_tool(body.tools[i]):
2024-12-29 15:59:16 +08:00
logger.info(f"function name: {t.name}")
2025-01-05 18:31:03 +08:00
logger.info(f"parameters: {t.args}")
2024-12-20 16:04:03 +08:00
body.tools[i] = {
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.args,
},
}
2025-01-05 18:31:03 +08:00
# "14c68cfa65ab4d7b8f5dba05f20f4eec"#
2024-12-20 16:04:03 +08:00
conversation_id = extra.get("conversation_id")
2025-01-05 18:31:03 +08:00
logger.info(f"conversation_id:{conversation_id}")
2024-12-20 16:04:03 +08:00
# chat based on result from one choiced tool
if body.tool_choice:
2025-01-05 18:31:03 +08:00
logger.info(f"if body.tool_choice:{body.tool_choice:}")
2024-12-20 16:04:03 +08:00
tool = get_tool(body.tool_choice["function"]["name"])
if not body.tools:
2025-01-05 18:31:03 +08:00
logger.info(f"if not body.tools")
2024-12-20 16:04:03 +08:00
body.tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.args,
},
}
]
if tool_input := extra.get("tool_input"):
2025-01-05 18:31:03 +08:00
logger.info(f"if tool_input := extra.get")
2024-12-20 16:04:03 +08:00
try:
message_id = (
add_message_to_db(
chat_type="tool_call",
query=body.messages[-1]["content"],
conversation_id=conversation_id,
)
if conversation_id
else None
)
except Exception as e:
logger.warning(f"failed to add message to db: {e}")
message_id = None
tool_result = await tool.ainvoke(tool_input)
2025-01-05 18:31:03 +08:00
logger.info(f"tool result: {tool_result}")
2024-12-20 16:04:03 +08:00
prompt_template = PromptTemplate.from_template(
get_prompt_template("rag", "default"), template_format="jinja2"
)
body.messages[-1]["content"] = prompt_template.format(
context=tool_result, question=body.messages[-1]["content"]
)
del body.tools
del body.tool_choice
extra_json = {
"message_id": message_id,
"status": None,
"model": body.model,
}
header = [
{
**extra_json,
"content": f"{tool_result}",
"tool_call": tool.get_name(),
"tool_output": tool_result.data,
"is_ref": False if tool.return_direct else True,
}
]
if tool.return_direct:
def temp_gen():
yield OpenAIChatOutput(**header[0]).model_dump_json()
return EventSourceResponse(temp_gen())
else:
return await openai_request(
client.chat.completions.create,
body,
extra_json=extra_json,
header=header,
)
# agent chat with tool calls
if body.tools:
2025-01-05 18:31:03 +08:00
logger.info(f"if body.tools:{body.tools:}")
2024-12-20 16:04:03 +08:00
try:
2025-01-05 18:31:03 +08:00
message_str = body.messages[-1]["content"]
logger.info(f"message_str:{message_str}")
2024-12-20 16:04:03 +08:00
message_id = (
add_message_to_db(
chat_type="agent_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id,
)
if conversation_id
else None
)
2025-01-05 18:31:03 +08:00
logger.info(f"message_id:{message_id}")
2024-12-20 16:04:03 +08:00
except Exception as e:
logger.warning(f"failed to add message to db: {e}")
message_id = None
chat_model_config = {} # TODO: 前端支持配置模型
tool_names = [x["function"]["name"] for x in body.tools]
tool_config = {name: get_tool_config(name) for name in tool_names}
result = await chat(
query=body.messages[-1]["content"],
metadata=extra.get("metadata", {}),
conversation_id=extra.get("conversation_id", ""),
message_id=message_id,
history_len=-1,
history=body.messages[:-1],
stream=body.stream,
chat_model_config=extra.get("chat_model_config", chat_model_config),
tool_config=extra.get("tool_config", tool_config),
max_tokens=body.max_tokens,
)
2025-01-05 18:31:03 +08:00
logger.info(f"chat result: {result}")
2024-12-20 16:04:03 +08:00
return result
else: # LLM chat directly
2025-01-05 18:31:03 +08:00
logger.info(f"LLM chat directly")
try: # query is complex object that unable add to db when using qwen-vl-chat
logger.info(f"conversation_id:{conversation_id}")
2024-12-20 16:04:03 +08:00
message_id = (
add_message_to_db(
chat_type="llm_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id,
)
if conversation_id
else None
)
2025-01-05 18:31:03 +08:00
logger.info(f"message_id:{message_id}")
2024-12-20 16:04:03 +08:00
except Exception as e:
logger.warning(f"failed to add message to db: {e}")
message_id = None
2025-01-05 18:31:03 +08:00
logger.info(f"**message_id:{message_id}")
2024-12-20 16:04:03 +08:00
extra_json = {
"message_id": message_id,
"status": None,
}
2025-01-05 18:31:03 +08:00
logger.info(f"extra_json:{extra_json}")
if message_id is not None:
logger.info(f"message_id is not None")
return await openai_request(
client.chat.completions.create, body, extra_json=extra_json
)
else:
logger.info(f"message_id is None")
return await openai_request(
client.chat.completions.create, body)