diff --git a/chatglm_llm.py b/chatglm_llm.py index 8c00685..b7251bb 100644 --- a/chatglm_llm.py +++ b/chatglm_llm.py @@ -3,7 +3,17 @@ from typing import Optional, List from langchain.llms.utils import enforce_stop_tokens from transformers import AutoTokenizer, AutoModel -"""ChatGLM_G is a wrapper around the ChatGLM model to fit LangChain framework. May not be an optimal implementation""" +tokenizer = AutoTokenizer.from_pretrained( + "THUDM/chatglm-6b", + trust_remote_code=True +) +model = ( + AutoModel.from_pretrained( + "THUDM/chatglm-6b", + trust_remote_code=True) + .half() + .cuda() +) class ChatGLM(LLM): @@ -11,18 +21,6 @@ class ChatGLM(LLM): temperature: float = 0.1 top_p = 0.9 history = [] - tokenizer = AutoTokenizer.from_pretrained( - "THUDM/chatglm-6b", - trust_remote_code=True - ) - model = ( - AutoModel.from_pretrained( - "THUDM/chatglm-6b", - trust_remote_code=True) - .half() - .cuda() - ) - def __init__(self): super().__init__() @@ -34,13 +32,12 @@ class ChatGLM(LLM): def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - response, updated_history = self.model.chat( - self.tokenizer, + response, updated_history = model.chat( + tokenizer, prompt, history=self.history, max_length=self.max_token, temperature=self.temperature, - ) print("history: ", self.history) if stop is not None: