diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 866bedf..2d696a3 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -61,9 +61,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]: def similarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, + self, embedding: List[float], k: int = 4, ) -> List[Tuple[Document, float]]: scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) docs = [] @@ -122,12 +120,12 @@ class LocalDocQA: llm_model: str = LLM_MODEL, llm_device=LLM_DEVICE, top_k=VECTOR_SEARCH_TOP_K, - use_ptuning_v2: bool = USE_PTUNING_V2 + use_ptuning_v2: bool = USE_PTUNING_V2, + use_lora: bool = USE_LORA, ): self.llm = ChatGLM() self.llm.load_model(model_name_or_path=llm_model_dict[llm_model], - llm_device=llm_device, - use_ptuning_v2=use_ptuning_v2) + llm_device=llm_device, use_ptuning_v2=use_ptuning_v2, use_lora=use_lora) self.llm.history_len = llm_history_len self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], diff --git a/configs/model_config.py b/configs/model_config.py index 38a8970..c8a8a4c 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -27,6 +27,11 @@ llm_model_dict = { # LLM model name LLM_MODEL = "chatglm-6b" +# LLM lora path,默认为空,如果有请直接指定文件夹路径 +# 推荐使用 chatglm-6b-belle-zh-lora +LLM_LORA_PATH = "" +USE_LORA = True if LLM_LORA_PATH else False + # LLM streaming reponse STREAMING = True diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index d33b62e..b789fea 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -78,11 +78,11 @@ class ChatGLM(LLM): torch_gc() else: response, _ = self.model.chat( - self.tokenizer, - prompt, - history=history[-self.history_len:] if self.history_len > 0 else [], - max_length=self.max_token, - temperature=self.temperature, + self.tokenizer, + prompt, + history=history[-self.history_len:] if self.history_len > 0 else [], + max_length=self.max_token, + temperature=self.temperature, ) torch_gc() history += [[prompt, response]] @@ -106,6 +106,7 @@ class ChatGLM(LLM): model_name_or_path: str = "THUDM/chatglm-6b", llm_device=LLM_DEVICE, use_ptuning_v2=False, + use_lora=False, device_map: Optional[Dict[str, int]] = None, **kwargs): self.tokenizer = AutoTokenizer.from_pretrained( @@ -125,45 +126,32 @@ class ChatGLM(LLM): except Exception as e: print(e) print("加载PrefixEncoder config.json失败") + self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True, + **kwargs) + if LLM_LORA_PATH and use_lora: + from peft import PeftModel + self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH) if torch.cuda.is_available() and llm_device.lower().startswith("cuda"): # 根据当前设备GPU数量决定是否进行多卡部署 num_gpus = torch.cuda.device_count() if num_gpus < 2 and device_map is None: - self.model = ( - AutoModel.from_pretrained( - model_name_or_path, - config=model_config, - trust_remote_code=True, - **kwargs) - .half() - .cuda() - ) + self.model = self.model.half().cuda() else: from accelerate import dispatch_model - model = ( - AutoModel.from_pretrained( - model_name_or_path, - trust_remote_code=True, - config=model_config, - **kwargs) - .half()) + model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, + config=model_config, **kwargs) + if LLM_LORA_PATH and use_lora: + from peft import PeftModel + model_auto = PeftModel.from_pretrained(model, LLM_LORA_PATH) # 可传入device_map自定义每张卡的部署情况 if device_map is None: device_map = auto_configure_device_map(num_gpus) - self.model = dispatch_model(model, device_map=device_map) + self.model = dispatch_model(model_auto.half(), device_map=device_map) else: - self.model = ( - AutoModel.from_pretrained( - model_name_or_path, - config=model_config, - trust_remote_code=True, - **kwargs) - .float() - .to(llm_device) - ) + self.model = self.model.float().to(llm_device) if use_ptuning_v2: try: @@ -185,7 +173,7 @@ if __name__ == "__main__": llm = ChatGLM() llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL], llm_device=LLM_DEVICE, ) - last_print_len=0 + last_print_len = 0 for resp, history in llm._call("你好", streaming=True): print(resp[last_print_len:], end="", flush=True) last_print_len = len(resp) diff --git a/requirements.txt b/requirements.txt index 470ee69..6bc9c08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ accelerate gradio==3.24.1 fastapi uvicorn +peft #detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2 diff --git a/webui.py b/webui.py index 24cc2ff..308e95c 100644 --- a/webui.py +++ b/webui.py @@ -72,12 +72,13 @@ def init_model(): return reply -def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history): +def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, history): try: local_doc_qa.init_cfg(llm_model=llm_model, embedding_model=embedding_model, llm_history_len=llm_history_len, use_ptuning_v2=use_ptuning_v2, + use_lora = use_lora, top_k=top_k,) model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" print(model_status) @@ -246,6 +247,9 @@ with gr.Blocks(css=block_css) as demo: use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2, label="使用p-tuning-v2微调过的模型", interactive=True) + use_lora = gr.Checkbox(USE_LORA, + label="使用lora微调的权重", + interactive=True) embedding_model = gr.Radio(embedding_model_dict_list, label="Embedding 模型", value=EMBEDDING_MODEL, @@ -259,7 +263,7 @@ with gr.Blocks(css=block_css) as demo: load_model_button = gr.Button("重新加载模型") load_model_button.click(reinit_model, show_progress=True, - inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot], + inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, chatbot], outputs=chatbot )