diff --git a/chatglm_llm.py b/chatglm_llm.py index 8c00685..4a8b36a 100644 --- a/chatglm_llm.py +++ b/chatglm_llm.py @@ -5,24 +5,14 @@ from transformers import AutoTokenizer, AutoModel """ChatGLM_G is a wrapper around the ChatGLM model to fit LangChain framework. May not be an optimal implementation""" +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() class ChatGLM(LLM): max_token: int = 10000 temperature: float = 0.1 top_p = 0.9 history = [] - tokenizer = AutoTokenizer.from_pretrained( - "THUDM/chatglm-6b", - trust_remote_code=True - ) - model = ( - AutoModel.from_pretrained( - "THUDM/chatglm-6b", - trust_remote_code=True) - .half() - .cuda() - ) - def __init__(self): super().__init__() @@ -34,8 +24,8 @@ class ChatGLM(LLM): def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - response, updated_history = self.model.chat( - self.tokenizer, + response, updated_history = model.chat( + tokenizer, prompt, history=self.history, max_length=self.max_token,