diff --git a/configs/model_config.py b/configs/model_config.py index ea24c36..4a51c5c 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -203,7 +203,7 @@ llm_model_dict = { } # LLM 名称 -LLM_MODEL = "fastchat-chatglm-6b-int4" +LLM_MODEL = "fastchat-chatglm" # 量化加载8bit 模型 LOAD_IN_8BIT = False # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. @@ -220,7 +220,7 @@ STREAMING = True # Use p-tuning-v2 PrefixEncoder USE_PTUNING_V2 = False -PTUNING_DIR='./ptuing-v2' +PTUNING_DIR='./ptuning-v2' # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 81878ce..0d19ee6 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -2,14 +2,14 @@ from abc import ABC from langchain.chains.base import Chain from typing import Any, Dict, List, Optional, Generator from langchain.callbacks.manager import CallbackManagerForChainRun -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList +# from transformers.generation.logits_process import LogitsProcessor +# from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, AnswerResult, AnswerResultStream, AnswerResultQueueSentinelTokenListenerQueue) -import torch +# import torch import transformers diff --git a/models/loader/args.py b/models/loader/args.py index e7df76b..cd3e78b 100644 --- a/models/loader/args.py +++ b/models/loader/args.py @@ -44,7 +44,7 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") -parser.add_argument('--use-ptuning-v2',type=str,default=USE_PTUNING_V2,help="whether use ptuning-v2 checkpoint") +parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint") parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint") # Accelerate/transformers parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT, diff --git a/models/loader/loader.py b/models/loader/loader.py index 986f5c4..f43bb1d 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -441,7 +441,7 @@ class LoaderCheckPoint: if self.use_ptuning_v2: try: - prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') + prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r') prefix_encoder_config = json.loads(prefix_encoder_file.read()) prefix_encoder_file.close() self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] @@ -457,13 +457,14 @@ class LoaderCheckPoint: if self.use_ptuning_v2: try: - prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) + prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin')) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.float() + print("加载ptuning检查点成功!") except Exception as e: print(e) print("加载PrefixEncoder模型参数失败")