From 4fa28c71b15cfec59773e2dfe30ee5eba7d6af09 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:35:43 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=EF=BC=9Achat=20=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E9=BB=98=E8=AE=A4=E4=BD=BF=E7=94=A8=20memory=E8=8E=B7?= =?UTF-8?q?=E5=8F=9610=E6=9D=A1=E5=8E=86=E5=8F=B2=E6=B6=88=E6=81=AF?= =?UTF-8?q?=EF=BC=8C=E5=AF=BC=E8=87=B4=E6=9C=80=E7=BB=88=E6=8B=BC=E6=8E=A5?= =?UTF-8?q?=E7=9A=84prompt=E5=87=BA=E9=94=99=20(#2247)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改后默认使用用户传入的history,只有同时传入conversation_id和history_len时才从数据库读取历史消息 使用memory还有一个问题,对话角色是固定的,不能适配不同的模型。 其它:注释一些无用的 print --- server/chat/chat.py | 27 ++++++++++++++++++--------- webui_pages/utils.py | 26 ++++++++++++++------------ 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/server/chat/chat.py b/server/chat/chat.py index 0885283..254a6dd 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -19,6 +19,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), conversation_id: str = Body("", description="对话框ID"), + history_len: int = Body(-1, description="从数据库中取历史消息的数量"), history: Union[int, List[History]] = Body([], description="历史对话,设为一个整数可以从数据库中读取历史消息", examples=[[ @@ -39,12 +40,14 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callbacks = [callback] memory = None - message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) - # 负责保存llm response到message db - conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id, - chat_type="llm_chat", - query=query) - callbacks.append(conversation_callback) + if conversation_id: + message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) + # 负责保存llm response到message db + conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id, + chat_type="llm_chat", + query=query) + callbacks.append(conversation_callback) + if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None @@ -55,18 +58,24 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callbacks=callbacks, ) - if not conversation_id: + if history: # 优先使用前端传入的历史消息 history = [History.from_data(h) for h in history] prompt_template = get_prompt_template("llm_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) - else: + elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息 # 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量 prompt = get_prompt_template("llm_chat", "with_history") chat_prompt = PromptTemplate.from_template(prompt) # 根据conversation_id 获取message 列表进而拼凑 memory - memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=model) + memory = ConversationBufferDBMemory(conversation_id=conversation_id, + llm=model, + message_limit=history_len) + else: + prompt_template = get_prompt_template("llm_chat", prompt_name) + input_msg = History(role="user", content=prompt_template).to_msg_template(False) + chat_prompt = ChatPromptTemplate.from_messages([input_msg]) chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory) diff --git a/webui_pages/utils.py b/webui_pages/utils.py index d33912a..75e1419 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -85,7 +85,7 @@ class ApiRequest: ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: - print(kwargs) + # print(kwargs) if stream: return self.client.stream("POST", url, data=data, json=json, **kwargs) else: @@ -134,7 +134,7 @@ class ApiRequest: if as_json: try: data = json.loads(chunk) - pprint(data, depth=1) + # pprint(data, depth=1) yield data except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" @@ -166,7 +166,7 @@ class ApiRequest: if as_json: try: data = json.loads(chunk) - pprint(data, depth=1) + # pprint(data, depth=1) yield data except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" @@ -291,6 +291,7 @@ class ApiRequest: self, query: str, conversation_id: str = None, + history_len: int = -1, history: List[Dict] = [], stream: bool = True, model: str = LLM_MODELS[0], @@ -300,11 +301,12 @@ class ApiRequest: **kwargs, ): ''' - 对应api.py/chat/chat接口 #TODO: 考虑是否返回json + 对应api.py/chat/chat接口 ''' data = { "query": query, "conversation_id": conversation_id, + "history_len": history_len, "history": history, "stream": stream, "model_name": model, @@ -342,8 +344,8 @@ class ApiRequest: "prompt_name": prompt_name, } - print(f"received input message:") - pprint(data) + # print(f"received input message:") + # pprint(data) response = self.post("/chat/agent_chat", json=data, stream=True) return self._httpx_stream2generator(response) @@ -377,8 +379,8 @@ class ApiRequest: "prompt_name": prompt_name, } - print(f"received input message:") - pprint(data) + # print(f"received input message:") + # pprint(data) response = self.post( "/chat/knowledge_base_chat", @@ -452,8 +454,8 @@ class ApiRequest: "prompt_name": prompt_name, } - print(f"received input message:") - pprint(data) + # print(f"received input message:") + # pprint(data) response = self.post( "/chat/file_chat", @@ -491,8 +493,8 @@ class ApiRequest: "split_result": split_result, } - print(f"received input message:") - pprint(data) + # print(f"received input message:") + # pprint(data) response = self.post( "/chat/search_engine_chat",