修复:chat 接口默认使用 memory获取10条历史消息,导致最终拼接的prompt出错 (#2247)

修改后默认使用用户传入的history,只有同时传入conversation_id和history_len时才从数据库读取历史消息
使用memory还有一个问题,对话角色是固定的,不能适配不同的模型。

其它:注释一些无用的 print
This commit is contained in:
liunux4odoo 2023-12-01 11:35:43 +08:00 committed by GitHub
parent 0cc1be224d
commit 4fa28c71b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 21 deletions

View File

@ -19,6 +19,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[
@ -39,12 +40,14 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks = [callback]
memory = None
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
# 负责保存llm response到message db
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
chat_type="llm_chat",
query=query)
callbacks.append(conversation_callback)
if conversation_id:
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
# 负责保存llm response到message db
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
chat_type="llm_chat",
query=query)
callbacks.append(conversation_callback)
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
@ -55,18 +58,24 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks=callbacks,
)
if not conversation_id:
if history: # 优先使用前端传入的历史消息
history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
else:
elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
prompt = get_prompt_template("llm_chat", "with_history")
chat_prompt = PromptTemplate.from_template(prompt)
# 根据conversation_id 获取message 列表进而拼凑 memory
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=model)
memory = ConversationBufferDBMemory(conversation_id=conversation_id,
llm=model,
message_limit=history_len)
else:
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)

View File

@ -85,7 +85,7 @@ class ApiRequest:
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
print(kwargs)
# print(kwargs)
if stream:
return self.client.stream("POST", url, data=data, json=json, **kwargs)
else:
@ -134,7 +134,7 @@ class ApiRequest:
if as_json:
try:
data = json.loads(chunk)
pprint(data, depth=1)
# pprint(data, depth=1)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
@ -166,7 +166,7 @@ class ApiRequest:
if as_json:
try:
data = json.loads(chunk)
pprint(data, depth=1)
# pprint(data, depth=1)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
@ -291,6 +291,7 @@ class ApiRequest:
self,
query: str,
conversation_id: str = None,
history_len: int = -1,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
@ -300,11 +301,12 @@ class ApiRequest:
**kwargs,
):
'''
对应api.py/chat/chat接口 #TODO: 考虑是否返回json
对应api.py/chat/chat接口
'''
data = {
"query": query,
"conversation_id": conversation_id,
"history_len": history_len,
"history": history,
"stream": stream,
"model_name": model,
@ -342,8 +344,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
# print(f"received input message:")
# pprint(data)
response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response)
@ -377,8 +379,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
# print(f"received input message:")
# pprint(data)
response = self.post(
"/chat/knowledge_base_chat",
@ -452,8 +454,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
# print(f"received input message:")
# pprint(data)
response = self.post(
"/chat/file_chat",
@ -491,8 +493,8 @@ class ApiRequest:
"split_result": split_result,
}
print(f"received input message:")
pprint(data)
# print(f"received input message:")
# pprint(data)
response = self.post(
"/chat/search_engine_chat",