diff --git a/chatglm_llm.py b/chatglm_llm.py index 4b974b5..a2e8f6f 100644 --- a/chatglm_llm.py +++ b/chatglm_llm.py @@ -17,7 +17,6 @@ def torch_gc(): class ChatGLM(LLM): - model_name: str max_token: int = 10000 temperature: float = 0.1 top_p = 0.9 @@ -28,20 +27,6 @@ class ChatGLM(LLM): def __init__(self): super().__init__() - def load_model(self, - model_name_or_path: str = "THUDM/chatglm-6b"): - self.tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, - trust_remote_code=True - ) - self.model = ( - AutoModel.from_pretrained( - model_name_or_path, - trust_remote_code=True) - .half() - .cuda() - ) - @property def _llm_type(self) -> str: return "ChatGLM" @@ -62,3 +47,17 @@ class ChatGLM(LLM): response = enforce_stop_tokens(response, stop) self.history = updated_history return response + + def load_model(self, + model_name_or_path: str = "THUDM/chatglm-6b"): + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + trust_remote_code=True + ) + self.model = ( + AutoModel.from_pretrained( + model_name_or_path, + trust_remote_code=True) + .half() + .cuda() + )