删除model_dir和NO_REMOTE_MODEL
This commit is contained in:
parent
f1f742ce44
commit
218aca2e20
File diff suppressed because it is too large
Load Diff
|
|
@ -91,14 +91,10 @@ llm_model_dict = {
|
||||||
|
|
||||||
# LLM 名称
|
# LLM 名称
|
||||||
LLM_MODEL = "chatglm-6b"
|
LLM_MODEL = "chatglm-6b"
|
||||||
# 如果你需要加载本地的model,指定这个参数 ` --no-remote-model`,或者下方参数修改为 `True`
|
|
||||||
NO_REMOTE_MODEL = False
|
|
||||||
# 量化加载8bit 模型
|
# 量化加载8bit 模型
|
||||||
LOAD_IN_8BIT = False
|
LOAD_IN_8BIT = False
|
||||||
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
||||||
BF16 = False
|
BF16 = False
|
||||||
# 本地模型存放的位置
|
|
||||||
MODEL_DIR = "model/"
|
|
||||||
# 本地lora存放的位置
|
# 本地lora存放的位置
|
||||||
LORA_DIR = "loras/"
|
LORA_DIR = "loras/"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,14 +35,13 @@ parser = argparse.ArgumentParser(prog='langchina-ChatGLM',
|
||||||
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain | '
|
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain | '
|
||||||
'基于本地知识库的 ChatGLM 问答')
|
'基于本地知识库的 ChatGLM 问答')
|
||||||
|
|
||||||
parser.add_argument('--no-remote-model', action='store_true', default=NO_REMOTE_MODEL, help='remote in the model on '
|
parser.add_argument('--no-remote-model', action='store_true', help='remote in the model on '
|
||||||
'loader checkpoint, '
|
'loader checkpoint, '
|
||||||
'if your load local '
|
'if your load local '
|
||||||
'model to add the ` '
|
'model to add the ` '
|
||||||
'--no-remote-model`')
|
'--no-remote-model`')
|
||||||
parser.add_argument('--model', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
||||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||||
parser.add_argument("--model-dir", type=str, default=MODEL_DIR, help="Path to directory with all the models")
|
|
||||||
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
||||||
|
|
||||||
# Accelerate/transformers
|
# Accelerate/transformers
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ class LoaderCheckPoint:
|
||||||
model: object = None
|
model: object = None
|
||||||
model_config: object = None
|
model_config: object = None
|
||||||
lora_names: set = []
|
lora_names: set = []
|
||||||
model_dir: str = None
|
|
||||||
lora_dir: str = None
|
lora_dir: str = None
|
||||||
ptuning_dir: str = None
|
ptuning_dir: str = None
|
||||||
use_ptuning_v2: bool = False
|
use_ptuning_v2: bool = False
|
||||||
|
|
@ -45,28 +44,30 @@ class LoaderCheckPoint:
|
||||||
模型初始化
|
模型初始化
|
||||||
:param params:
|
:param params:
|
||||||
"""
|
"""
|
||||||
self.model_path = None
|
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.params = params or {}
|
self.params = params or {}
|
||||||
|
self.model_name = params.get('model_name', False)
|
||||||
|
self.model_path = params.get('model_path', None)
|
||||||
self.no_remote_model = params.get('no_remote_model', False)
|
self.no_remote_model = params.get('no_remote_model', False)
|
||||||
self.model_name = params.get('model', '')
|
|
||||||
self.lora = params.get('lora', '')
|
self.lora = params.get('lora', '')
|
||||||
self.use_ptuning_v2 = params.get('use_ptuning_v2', False)
|
self.use_ptuning_v2 = params.get('use_ptuning_v2', False)
|
||||||
self.model_dir = params.get('model_dir', '')
|
|
||||||
self.lora_dir = params.get('lora_dir', '')
|
self.lora_dir = params.get('lora_dir', '')
|
||||||
self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2')
|
self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2')
|
||||||
self.load_in_8bit = params.get('load_in_8bit', False)
|
self.load_in_8bit = params.get('load_in_8bit', False)
|
||||||
self.bf16 = params.get('bf16', False)
|
self.bf16 = params.get('bf16', False)
|
||||||
|
|
||||||
def _load_model_config(self, model_name):
|
def _load_model_config(self, model_name):
|
||||||
checkpoint = Path(f'{self.model_dir}/{model_name}')
|
|
||||||
|
|
||||||
if self.model_path:
|
if self.model_path:
|
||||||
checkpoint = Path(f'{self.model_path}')
|
checkpoint = Path(f'{self.model_path}')
|
||||||
else:
|
else:
|
||||||
if not self.no_remote_model:
|
if not self.no_remote_model:
|
||||||
checkpoint = model_name
|
checkpoint = model_name
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"本地模型local_model_path未配置路径"
|
||||||
|
)
|
||||||
|
|
||||||
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
||||||
|
|
||||||
|
|
@ -81,16 +82,17 @@ class LoaderCheckPoint:
|
||||||
print(f"Loading {model_name}...")
|
print(f"Loading {model_name}...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
checkpoint = Path(f'{self.model_dir}/{model_name}')
|
|
||||||
|
|
||||||
self.is_llamacpp = len(list(checkpoint.glob('ggml*.bin'))) > 0
|
|
||||||
|
|
||||||
if self.model_path:
|
if self.model_path:
|
||||||
checkpoint = Path(f'{self.model_path}')
|
checkpoint = Path(f'{self.model_path}')
|
||||||
else:
|
else:
|
||||||
if not self.no_remote_model:
|
if not self.no_remote_model:
|
||||||
checkpoint = model_name
|
checkpoint = model_name
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"本地模型local_model_path未配置路径"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
|
||||||
if 'chatglm' in model_name.lower():
|
if 'chatglm' in model_name.lower():
|
||||||
LoaderClass = AutoModel
|
LoaderClass = AutoModel
|
||||||
else:
|
else:
|
||||||
|
|
@ -274,13 +276,16 @@ class LoaderCheckPoint:
|
||||||
"`pip install bitsandbytes``pip install accelerate`."
|
"`pip install bitsandbytes``pip install accelerate`."
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
checkpoint = Path(f'{self.model_dir}/{model_name}')
|
|
||||||
|
|
||||||
if self.model_path:
|
if self.model_path:
|
||||||
checkpoint = Path(f'{self.model_path}')
|
checkpoint = Path(f'{self.model_path}')
|
||||||
else:
|
else:
|
||||||
if not self.no_remote_model:
|
if not self.no_remote_model:
|
||||||
checkpoint = model_name
|
checkpoint = model_name
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"本地模型local_model_path未配置路径"
|
||||||
|
)
|
||||||
|
|
||||||
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
||||||
pretrained_model_name_or_path=checkpoint)
|
pretrained_model_name_or_path=checkpoint)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue