chatglm init twice (#313)
This commit is contained in:
parent
2987c9cd52
commit
31655339f0
|
|
@ -141,16 +141,16 @@ class ChatGLM(LLM):
|
||||||
else:
|
else:
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
|
# model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
|
||||||
config=model_config, **kwargs)
|
# config=model_config, **kwargs)
|
||||||
if LLM_LORA_PATH and use_lora:
|
if LLM_LORA_PATH and use_lora:
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
model = PeftModel.from_pretrained(model, LLM_LORA_PATH)
|
model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
||||||
# 可传入device_map自定义每张卡的部署情况
|
# 可传入device_map自定义每张卡的部署情况
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
device_map = auto_configure_device_map(num_gpus)
|
device_map = auto_configure_device_map(num_gpus)
|
||||||
|
|
||||||
self.model = dispatch_model(model.half(), device_map=device_map)
|
self.model = dispatch_model(self.model.half(), device_map=device_map)
|
||||||
else:
|
else:
|
||||||
self.model = self.model.float().to(llm_device)
|
self.model = self.model.float().to(llm_device)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue