Langchain-Chatchat/libs/chatchat-server/chatchat/server/api_server/chat_routes.py

216 lines
7.6 KiB
Python
Raw Normal View History

2024-12-20 16:04:03 +08:00
from __future__ import annotations
from typing import Dict, List
from fastapi import APIRouter, Request
from langchain.prompts.prompt import PromptTemplate
from sse_starlette import EventSourceResponse
from chatchat.server.api_server.api_schemas import OpenAIChatInput
from chatchat.server.chat.chat import chat
from chatchat.server.chat.kb_chat import kb_chat
from chatchat.server.chat.feedback import chat_feedback
from chatchat.server.chat.file_chat import file_chat
from chatchat.server.db.repository import add_message_to_db
from chatchat.server.utils import (
get_OpenAIClient,
get_prompt_template,
get_tool,
get_tool_config,
)
from chatchat.settings import Settings
from chatchat.utils import build_logger
from .openai_routes import openai_request, OpenAIChatOutput
2024-12-29 15:59:16 +08:00
# FastAPI 路由处理器,处理与对话相关的请求
2024-12-20 16:04:03 +08:00
logger = build_logger()
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
# chat_router.post(
# "/chat",
# summary="与llm模型对话(通过LLMChain)",
# )(chat)
chat_router.post(
"/feedback",
summary="返回llm模型对话评分",
)(chat_feedback)
chat_router.post("/kb_chat", summary="知识库对话")(kb_chat)
chat_router.post("/file_chat", summary="文件对话")(file_chat)
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
async def chat_completions(
request: Request,
body: OpenAIChatInput,
) -> Dict:
"""
请求参数与 openai.chat.completions.create 一致可以通过 extra_body 传入额外参数
tools tool_choice 可以直接传工具名称会根据项目里包含的 tools 进行转换
通过不同的参数组合调用不同的 chat 功能
- tool_choice
- extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input)
- extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice
- tools: agent 对话
- 其它LLM 对话
以后还要考虑其它的组合如文件对话
返回与 openai 兼容的 Dict
"""
# import rich
# rich.print(body)
# 当调用本接口且 body 中没有传入 "max_tokens" 参数时, 默认使用配置中定义的值
if body.max_tokens in [None, 0]:
body.max_tokens = Settings.model_settings.MAX_TOKENS
client = get_OpenAIClient(model_name=body.model, is_async=True)
extra = {**body.model_extra} or {}
for key in list(extra):
delattr(body, key)
# check tools & tool_choice in request body
if isinstance(body.tool_choice, str):
2024-12-29 15:59:16 +08:00
logger.info(f"tool_choice")
2024-12-20 16:04:03 +08:00
if t := get_tool(body.tool_choice):
2024-12-29 15:59:16 +08:00
logger.info(f"function name: {t.name}")
2024-12-20 16:04:03 +08:00
body.tool_choice = {"function": {"name": t.name}, "type": "function"}
if isinstance(body.tools, list):
2024-12-29 15:59:16 +08:00
logger.info(f"tools")
2024-12-20 16:04:03 +08:00
for i in range(len(body.tools)):
if isinstance(body.tools[i], str):
if t := get_tool(body.tools[i]):
2024-12-29 15:59:16 +08:00
logger.info(f"function name: {t.name}")
2024-12-20 16:04:03 +08:00
body.tools[i] = {
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.args,
},
}
conversation_id = extra.get("conversation_id")
# chat based on result from one choiced tool
if body.tool_choice:
tool = get_tool(body.tool_choice["function"]["name"])
if not body.tools:
body.tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.args,
},
}
]
if tool_input := extra.get("tool_input"):
try:
message_id = (
add_message_to_db(
chat_type="tool_call",
query=body.messages[-1]["content"],
conversation_id=conversation_id,
)
if conversation_id
else None
)
except Exception as e:
logger.warning(f"failed to add message to db: {e}")
message_id = None
tool_result = await tool.ainvoke(tool_input)
prompt_template = PromptTemplate.from_template(
get_prompt_template("rag", "default"), template_format="jinja2"
)
body.messages[-1]["content"] = prompt_template.format(
context=tool_result, question=body.messages[-1]["content"]
)
del body.tools
del body.tool_choice
extra_json = {
"message_id": message_id,
"status": None,
"model": body.model,
}
header = [
{
**extra_json,
"content": f"{tool_result}",
"tool_call": tool.get_name(),
"tool_output": tool_result.data,
"is_ref": False if tool.return_direct else True,
}
]
if tool.return_direct:
def temp_gen():
yield OpenAIChatOutput(**header[0]).model_dump_json()
return EventSourceResponse(temp_gen())
else:
return await openai_request(
client.chat.completions.create,
body,
extra_json=extra_json,
header=header,
)
# agent chat with tool calls
if body.tools:
try:
message_id = (
add_message_to_db(
chat_type="agent_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id,
)
if conversation_id
else None
)
except Exception as e:
logger.warning(f"failed to add message to db: {e}")
message_id = None
chat_model_config = {} # TODO: 前端支持配置模型
tool_names = [x["function"]["name"] for x in body.tools]
tool_config = {name: get_tool_config(name) for name in tool_names}
result = await chat(
query=body.messages[-1]["content"],
metadata=extra.get("metadata", {}),
conversation_id=extra.get("conversation_id", ""),
message_id=message_id,
history_len=-1,
history=body.messages[:-1],
stream=body.stream,
chat_model_config=extra.get("chat_model_config", chat_model_config),
tool_config=extra.get("tool_config", tool_config),
max_tokens=body.max_tokens,
)
return result
else: # LLM chat directly
try: # query is complex object that unable add to db when using qwen-vl-chat
message_id = (
add_message_to_db(
chat_type="llm_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id,
)
if conversation_id
else None
)
except Exception as e:
logger.warning(f"failed to add message to db: {e}")
message_id = None
extra_json = {
"message_id": message_id,
"status": None,
}
return await openai_request(
client.chat.completions.create, body, extra_json=extra_json
)