fix bug in chatglm_llm.py
This commit is contained in:
parent
5c34dd94e4
commit
f147043253
|
|
@ -41,10 +41,12 @@ class LocalDocQA:
|
||||||
llm_model: str = LLM_MODEL,
|
llm_model: str = LLM_MODEL,
|
||||||
llm_device=LLM_DEVICE,
|
llm_device=LLM_DEVICE,
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
|
use_ptuning_v2: bool = USE_PTUNING_V2
|
||||||
):
|
):
|
||||||
self.llm = ChatGLM()
|
self.llm = ChatGLM()
|
||||||
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
||||||
llm_device=llm_device)
|
llm_device=llm_device,
|
||||||
|
use_ptuning_v2=use_ptuning_v2)
|
||||||
self.llm.history_len = llm_history_len
|
self.llm.history_len = llm_history_len
|
||||||
|
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
|
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], )
|
||||||
|
|
|
||||||
|
|
@ -127,14 +127,6 @@ class ChatGLM(LLM):
|
||||||
device_map = auto_configure_device_map(num_gpus)
|
device_map = auto_configure_device_map(num_gpus)
|
||||||
|
|
||||||
self.model = dispatch_model(model, device_map=device_map)
|
self.model = dispatch_model(model, device_map=device_map)
|
||||||
self.model = (
|
|
||||||
AutoModel.from_pretrained(
|
|
||||||
model_name_or_path,
|
|
||||||
config=model_config,
|
|
||||||
trust_remote_code=True)
|
|
||||||
.half()
|
|
||||||
.cuda()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.model = (
|
self.model = (
|
||||||
AutoModel.from_pretrained(
|
AutoModel.from_pretrained(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue