修复:chat 接口默认使用 memory获取10条历史消息,导致最终拼接的prompt出错 (#2247)
修改后默认使用用户传入的history,只有同时传入conversation_id和history_len时才从数据库读取历史消息 使用memory还有一个问题,对话角色是固定的,不能适配不同的模型。 其它:注释一些无用的 print
This commit is contained in:
parent
0cc1be224d
commit
4fa28c71b1
|
|
@ -19,6 +19,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa
|
|||
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
conversation_id: str = Body("", description="对话框ID"),
|
||||
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||||
history: Union[int, List[History]] = Body([],
|
||||
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||||
examples=[[
|
||||
|
|
@ -39,12 +40,14 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
callbacks = [callback]
|
||||
memory = None
|
||||
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
||||
# 负责保存llm response到message db
|
||||
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
||||
chat_type="llm_chat",
|
||||
query=query)
|
||||
callbacks.append(conversation_callback)
|
||||
if conversation_id:
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
||||
# 负责保存llm response到message db
|
||||
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
||||
chat_type="llm_chat",
|
||||
query=query)
|
||||
callbacks.append(conversation_callback)
|
||||
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
|
|
@ -55,18 +58,24 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if not conversation_id:
|
||||
if history: # 优先使用前端传入的历史消息
|
||||
history = [History.from_data(h) for h in history]
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
else:
|
||||
elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
|
||||
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
|
||||
prompt = get_prompt_template("llm_chat", "with_history")
|
||||
chat_prompt = PromptTemplate.from_template(prompt)
|
||||
# 根据conversation_id 获取message 列表进而拼凑 memory
|
||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=model)
|
||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id,
|
||||
llm=model,
|
||||
message_limit=history_len)
|
||||
else:
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
|
||||
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class ApiRequest:
|
|||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||
while retry > 0:
|
||||
try:
|
||||
print(kwargs)
|
||||
# print(kwargs)
|
||||
if stream:
|
||||
return self.client.stream("POST", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
|
|
@ -134,7 +134,7 @@ class ApiRequest:
|
|||
if as_json:
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
pprint(data, depth=1)
|
||||
# pprint(data, depth=1)
|
||||
yield data
|
||||
except Exception as e:
|
||||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||
|
|
@ -166,7 +166,7 @@ class ApiRequest:
|
|||
if as_json:
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
pprint(data, depth=1)
|
||||
# pprint(data, depth=1)
|
||||
yield data
|
||||
except Exception as e:
|
||||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||
|
|
@ -291,6 +291,7 @@ class ApiRequest:
|
|||
self,
|
||||
query: str,
|
||||
conversation_id: str = None,
|
||||
history_len: int = -1,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODELS[0],
|
||||
|
|
@ -300,11 +301,12 @@ class ApiRequest:
|
|||
**kwargs,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/chat接口 #TODO: 考虑是否返回json
|
||||
对应api.py/chat/chat接口
|
||||
'''
|
||||
data = {
|
||||
"query": query,
|
||||
"conversation_id": conversation_id,
|
||||
"history_len": history_len,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
|
|
@ -342,8 +344,8 @@ class ApiRequest:
|
|||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post("/chat/agent_chat", json=data, stream=True)
|
||||
return self._httpx_stream2generator(response)
|
||||
|
|
@ -377,8 +379,8 @@ class ApiRequest:
|
|||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/knowledge_base_chat",
|
||||
|
|
@ -452,8 +454,8 @@ class ApiRequest:
|
|||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/file_chat",
|
||||
|
|
@ -491,8 +493,8 @@ class ApiRequest:
|
|||
"split_result": split_result,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/search_engine_chat",
|
||||
|
|
|
|||
Loading…
Reference in New Issue