可选择lora权重加载 (#231)
* Add files via upload 增加lora权重使用 * Update model_config.py * Add files via upload 修复一个小错误,少写了模型加载 * 使用lora微调的权重 使用lora微调的权重 * Update model_config.py
This commit is contained in:
parent
47922d2ee3
commit
14d998b8e6
|
|
@ -61,9 +61,7 @@ def seperate_list(ls: List[int]) -> List[List[int]]:
|
|||
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
self, embedding: List[float], k: int = 4,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
||||
docs = []
|
||||
|
|
@ -122,12 +120,12 @@ class LocalDocQA:
|
|||
llm_model: str = LLM_MODEL,
|
||||
llm_device=LLM_DEVICE,
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
use_ptuning_v2: bool = USE_PTUNING_V2
|
||||
use_ptuning_v2: bool = USE_PTUNING_V2,
|
||||
use_lora: bool = USE_LORA,
|
||||
):
|
||||
self.llm = ChatGLM()
|
||||
self.llm.load_model(model_name_or_path=llm_model_dict[llm_model],
|
||||
llm_device=llm_device,
|
||||
use_ptuning_v2=use_ptuning_v2)
|
||||
llm_device=llm_device, use_ptuning_v2=use_ptuning_v2, use_lora=use_lora)
|
||||
self.llm.history_len = llm_history_len
|
||||
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
||||
|
|
|
|||
|
|
@ -27,6 +27,11 @@ llm_model_dict = {
|
|||
# LLM model name
|
||||
LLM_MODEL = "chatglm-6b"
|
||||
|
||||
# LLM lora path,默认为空,如果有请直接指定文件夹路径
|
||||
# 推荐使用 chatglm-6b-belle-zh-lora
|
||||
LLM_LORA_PATH = ""
|
||||
USE_LORA = True if LLM_LORA_PATH else False
|
||||
|
||||
# LLM streaming reponse
|
||||
STREAMING = True
|
||||
|
||||
|
|
|
|||
|
|
@ -78,11 +78,11 @@ class ChatGLM(LLM):
|
|||
torch_gc()
|
||||
else:
|
||||
response, _ = self.model.chat(
|
||||
self.tokenizer,
|
||||
prompt,
|
||||
history=history[-self.history_len:] if self.history_len > 0 else [],
|
||||
max_length=self.max_token,
|
||||
temperature=self.temperature,
|
||||
self.tokenizer,
|
||||
prompt,
|
||||
history=history[-self.history_len:] if self.history_len > 0 else [],
|
||||
max_length=self.max_token,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
torch_gc()
|
||||
history += [[prompt, response]]
|
||||
|
|
@ -106,6 +106,7 @@ class ChatGLM(LLM):
|
|||
model_name_or_path: str = "THUDM/chatglm-6b",
|
||||
llm_device=LLM_DEVICE,
|
||||
use_ptuning_v2=False,
|
||||
use_lora=False,
|
||||
device_map: Optional[Dict[str, int]] = None,
|
||||
**kwargs):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
|
@ -125,45 +126,32 @@ class ChatGLM(LLM):
|
|||
except Exception as e:
|
||||
print(e)
|
||||
print("加载PrefixEncoder config.json失败")
|
||||
self.model = AutoModel.from_pretrained(model_name_or_path, config=model_config, trust_remote_code=True,
|
||||
**kwargs)
|
||||
if LLM_LORA_PATH and use_lora:
|
||||
from peft import PeftModel
|
||||
self.model = PeftModel.from_pretrained(self.model, LLM_LORA_PATH)
|
||||
|
||||
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
|
||||
# 根据当前设备GPU数量决定是否进行多卡部署
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus < 2 and device_map is None:
|
||||
self.model = (
|
||||
AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
trust_remote_code=True,
|
||||
**kwargs)
|
||||
.half()
|
||||
.cuda()
|
||||
)
|
||||
self.model = self.model.half().cuda()
|
||||
else:
|
||||
from accelerate import dispatch_model
|
||||
|
||||
model = (
|
||||
AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
config=model_config,
|
||||
**kwargs)
|
||||
.half())
|
||||
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
|
||||
config=model_config, **kwargs)
|
||||
if LLM_LORA_PATH and use_lora:
|
||||
from peft import PeftModel
|
||||
model_auto = PeftModel.from_pretrained(model, LLM_LORA_PATH)
|
||||
# 可传入device_map自定义每张卡的部署情况
|
||||
if device_map is None:
|
||||
device_map = auto_configure_device_map(num_gpus)
|
||||
|
||||
self.model = dispatch_model(model, device_map=device_map)
|
||||
self.model = dispatch_model(model_auto.half(), device_map=device_map)
|
||||
else:
|
||||
self.model = (
|
||||
AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
trust_remote_code=True,
|
||||
**kwargs)
|
||||
.float()
|
||||
.to(llm_device)
|
||||
)
|
||||
self.model = self.model.float().to(llm_device)
|
||||
|
||||
if use_ptuning_v2:
|
||||
try:
|
||||
|
|
@ -185,7 +173,7 @@ if __name__ == "__main__":
|
|||
llm = ChatGLM()
|
||||
llm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL],
|
||||
llm_device=LLM_DEVICE, )
|
||||
last_print_len=0
|
||||
last_print_len = 0
|
||||
for resp, history in llm._call("你好", streaming=True):
|
||||
print(resp[last_print_len:], end="", flush=True)
|
||||
last_print_len = len(resp)
|
||||
|
|
|
|||
|
|
@ -12,4 +12,5 @@ accelerate
|
|||
gradio==3.24.1
|
||||
fastapi
|
||||
uvicorn
|
||||
peft
|
||||
#detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
|
||||
|
|
|
|||
8
webui.py
8
webui.py
|
|
@ -72,12 +72,13 @@ def init_model():
|
|||
return reply
|
||||
|
||||
|
||||
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history):
|
||||
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, history):
|
||||
try:
|
||||
local_doc_qa.init_cfg(llm_model=llm_model,
|
||||
embedding_model=embedding_model,
|
||||
llm_history_len=llm_history_len,
|
||||
use_ptuning_v2=use_ptuning_v2,
|
||||
use_lora = use_lora,
|
||||
top_k=top_k,)
|
||||
model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||
print(model_status)
|
||||
|
|
@ -246,6 +247,9 @@ with gr.Blocks(css=block_css) as demo:
|
|||
use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
|
||||
label="使用p-tuning-v2微调过的模型",
|
||||
interactive=True)
|
||||
use_lora = gr.Checkbox(USE_LORA,
|
||||
label="使用lora微调的权重",
|
||||
interactive=True)
|
||||
embedding_model = gr.Radio(embedding_model_dict_list,
|
||||
label="Embedding 模型",
|
||||
value=EMBEDDING_MODEL,
|
||||
|
|
@ -259,7 +263,7 @@ with gr.Blocks(css=block_css) as demo:
|
|||
load_model_button = gr.Button("重新加载模型")
|
||||
load_model_button.click(reinit_model,
|
||||
show_progress=True,
|
||||
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
|
||||
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, chatbot],
|
||||
outputs=chatbot
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue