* 增加lora检查点加载成功提示

* 修复baichuan_llm引起的bug

* update model_config

---------

Co-authored-by: hzg0601 <hzg0601@163.com>
This commit is contained in:
imClumsyPanda 2023-08-11 10:50:21 +08:00 committed by GitHub
parent 0d1d9c5ed7
commit 2a44bd870e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 6 deletions

1
.gitignore vendored
View File

@ -178,3 +178,4 @@ flagged/*
ptuning-v2/*.json
ptuning-v2/*.bin
*.log.*

View File

@ -130,10 +130,7 @@ llm_model_dict = {
"local_model_path": None,
"provides": "LLamaLLMChain"
},
# 直接调用返回requests.exceptions.ConnectionError错误需要通过huggingface_hub包里的snapshot_download函数
# 下载模型如果snapshot_download还是返回网络错误多试几次一般是可以的
# 如果仍然不行,则应该是网络加了防火墙(在服务器上这种情况比较常见),基本只能从别的设备上下载,
# 然后转移到目标设备了.
"bloomz-7b1": {
"name": "bloomz-7b1",
"pretrained_model_name": "bigscience/bloomz-7b1",
@ -250,8 +247,8 @@ USE_LORA = True if LORA_NAME else False
# LLM streaming reponse
STREAMING = True
# 直接定义baichuan的lora完整路径即可
LORA_MODEL_PATH_BAICHUAN=""
# 直接定义baichuan的lora完整路径即可,"" != False
LORA_MODEL_PATH_BAICHUAN=None
# Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False

View File

@ -457,6 +457,7 @@ class LoaderCheckPoint:
self.model = self.model.to(device)
else:
self.model = self.model.cuda()
print("加载lora检查点成功.")
def clear_torch_cache(self):
gc.collect()