Langchain-Chatchat-0.3.1/libs/chatchat-server/chatchat/server/chat/chat.py

266 lines
9.9 KiB
Python
Raw Permalink Normal View History

2024-12-27 17:25:03 +08:00
import asyncio
import json
import time
import uuid
from typing import AsyncIterable, List
from fastapi import Body
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from langchain_core.messages import AIMessage, HumanMessage, convert_to_messages
from langchain_core.output_parsers import StrOutputParser
from sse_starlette.sse import EventSourceResponse
from chatchat.settings import Settings
from chatchat.server.agent.agent_factory.agents_registry import agents_registry
from chatchat.server.agent.container import container
from chatchat.server.api_server.api_schemas import OpenAIChatOutput
from chatchat.server.callback_handler.agent_callback_handler import (
AgentExecutorAsyncIteratorCallbackHandler,
AgentStatus,
)
from chatchat.server.chat.utils import History
from chatchat.server.memory.conversation_db_buffer_memory import (
ConversationBufferDBMemory,
)
from chatchat.server.utils import (
MsgType,
get_ChatOpenAI,
get_prompt_template,
get_tool,
wrap_done,
get_default_llm,
)
def create_models_from_config(configs, callbacks, stream, max_tokens):
configs = configs or Settings.model_settings.LLM_MODEL_CONFIG
models = {}
prompts = {}
for model_type, params in configs.items():
model_name = params.get("model", "").strip() or get_default_llm()
callbacks = callbacks if params.get("callbacks", False) else None
# 判断是否传入 max_tokens 的值, 如果传入就按传入的赋值(api 调用且赋值), 如果没有传入则按照初始化配置赋值(ui 调用或 api 调用未赋值)
max_tokens_value = max_tokens if max_tokens is not None else params.get("max_tokens", 1000)
model_instance = get_ChatOpenAI(
model_name=model_name,
temperature=params.get("temperature", 0.5),
max_tokens=max_tokens_value,
callbacks=callbacks,
streaming=stream,
local_wrap=True,
)
models[model_type] = model_instance
prompt_name = params.get("prompt_name", "default")
prompt_template = get_prompt_template(type=model_type, name=prompt_name)
prompts[model_type] = prompt_template
return models, prompts
def create_models_chains(
history, history_len, prompts, models, tools, callbacks, conversation_id, metadata
):
memory = None
chat_prompt = None
container.metadata = metadata
if history:
history = [History.from_data(h) for h in history]
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(
False
)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg]
)
elif conversation_id and history_len > 0:
memory = ConversationBufferDBMemory(
conversation_id=conversation_id,
llm=models["llm_model"],
message_limit=history_len,
)
else:
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(
False
)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
if "action_model" in models and tools:
llm = models["action_model"]
llm.callbacks = callbacks
agent_executor = agents_registry(
llm=llm, callbacks=callbacks, tools=tools, prompt=None, verbose=True
)
full_chain = {"input": lambda x: x["input"]} | agent_executor
else:
llm = models["llm_model"]
llm.callbacks = callbacks
chain = LLMChain(prompt=chat_prompt, llm=llm, memory=memory)
full_chain = {"input": lambda x: x["input"]} | chain
return full_chain
async def chat(
query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
conversation_id: str = Body("", description="对话框ID"),
message_id: str = Body(None, description="数据库消息ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: List[History] = Body(
[],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[
[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"},
]
],
),
stream: bool = Body(True, description="流式输出"),
chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]),
tool_config: dict = Body({}, description="工具配置", examples=[]),
max_tokens: int = Body(None, description="LLM最大token数配置", example=4096),
):
"""Agent 对话"""
async def chat_iterator() -> AsyncIterable[OpenAIChatOutput]:
callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback]
# Enable langchain-chatchat to support langfuse
import os
langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
langfuse_host = os.environ.get("LANGFUSE_HOST")
if langfuse_secret_key and langfuse_public_key and langfuse_host:
from langfuse import Langfuse
from langfuse.callback import CallbackHandler
langfuse_handler = CallbackHandler()
callbacks.append(langfuse_handler)
models, prompts = create_models_from_config(
callbacks=callbacks, configs=chat_model_config, stream=stream, max_tokens=max_tokens
)
all_tools = get_tool().values()
tools = [tool for tool in all_tools if tool.name in tool_config]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
full_chain = create_models_chains(
prompts=prompts,
models=models,
conversation_id=conversation_id,
tools=tools,
callbacks=callbacks,
history=history,
history_len=history_len,
metadata=metadata,
)
_history = [History.from_data(h) for h in history]
chat_history = [h.to_msg_tuple() for h in _history]
history_message = convert_to_messages(chat_history)
task = asyncio.create_task(
wrap_done(
full_chain.ainvoke(
{
"input": query,
"chat_history": history_message,
}
),
callback.done,
)
)
last_tool = {}
async for chunk in callback.aiter():
data = json.loads(chunk)
data["tool_calls"] = []
data["message_type"] = MsgType.TEXT
if data["status"] == AgentStatus.tool_start:
last_tool = {
"index": 0,
"id": data["run_id"],
"type": "function",
"function": {
"name": data["tool"],
"arguments": data["tool_input"],
},
"tool_output": None,
"is_error": False,
}
data["tool_calls"].append(last_tool)
if data["status"] in [AgentStatus.tool_end]:
last_tool.update(
tool_output=data["tool_output"],
is_error=data.get("is_error", False),
)
data["tool_calls"] = [last_tool]
last_tool = {}
try:
tool_output = json.loads(data["tool_output"])
if message_type := tool_output.get("message_type"):
data["message_type"] = message_type
except:
...
elif data["status"] == AgentStatus.agent_finish:
try:
tool_output = json.loads(data["text"])
if message_type := tool_output.get("message_type"):
data["message_type"] = message_type
except:
...
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion.chunk",
content=data.get("text", ""),
role="assistant",
tool_calls=data["tool_calls"],
model=models["llm_model"].model_name,
status=data["status"],
message_type=data["message_type"],
message_id=message_id,
)
yield ret.model_dump_json()
# yield OpenAIChatOutput( # return blank text lastly
# id=f"chat{uuid.uuid4()}",
# object="chat.completion.chunk",
# content="",
# role="assistant",
# model=models["llm_model"].model_name,
# status = data["status"],
# message_type = data["message_type"],
# message_id=message_id,
# )
await task
if stream:
return EventSourceResponse(chat_iterator())
else:
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion",
content="",
role="assistant",
finish_reason="stop",
tool_calls=[],
status=AgentStatus.agent_finish,
message_type=MsgType.TEXT,
message_id=message_id,
)
async for chunk in chat_iterator():
data = json.loads(chunk)
if text := data["choices"][0]["delta"]["content"]:
ret.content += text
if data["status"] == AgentStatus.tool_end:
ret.tool_calls += data["choices"][0]["delta"]["tool_calls"]
ret.model = data["model"]
ret.created = data["created"]
return ret.model_dump()