From bed03a6ff11a2fd1c9981540c879bd1df01e4b85 Mon Sep 17 00:00:00 2001 From: myml Date: Wed, 5 Apr 2023 01:05:06 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dchatglm=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E8=A2=AB=E5=A4=8D=E5=88=B6=EF=BC=8C=E6=98=BE=E5=AD=98?= =?UTF-8?q?=E5=8D=A0=E7=94=A8=E8=BF=87=E5=A4=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit model作为类成员会在类实例化时进行一次复制 这导致每询问一个问题显存占用就会翻倍 通过将model改成全局变量修复这个问题 --- chatglm_llm.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/chatglm_llm.py b/chatglm_llm.py index 8c00685..4a8b36a 100644 --- a/chatglm_llm.py +++ b/chatglm_llm.py @@ -5,24 +5,14 @@ from transformers import AutoTokenizer, AutoModel """ChatGLM_G is a wrapper around the ChatGLM model to fit LangChain framework. May not be an optimal implementation""" +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() class ChatGLM(LLM): max_token: int = 10000 temperature: float = 0.1 top_p = 0.9 history = [] - tokenizer = AutoTokenizer.from_pretrained( - "THUDM/chatglm-6b", - trust_remote_code=True - ) - model = ( - AutoModel.from_pretrained( - "THUDM/chatglm-6b", - trust_remote_code=True) - .half() - .cuda() - ) - def __init__(self): super().__init__() @@ -34,8 +24,8 @@ class ChatGLM(LLM): def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - response, updated_history = self.model.chat( - self.tokenizer, + response, updated_history = model.chat( + tokenizer, prompt, history=self.history, max_length=self.max_token,