From 58a5de92a5a7453a11895a32163a8474e5820732 Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Wed, 26 Jul 2023 17:05:37 +0800 Subject: [PATCH] =?UTF-8?q?1.=E6=9B=B4=E6=94=B9=E5=8A=A0=E8=BD=BDlora?= =?UTF-8?q?=E7=9A=84=E6=96=B9=E5=BC=8F;2.=E5=85=81=E8=AE=B8api.py=E8=B0=83?= =?UTF-8?q?=E7=94=A8args.py=E7=9A=84=E5=91=BD=E4=BB=A4=E8=A1=8C;3.=20FastC?= =?UTF-8?q?hat=E8=B7=AF=E5=BE=84=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.py | 2 +- configs/model_config.py | 11 ++++------- models/loader/args.py | 2 +- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/api.py b/api.py index bc2fe63..452313e 100644 --- a/api.py +++ b/api.py @@ -583,7 +583,7 @@ if __name__ == "__main__": parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_certfile", type=str) # 初始化消息 - args = None + args = parser.parse_args() args_dict = vars(args) shared.loaderCheckPoint = LoaderCheckPoint(args_dict) diff --git a/configs/model_config.py b/configs/model_config.py index 846da98..91ed20a 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -32,7 +32,7 @@ EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backe # llm_model_dict 处理了loader的一些预设行为,如加载位置,模型名称,模型处理器实例 # 在以下字典中修改属性值,以指定本地 LLM 模型存储位置 # 如将 "chatglm-6b" 的 "local_model_path" 由 None 修改为 "User/Downloads/chatglm-6b" -# 此处请写绝对路径 +# 此处请写绝对路径,且路径中必须包含模型名称,FastChat会根据路径名称提取repo-id llm_model_dict = { "chatglm-6b-int4-qe": { "name": "chatglm-6b-int4-qe", @@ -104,17 +104,14 @@ LOAD_IN_8BIT = False # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. BF16 = False # 本地lora存放的位置 -LORA_DIR = "loras/" - -# LLM lora path,默认为空,如果有请直接指定文件夹路径 -LLM_LORA_PATH = "" -USE_LORA = True if LLM_LORA_PATH else False +LORA_DIR = "./loras/" # LLM streaming reponse STREAMING = True # Use p-tuning-v2 PrefixEncoder -USE_PTUNING_V2 = False + +PTUNING_DIR="./ptuning-v2" # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/models/loader/args.py b/models/loader/args.py index cd3e78b..9e2e8b2 100644 --- a/models/loader/args.py +++ b/models/loader/args.py @@ -42,7 +42,7 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th 'model to add the ` ' '--no-remote-model`') parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.') -parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') +parser.add_argument('--lora', type=str, action="store_true",help='Name of the LoRA to apply to the model by default.') parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint") parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")