diff --git a/webui.py b/webui.py index 12808cb..d0d0862 100644 --- a/webui.py +++ b/webui.py @@ -8,17 +8,19 @@ import uuid nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + def get_vs_list(): - lst_default = ["新建知识库"] + lst_default = ["新建知识库"] if not os.path.exists(VS_ROOT_PATH): return lst_default - lst= os.listdir(VS_ROOT_PATH) - if not lst: + lst = os.listdir(VS_ROOT_PATH) + if not lst: return lst_default lst.sort(reverse=True) - return lst+ lst_default + return lst + lst_default -vs_list =get_vs_list() + +vs_list = get_vs_list() embedding_model_dict_list = list(embedding_model_dict.keys()) @@ -29,6 +31,7 @@ local_doc_qa = LocalDocQA() logger = gr.CSVLogger() username = uuid.uuid4().hex + def get_answer(query, vs_path, history, mode, streaming: bool = STREAMING): if mode == "知识库问答" and vs_path: @@ -51,8 +54,9 @@ def get_answer(query, vs_path, history, mode, streaming=streaming): history[-1][-1] = resp + ( "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") - yield history, "" - logger.flag([query, vs_path, history, mode],username=username) + yield history, "" + logger.flag([query, vs_path, history, mode], username=username) + def init_model(): try: @@ -78,8 +82,8 @@ def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, us embedding_model=embedding_model, llm_history_len=llm_history_len, use_ptuning_v2=use_ptuning_v2, - use_lora = use_lora, - top_k=top_k,) + use_lora=use_lora, + top_k=top_k, ) model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" print(model_status) except Exception as e: @@ -111,12 +115,14 @@ def get_vector_store(vs_id, files, history): return vs_path, None, history + [[None, file_status]] -def change_vs_name_input(vs_id,history): +def change_vs_name_input(vs_id, history): if vs_id == "新建知识库": - return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None,history + return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history else: file_status = f"已加载知识库{vs_id},请开始提问" - return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, vs_id),history + [[None, file_status]] + return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, + vs_id), history + [ + [None, file_status]] def change_mode(mode): @@ -136,6 +142,7 @@ def add_vs_name(vs_name, vs_list, chatbot): chatbot = chatbot + [[None, vs_status]] return gr.update(visible=True, choices=vs_list + [vs_name], value=vs_name), vs_list + [vs_name], chatbot + block_css = """.importantButton { background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; border: none !important; @@ -163,10 +170,11 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI! """ model_status = init_model() -default_path = os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else "" +default_path = os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else "" with gr.Blocks(css=block_css) as demo: - vs_path, file_status, model_status, vs_list = gr.State(default_path), gr.State(""), gr.State(model_status), gr.State(vs_list) + vs_path, file_status, model_status, vs_list = gr.State(default_path), gr.State(""), gr.State( + model_status), gr.State(vs_list) gr.Markdown(webui_title) with gr.Tab("对话"): with gr.Row(): @@ -175,7 +183,7 @@ with gr.Blocks(css=block_css) as demo: elem_id="chat-box", show_label=False).style(height=750) query = gr.Textbox(show_label=False, - placeholder="请输入提问内容,按回车进行提交").style(container=False) + placeholder="请输入提问内容,按回车进行提交").style(container=False) with gr.Column(scale=5): mode = gr.Radio(["LLM 对话", "知识库问答"], label="请选择使用模式", @@ -218,7 +226,7 @@ with gr.Blocks(css=block_css) as demo: load_folder_button = gr.Button("上传文件夹并加载知识库") # load_vs.click(fn=) select_vs.change(fn=change_vs_name_input, - inputs=[select_vs,chatbot], + inputs=[select_vs, chatbot], outputs=[vs_name, vs_add, file2vs, vs_path, chatbot]) # 将上传的文件保存到content文件夹下,并更新下拉框 load_file_button.click(get_vector_store, @@ -230,11 +238,11 @@ with gr.Blocks(css=block_css) as demo: show_progress=True, inputs=[select_vs, folder_files, chatbot], outputs=[vs_path, folder_files, chatbot], - ) - logger.setup([query, vs_path, chatbot, mode], "flagged") + ) + logger.setup([query, vs_path, chatbot, mode], "flagged") query.submit(get_answer, - [query, vs_path, chatbot, mode], - [chatbot, query]) + [query, vs_path, chatbot, mode], + [chatbot, query]) with gr.Tab("模型配置"): llm_model = gr.Radio(llm_model_dict_list, label="LLM 模型", @@ -250,8 +258,8 @@ with gr.Blocks(css=block_css) as demo: label="使用p-tuning-v2微调过的模型", interactive=True) use_lora = gr.Checkbox(USE_LORA, - label="使用lora微调的权重", - interactive=True) + label="使用lora微调的权重", + interactive=True) embedding_model = gr.Radio(embedding_model_dict_list, label="Embedding 模型", value=EMBEDDING_MODEL, @@ -265,7 +273,8 @@ with gr.Blocks(css=block_css) as demo: load_model_button = gr.Button("重新加载模型") load_model_button.click(reinit_model, show_progress=True, - inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, chatbot], + inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, use_lora, top_k, + chatbot], outputs=chatbot )