From 43d1bf4fb3474752f7e6885aae65702b024bb5ba Mon Sep 17 00:00:00 2001 From: zg h Date: Thu, 13 Jul 2023 22:10:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=91=BD=E4=BB=A4=E8=A1=8C?= =?UTF-8?q?=E8=BE=93=E5=85=A5ptuning=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py | 2 +- models/loader/args.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/configs/model_config.py b/configs/model_config.py index 2f52c9c..c53e4df 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -203,7 +203,7 @@ STREAMING = True # Use p-tuning-v2 PrefixEncoder USE_PTUNING_V2 = False - +PTUNING_DIR='./ptuing-v2' # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/models/loader/args.py b/models/loader/args.py index b15ad5e..59668b0 100644 --- a/models/loader/args.py +++ b/models/loader/args.py @@ -1,3 +1,4 @@ + import argparse import os from configs.model_config import * @@ -43,7 +44,8 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") - +parser.add_argument('--use-ptuning-v2',type=str,default=False,help="whether use ptuning-v2 checkpoint") +parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint") # Accelerate/transformers parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT, help='Load the model with 8-bit precision.')