修复:chat 接口默认使用 memory获取10条历史消息,导致最终拼接的prompt出错 (#2247)

修改后默认使用用户传入的history,只有同时传入conversation_id和history_len时才从数据库读取历史消息
使用memory还有一个问题,对话角色是固定的,不能适配不同的模型。

其它:注释一些无用的 print
This commit is contained in:
liunux4odoo 2023-12-01 11:35:43 +08:00 committed by GitHub
parent 0cc1be224d
commit 4fa28c71b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 21 deletions

View File

@ -19,6 +19,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body("", description="对话框ID"), conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body([], history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息", description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[ examples=[[
@ -39,12 +40,14 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks = [callback] callbacks = [callback]
memory = None memory = None
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) if conversation_id:
# 负责保存llm response到message db message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id, # 负责保存llm response到message db
chat_type="llm_chat", conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
query=query) chat_type="llm_chat",
callbacks.append(conversation_callback) query=query)
callbacks.append(conversation_callback)
if isinstance(max_tokens, int) and max_tokens <= 0: if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None max_tokens = None
@ -55,18 +58,24 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks=callbacks, callbacks=callbacks,
) )
if not conversation_id: if history: # 优先使用前端传入的历史消息
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("llm_chat", prompt_name) prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False) input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg]) [i.to_msg_template() for i in history] + [input_msg])
else: elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量 # 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
prompt = get_prompt_template("llm_chat", "with_history") prompt = get_prompt_template("llm_chat", "with_history")
chat_prompt = PromptTemplate.from_template(prompt) chat_prompt = PromptTemplate.from_template(prompt)
# 根据conversation_id 获取message 列表进而拼凑 memory # 根据conversation_id 获取message 列表进而拼凑 memory
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=model) memory = ConversationBufferDBMemory(conversation_id=conversation_id,
llm=model,
message_limit=history_len)
else:
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory) chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)

View File

@ -85,7 +85,7 @@ class ApiRequest:
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0: while retry > 0:
try: try:
print(kwargs) # print(kwargs)
if stream: if stream:
return self.client.stream("POST", url, data=data, json=json, **kwargs) return self.client.stream("POST", url, data=data, json=json, **kwargs)
else: else:
@ -134,7 +134,7 @@ class ApiRequest:
if as_json: if as_json:
try: try:
data = json.loads(chunk) data = json.loads(chunk)
pprint(data, depth=1) # pprint(data, depth=1)
yield data yield data
except Exception as e: except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}" msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
@ -166,7 +166,7 @@ class ApiRequest:
if as_json: if as_json:
try: try:
data = json.loads(chunk) data = json.loads(chunk)
pprint(data, depth=1) # pprint(data, depth=1)
yield data yield data
except Exception as e: except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}" msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
@ -291,6 +291,7 @@ class ApiRequest:
self, self,
query: str, query: str,
conversation_id: str = None, conversation_id: str = None,
history_len: int = -1,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODELS[0], model: str = LLM_MODELS[0],
@ -300,11 +301,12 @@ class ApiRequest:
**kwargs, **kwargs,
): ):
''' '''
对应api.py/chat/chat接口 #TODO: 考虑是否返回json 对应api.py/chat/chat接口
''' '''
data = { data = {
"query": query, "query": query,
"conversation_id": conversation_id, "conversation_id": conversation_id,
"history_len": history_len,
"history": history, "history": history,
"stream": stream, "stream": stream,
"model_name": model, "model_name": model,
@ -342,8 +344,8 @@ class ApiRequest:
"prompt_name": prompt_name, "prompt_name": prompt_name,
} }
print(f"received input message:") # print(f"received input message:")
pprint(data) # pprint(data)
response = self.post("/chat/agent_chat", json=data, stream=True) response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response) return self._httpx_stream2generator(response)
@ -377,8 +379,8 @@ class ApiRequest:
"prompt_name": prompt_name, "prompt_name": prompt_name,
} }
print(f"received input message:") # print(f"received input message:")
pprint(data) # pprint(data)
response = self.post( response = self.post(
"/chat/knowledge_base_chat", "/chat/knowledge_base_chat",
@ -452,8 +454,8 @@ class ApiRequest:
"prompt_name": prompt_name, "prompt_name": prompt_name,
} }
print(f"received input message:") # print(f"received input message:")
pprint(data) # pprint(data)
response = self.post( response = self.post(
"/chat/file_chat", "/chat/file_chat",
@ -491,8 +493,8 @@ class ApiRequest:
"split_result": split_result, "split_result": split_result,
} }
print(f"received input message:") # print(f"received input message:")
pprint(data) # pprint(data)
response = self.post( response = self.post(
"/chat/search_engine_chat", "/chat/search_engine_chat",