From 22d08f5ec50ac089266eae0f0f40d19119595a94 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 16 Jul 2023 02:17:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BF=85=E8=A6=81=E5=8F=82=E6=95=B0=E6=A0=A1?= =?UTF-8?q?=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py | 4 ++-- models/fastchat_openai_llm.py | 41 ++++++++++++++++++++--------------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/configs/model_config.py b/configs/model_config.py index 2f52c9c..b596c64 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -16,7 +16,7 @@ embedding_model_dict = { "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "ernie-base": "nghuyong/ernie-3.0-base-zh", "text2vec-base": "shibing624/text2vec-base-chinese", - "text2vec": "GanymedeNil/text2vec-large-chinese", + "text2vec": "/media/checkpoint/text2vec-large-chinese/", "m3e-small": "moka-ai/m3e-small", "m3e-base": "moka-ai/m3e-base", } @@ -186,7 +186,7 @@ llm_model_dict = { } # LLM 名称 -LLM_MODEL = "chatglm-6b" +LLM_MODEL = "fastchat-chatglm-6b" # 量化加载8bit 模型 LOAD_IN_8BIT = False # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. diff --git a/models/fastchat_openai_llm.py b/models/fastchat_openai_llm.py index 5cb617b..d0972f7 100644 --- a/models/fastchat_openai_llm.py +++ b/models/fastchat_openai_llm.py @@ -52,16 +52,21 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, system_build_message['role'] = 'system' system_build_message['content'] = "You are a helpful assistant." build_messages.append(system_build_message) + if history: + for i, (user, assistant) in enumerate(history): + if user: - for i, (old_query, response) in enumerate(history): - user_build_message = _build_message_template() - user_build_message['role'] = 'user' - user_build_message['content'] = old_query - system_build_message = _build_message_template() - system_build_message['role'] = 'assistant' - system_build_message['content'] = response - build_messages.append(user_build_message) - build_messages.append(system_build_message) + user_build_message = _build_message_template() + user_build_message['role'] = 'user' + user_build_message['content'] = user + build_messages.append(user_build_message) + + if not assistant: + raise RuntimeError("历史数据结构不正确") + system_build_message = _build_message_template() + system_build_message['role'] = 'assistant' + system_build_message['content'] = assistant + build_messages.append(system_build_message) user_build_message = _build_message_template() user_build_message['role'] = 'user' @@ -181,10 +186,10 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): run_manager: Optional[CallbackManagerForChainRun] = None, generate_with_callback: AnswerResultStream = None) -> None: - history = inputs[self.history_key] - streaming = inputs[self.streaming_key] + history = inputs.get(self.history_key, []) + streaming = inputs.get(self.streaming_key, False) prompt = inputs[self.prompt_key] - stop = inputs.get("stop", None) + stop = inputs.get("stop", "stop") print(f"__call:{prompt}") try: @@ -205,16 +210,18 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): params = {"stream": streaming, "model": self.model_name, "stop": stop} + out_str = "" for stream_resp in self.completion_with_retry( messages=msg, **params ): role = stream_resp["choices"][0]["delta"].get("role", "") token = stream_resp["choices"][0]["delta"].get("content", "") - history += [[prompt, token]] + out_str += token + history[-1] = [prompt, out_str] answer_result = AnswerResult() answer_result.history = history - answer_result.llm_output = {"answer": token} + answer_result.llm_output = {"answer": out_str} generate_with_callback(answer_result) else: @@ -239,10 +246,10 @@ if __name__ == "__main__": chain = FastChatOpenAILLMChain() chain.set_api_key("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") - chain.set_api_base_url("https://api.openai.com/v1") - chain.call_model_name("gpt-3.5-turbo") + # chain.set_api_base_url("https://api.openai.com/v1") + # chain.call_model_name("gpt-3.5-turbo") - answer_result_stream_result = chain({"streaming": False, + answer_result_stream_result = chain({"streaming": True, "prompt": "你好", "history": [] })