支持命令行输入ptuning路径

This commit is contained in:
zg h 2023-07-13 22:10:54 +08:00
parent c5bc21781c
commit 43d1bf4fb3
2 changed files with 4 additions and 2 deletions

View File

@ -203,7 +203,7 @@ STREAMING = True
# Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False
PTUNING_DIR='./ptuing-v2'
# LLM running device
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

View File

@ -1,3 +1,4 @@
import argparse
import os
from configs.model_config import *
@ -43,7 +44,8 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
parser.add_argument('--use-ptuning-v2',type=str,default=False,help="whether use ptuning-v2 checkpoint")
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
# Accelerate/transformers
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
help='Load the model with 8-bit precision.')