可选择lora权重加载 (#231)
* Add files via upload 增加lora权重使用 * Update model_config.py * Add files via upload 修复一个小错误,少写了模型加载 * 使用lora微调的权重 使用lora微调的权重 * Update model_config.py
This commit is contained in:
parent
47922d2ee3
commit
14d998b8e6
|
|
@ -61,9 +61,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
|
||||||
|
|
||||||
|
|
||||||
def similarity_search_with_score_by_vector(
|
def similarity_search_with_score_by_vector(
|
||||||
self,
|
self, embedding: List[float], k: int = 4,
|
||||||
embedding: List[float],
|
|
||||||
k: int = 4,
|
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
||||||
docs = []
|
docs = []
|
||||||
|
|
@ -122,12 +120,12 @@ class LocalDocQA:
|
||||||
llm_model: str = LLM_MODEL,
|
llm_model: str = LLM_MODEL,
|
||||||
llm_device=LLM_DEVICE,
|
llm_device=LLM_DEVICE,
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
use_ptuning_v2: bool = USE_PTUNING_V2
|
use_ptuning_v2: bool = USE_PTUNING_V2,
|
||||||
|
use_lora: bool = USE_LORA,
|
||||||
):
|
):
|
||||||
self.llm = ChatGLM()
|
self.llm = ChatGLM()
|
||||||
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
||||||
llm_device=llm_device,
|
llm_device=llm_device, use_ptuning_v2=use_ptuning_v2, use_lora=use_lora)
|
||||||
use_ptuning_v2=use_ptuning_v2)
|
|
||||||
self.llm.history_len = llm_history_len
|
self.llm.history_len = llm_history_len
|
||||||
|
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,11 @@ llm_model_dict = {
|
||||||
# LLM model name
|
# LLM model name
|
||||||
LLM_MODEL = "chatglm-6b"
|
LLM_MODEL = "chatglm-6b"
|
||||||
|
|
||||||
|
# LLM lora path,默认为空,如果有请直接指定文件夹路径
|
||||||
|
# 推荐使用 chatglm-6b-belle-zh-lora
|
||||||
|
LLM_LORA_PATH = ""
|
||||||
|
USE_LORA = True if LLM_LORA_PATH else False
|
||||||
|
|
||||||
# LLM streaming reponse
|
# LLM streaming reponse
|
||||||
STREAMING = True
|
STREAMING = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -106,6 +106,7 @@ class ChatGLM(LLM):
|
||||||
model_name_or_path: str = "THUDM/chatglm-6b",
|
model_name_or_path: str = "THUDM/chatglm-6b",
|
||||||
llm_device=LLM_DEVICE,
|
llm_device=LLM_DEVICE,
|
||||||
use_ptuning_v2=False,
|
use_ptuning_v2=False,
|
||||||
|
use_lora=False,
|
||||||
device_map: Optional[Dict[str, int]] = None,
|
device_map: Optional[Dict[str, int]] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
|
@ -125,45 +126,32 @@ class ChatGLM(LLM):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
print("加载PrefixEncoder config.json失败")
|
print("加载PrefixEncoder config.json失败")
|
||||||
|
self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True,
|
||||||
|
**kwargs)
|
||||||
|
if LLM_LORA_PATH and use_lora:
|
||||||
|
from peft import PeftModel
|
||||||
|
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
||||||
|
|
||||||
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
||||||
# 根据当前设备GPU数量决定是否进行多卡部署
|
# 根据当前设备GPU数量决定是否进行多卡部署
|
||||||
num_gpus = torch.cuda.device_count()
|
num_gpus = torch.cuda.device_count()
|
||||||
if num_gpus < 2 and device_map is None:
|
if num_gpus < 2 and device_map is None:
|
||||||
self.model = (
|
self.model = self.model.half().cuda()
|
||||||
AutoModel.from_pretrained(
|
|
||||||
model_name_or_path,
|
|
||||||
config=model_config,
|
|
||||||
trust_remote_code=True,
|
|
||||||
**kwargs)
|
|
||||||
.half()
|
|
||||||
.cuda()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
|
|
||||||
model = (
|
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
|
||||||
AutoModel.from_pretrained(
|
config=model_config, **kwargs)
|
||||||
model_name_or_path,
|
if LLM_LORA_PATH and use_lora:
|
||||||
trust_remote_code=True,
|
from peft import PeftModel
|
||||||
config=model_config,
|
model_auto = PeftModel.from_pretrained(model, LLM_LORA_PATH)
|
||||||
**kwargs)
|
|
||||||
.half())
|
|
||||||
# 可传入device_map自定义每张卡的部署情况
|
# 可传入device_map自定义每张卡的部署情况
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
device_map = auto_configure_device_map(num_gpus)
|
device_map = auto_configure_device_map(num_gpus)
|
||||||
|
|
||||||
self.model = dispatch_model(model, device_map=device_map)
|
self.model = dispatch_model(model_auto.half(), device_map=device_map)
|
||||||
else:
|
else:
|
||||||
self.model = (
|
self.model = self.model.float().to(llm_device)
|
||||||
AutoModel.from_pretrained(
|
|
||||||
model_name_or_path,
|
|
||||||
config=model_config,
|
|
||||||
trust_remote_code=True,
|
|
||||||
**kwargs)
|
|
||||||
.float()
|
|
||||||
.to(llm_device)
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_ptuning_v2:
|
if use_ptuning_v2:
|
||||||
try:
|
try:
|
||||||
|
|
@ -185,7 +173,7 @@ if __name__ == "__main__":
|
||||||
llm = ChatGLM()
|
llm = ChatGLM()
|
||||||
llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
|
llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
|
||||||
llm_device=LLM_DEVICE, )
|
llm_device=LLM_DEVICE, )
|
||||||
last_print_len=0
|
last_print_len = 0
|
||||||
for resp, history in llm._call("你好", streaming=True):
|
for resp, history in llm._call("你好", streaming=True):
|
||||||
print(resp[last_print_len:], end="", flush=True)
|
print(resp[last_print_len:], end="", flush=True)
|
||||||
last_print_len = len(resp)
|
last_print_len = len(resp)
|
||||||
|
|
|
||||||
|
|
@ -12,4 +12,5 @@ accelerate
|
||||||
gradio==3.24.1
|
gradio==3.24.1
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
|
peft
|
||||||
#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
|
#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
|
||||||
|
|
|
||||||
8
webui.py
8
webui.py
|
|
@ -72,12 +72,13 @@ def init_model():
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history):
|
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, history):
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model,
|
local_doc_qa.init_cfg(llm_model=llm_model,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
llm_history_len=llm_history_len,
|
llm_history_len=llm_history_len,
|
||||||
use_ptuning_v2=use_ptuning_v2,
|
use_ptuning_v2=use_ptuning_v2,
|
||||||
|
use_lora = use_lora,
|
||||||
top_k=top_k,)
|
top_k=top_k,)
|
||||||
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
print(model_status)
|
print(model_status)
|
||||||
|
|
@ -246,6 +247,9 @@ with gr.Blocks(css=block_css) as demo:
|
||||||
use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
|
use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
|
||||||
label="使用p-tuning-v2微调过的模型",
|
label="使用p-tuning-v2微调过的模型",
|
||||||
interactive=True)
|
interactive=True)
|
||||||
|
use_lora = gr.Checkbox(USE_LORA,
|
||||||
|
label="使用lora微调的权重",
|
||||||
|
interactive=True)
|
||||||
embedding_model = gr.Radio(embedding_model_dict_list,
|
embedding_model = gr.Radio(embedding_model_dict_list,
|
||||||
label="Embedding 模型",
|
label="Embedding 模型",
|
||||||
value=EMBEDDING_MODEL,
|
value=EMBEDDING_MODEL,
|
||||||
|
|
@ -259,7 +263,7 @@ with gr.Blocks(css=block_css) as demo:
|
||||||
load_model_button = gr.Button("重新加载模型")
|
load_model_button = gr.Button("重新加载模型")
|
||||||
load_model_button.click(reinit_model,
|
load_model_button.click(reinit_model,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
|
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, chatbot],
|
||||||
outputs=chatbot
|
outputs=chatbot
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue