From 2a44bd870e087b65fdde4f465fa2e0e1a827cbd9 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Fri, 11 Aug 2023 10:50:21 +0800 Subject: [PATCH] Dev (#1046) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 增加lora检查点加载成功提示 * 修复baichuan_llm引起的bug * update model_config --------- Co-authored-by: hzg0601 --- .gitignore | 1 + configs/model_config.py | 9 +++------ models/loader/loader.py | 1 + 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 73ebd21..90de57e 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,4 @@ flagged/* ptuning-v2/*.json ptuning-v2/*.bin +*.log.* diff --git a/configs/model_config.py b/configs/model_config.py index 2ca08d6..524f1f1 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -130,10 +130,7 @@ llm_model_dict = { "local_model_path": None, "provides": "LLamaLLMChain" }, - # 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数 - # 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的, - # 如果仍然不行,则应该是网络加了防火墙(在服务器上这种情况比较常见),基本只能从别的设备上下载, - # 然后转移到目标设备了. + "bloomz-7b1": { "name": "bloomz-7b1", "pretrained_model_name": "bigscience/bloomz-7b1", @@ -250,8 +247,8 @@ USE_LORA = True if LORA_NAME else False # LLM streaming reponse STREAMING = True -# 直接定义baichuan的lora完整路径即可 -LORA_MODEL_PATH_BAICHUAN="" +# 直接定义baichuan的lora完整路径即可,"" != False +LORA_MODEL_PATH_BAICHUAN=None # Use p-tuning-v2 PrefixEncoder USE_PTUNING_V2 = False diff --git a/models/loader/loader.py b/models/loader/loader.py index 6ba46cf..f6c75db 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -457,6 +457,7 @@ class LoaderCheckPoint: self.model = self.model.to(device) else: self.model = self.model.cuda() + print("加载lora检查点成功.") def clear_torch_cache(self): gc.collect()