From 54c983f4bcbb351b58e5f63051f2b5bb53e6a510 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 22 Apr 2023 12:20:08 +0800 Subject: [PATCH] update chatglm_llm.py --- models/chatglm_llm.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 20e98a5..c951b78 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -72,29 +72,29 @@ class ChatGLM(LLM): response, _ = self.model.chat( self.tokenizer, prompt, - history=self.history[-self.history_len:] if self.history_len>0 else [], + history=self.history[-self.history_len:] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, ) torch_gc() if stop is not None: response = enforce_stop_tokens(response, stop) - self.history = self.history+[[None, response]] + self.history = self.history + [[None, response]] return response def chat(self, - prompt: str) -> str: + prompt: str) -> str: response, _ = self.model.chat( self.tokenizer, prompt, - history=[],#self.history[-self.history_len:] if self.history_len>0 else + history=self.history[-self.history_len:] if self.history_len > 0 else [], max_length=self.max_token, temperature=self.temperature, ) torch_gc() - self.history = self.history+[[None, response]] + self.history = self.history + [[None, response]] return response - + def load_model(self, model_name_or_path: str = "THUDM/chatglm-6b", llm_device=LLM_DEVICE, @@ -126,7 +126,7 @@ class ChatGLM(LLM): AutoModel.from_pretrained( model_name_or_path, config=model_config, - trust_remote_code=True, + trust_remote_code=True, **kwargs) .half() .cuda() @@ -159,7 +159,8 @@ class ChatGLM(LLM): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.float() - except Exception: + except Exception as e: + print(e) print("加载PrefixEncoder模型参数失败") self.model = self.model.eval()