Dev (#1046)
* 增加lora检查点加载成功提示 * 修复baichuan_llm引起的bug * update model_config --------- Co-authored-by: hzg0601 <hzg0601@163.com>
This commit is contained in:
parent
0d1d9c5ed7
commit
2a44bd870e
|
|
@ -178,3 +178,4 @@ flagged/*
|
||||||
ptuning-v2/*.json
|
ptuning-v2/*.json
|
||||||
ptuning-v2/*.bin
|
ptuning-v2/*.bin
|
||||||
|
|
||||||
|
*.log.*
|
||||||
|
|
|
||||||
|
|
@ -130,10 +130,7 @@ llm_model_dict = {
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "LLamaLLMChain"
|
"provides": "LLamaLLMChain"
|
||||||
},
|
},
|
||||||
# 直接调用返回requests.exceptions.ConnectionError错误,需要通过huggingface_hub包里的snapshot_download函数
|
|
||||||
# 下载模型,如果snapshot_download还是返回网络错误,多试几次,一般是可以的,
|
|
||||||
# 如果仍然不行,则应该是网络加了防火墙(在服务器上这种情况比较常见),基本只能从别的设备上下载,
|
|
||||||
# 然后转移到目标设备了.
|
|
||||||
"bloomz-7b1": {
|
"bloomz-7b1": {
|
||||||
"name": "bloomz-7b1",
|
"name": "bloomz-7b1",
|
||||||
"pretrained_model_name": "bigscience/bloomz-7b1",
|
"pretrained_model_name": "bigscience/bloomz-7b1",
|
||||||
|
|
@ -250,8 +247,8 @@ USE_LORA = True if LORA_NAME else False
|
||||||
# LLM streaming reponse
|
# LLM streaming reponse
|
||||||
STREAMING = True
|
STREAMING = True
|
||||||
|
|
||||||
# 直接定义baichuan的lora完整路径即可
|
# 直接定义baichuan的lora完整路径即可,"" != False
|
||||||
LORA_MODEL_PATH_BAICHUAN=""
|
LORA_MODEL_PATH_BAICHUAN=None
|
||||||
|
|
||||||
# Use p-tuning-v2 PrefixEncoder
|
# Use p-tuning-v2 PrefixEncoder
|
||||||
USE_PTUNING_V2 = False
|
USE_PTUNING_V2 = False
|
||||||
|
|
|
||||||
|
|
@ -457,6 +457,7 @@ class LoaderCheckPoint:
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
else:
|
else:
|
||||||
self.model = self.model.cuda()
|
self.model = self.model.cuda()
|
||||||
|
print("加载lora检查点成功.")
|
||||||
|
|
||||||
def clear_torch_cache(self):
|
def clear_torch_cache(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue