diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 56aea4a..a15be3d 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -41,10 +41,12 @@ class LocalDocQA: llm_model: str = LLM_MODEL, llm_device=LLM_DEVICE, top_k=VECTOR_SEARCH_TOP_K, + use_ptuning_v2: bool = USE_PTUNING_V2 ): self.llm = ChatGLM() self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], - llm_device=llm_device) + llm_device=llm_device, + use_ptuning_v2=use_ptuning_v2) self.llm.history_len = llm_history_len self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], ) diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 6314077..c3d1a21 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -127,14 +127,6 @@ class ChatGLM(LLM): device_map = auto_configure_device_map(num_gpus) self.model = dispatch_model(model, device_map=device_map) - self.model = ( - AutoModel.from_pretrained( - model_name_or_path, - config=model_config, - trust_remote_code=True) - .half() - .cuda() - ) else: self.model = ( AutoModel.from_pretrained(