diff --git a/chatglm_llm.py b/chatglm_llm.py index 776a8e1..84f130a 100644 --- a/chatglm_llm.py +++ b/chatglm_llm.py @@ -16,20 +16,8 @@ def torch_gc(): torch.cuda.ipc_collect() -tokenizer = AutoTokenizer.from_pretrained( - "THUDM/chatglm-6b", - trust_remote_code=True -) -model = ( - AutoModel.from_pretrained( - "THUDM/chatglm-6b", - trust_remote_code=True) - .half() - .cuda() -) - - class ChatGLM(LLM): + model_name: str max_token: int = 10000 temperature: float = 0.1 top_p = 0.9 @@ -38,6 +26,20 @@ class ChatGLM(LLM): def __init__(self): super().__init__() + def load_model(self, + model_name_or_path: str = "THUDM/chatglm-6b"): + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + trust_remote_code=True + ) + self.model = ( + AutoModel.from_pretrained( + model_name_or_path, + trust_remote_code=True) + .half() + .cuda() + ) + @property def _llm_type(self) -> str: return "ChatGLM" @@ -45,8 +47,8 @@ class ChatGLM(LLM): def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - response, updated_history = model.chat( - tokenizer, + response, updated_history = self.model.chat( + self.tokenizer, prompt, history=self.history, max_length=self.max_token, diff --git a/knowledge_based_chatglm.py b/knowledge_based_chatglm.py index d1f49c7..07fbac1 100644 --- a/knowledge_based_chatglm.py +++ b/knowledge_based_chatglm.py @@ -15,8 +15,14 @@ embedding_model_dict = { "ernie-base": "nghuyong/ernie-3.0-base-zh", "text2vec": "GanymedeNil/text2vec-large-chinese" } -chatglm = ChatGLM() +llm_model_dict = { + "chatglm-6b": "THUDM/chatglm-6b", + "chatglm-6b-int4": "THUDM/chatglm-6b-int4" +} + +chatglm = ChatGLM() +chatglm.load_model(model_name_or_path=llm_model_dict["chatglm-6b"]) def init_knowledge_vector_store(filepath): embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict["text2vec"], )