update webui.py

This commit is contained in:
imClumsyPanda 2023-05-06 12:07:32 +08:00
parent 63453f2340
commit 0a4dd1987d
1 changed files with 32 additions and 23 deletions

View File

@ -8,6 +8,7 @@ import uuid
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
def get_vs_list(): def get_vs_list():
lst_default = ["新建知识库"] lst_default = ["新建知识库"]
if not os.path.exists(VS_ROOT_PATH): if not os.path.exists(VS_ROOT_PATH):
@ -18,6 +19,7 @@ def get_vs_list():
lst.sort(reverse=True) lst.sort(reverse=True)
return lst + lst_default return lst + lst_default
vs_list = get_vs_list() vs_list = get_vs_list()
embedding_model_dict_list = list(embedding_model_dict.keys()) embedding_model_dict_list = list(embedding_model_dict.keys())
@ -29,6 +31,7 @@ local_doc_qa = LocalDocQA()
logger = gr.CSVLogger() logger = gr.CSVLogger()
username = uuid.uuid4().hex username = uuid.uuid4().hex
def get_answer(query, vs_path, history, mode, def get_answer(query, vs_path, history, mode,
streaming: bool = STREAMING): streaming: bool = STREAMING):
if mode == "知识库问答" and vs_path: if mode == "知识库问答" and vs_path:
@ -54,6 +57,7 @@ def get_answer(query, vs_path, history, mode,
yield history, "" yield history, ""
logger.flag([query, vs_path, history, mode], username=username) logger.flag([query, vs_path, history, mode], username=username)
def init_model(): def init_model():
try: try:
local_doc_qa.init_cfg() local_doc_qa.init_cfg()
@ -116,7 +120,9 @@ def change_vs_name_input(vs_id,history):
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
else: else:
file_status = f"已加载知识库{vs_id},请开始提问" file_status = f"已加载知识库{vs_id},请开始提问"
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, vs_id),history + [[None, file_status]] return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH,
vs_id), history + [
[None, file_status]]
def change_mode(mode): def change_mode(mode):
@ -136,6 +142,7 @@ def add_vs_name(vs_name, vs_list, chatbot):
chatbot = chatbot + [[None, vs_status]] chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True, choices=vs_list + [vs_name], value=vs_name), vs_list + [vs_name], chatbot return gr.update(visible=True, choices=vs_list + [vs_name], value=vs_name), vs_list + [vs_name], chatbot
block_css = """.importantButton { block_css = """.importantButton {
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
border: none !important; border: none !important;
@ -166,7 +173,8 @@ model_status = init_model()
default_path = os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else "" default_path = os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""
with gr.Blocks(css=block_css) as demo: with gr.Blocks(css=block_css) as demo:
vs_path, file_status, model_status, vs_list = gr.State(default_path), gr.State(""), gr.State(model_status), gr.State(vs_list) vs_path, file_status, model_status, vs_list = gr.State(default_path), gr.State(""), gr.State(
model_status), gr.State(vs_list)
gr.Markdown(webui_title) gr.Markdown(webui_title)
with gr.Tab("对话"): with gr.Tab("对话"):
with gr.Row(): with gr.Row():
@ -265,7 +273,8 @@ with gr.Blocks(css=block_css) as demo:
load_model_button = gr.Button("重新加载模型") load_model_button = gr.Button("重新加载模型")
load_model_button.click(reinit_model, load_model_button.click(reinit_model,
show_progress=True, show_progress=True,
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, chatbot], inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k,
chatbot],
outputs=chatbot outputs=chatbot
) )