diff --git a/webui.py b/webui.py index cfe7b58..a93de0c 100644 --- a/webui.py +++ b/webui.py @@ -25,8 +25,6 @@ def get_vs_list(): return lst_default + lst -vs_list = get_vs_list() - embedding_model_dict_list = list(embedding_model_dict.keys()) llm_model_dict_list = list(llm_model_dict.keys()) @@ -44,11 +42,12 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR query=query, chat_history=history, streaming=streaming): source = "\n\n" source += "".join( - [f"""
出处 [{i + 1}] {doc.metadata["source"]} \n""" - f"""{doc.page_content}\n""" - f"""
""" - for i, doc in - enumerate(resp["source_documents"])]) + [ + f"""
出处 [{i + 1}] {doc.metadata["source"]} \n""" + f"""{doc.page_content}\n""" + f"""
""" + for i, doc in + enumerate(resp["source_documents"])]) history[-1][-1] += source yield history, "" elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path): @@ -89,7 +88,6 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR else: for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history, streaming=streaming): - resp = answer_result.llm_output["answer"] history = answer_result.history history[-1][-1] = resp + ( @@ -99,9 +97,15 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME) -def init_model(llm_model: BaseAnswer = None): +def init_model(): + args = parser.parse_args() + + args_dict = vars(args) + shared.loaderCheckPoint = LoaderCheckPoint(args_dict) + llm_model_ins = shared.loaderLLM() + llm_model_ins.set_history_len(LLM_HISTORY_LEN) try: - local_doc_qa.init_cfg(llm_model=llm_model) + local_doc_qa.init_cfg(llm_model=llm_model_ins) generator = local_doc_qa.llm.generatorAnswer("你好") for answer_result in generator: print(answer_result.llm_output) @@ -119,7 +123,9 @@ def init_model(llm_model: BaseAnswer = None): return reply -def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history): +def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, + history): + try: llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) llm_model_ins.history_len = llm_history_len @@ -138,8 +144,6 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): vs_path = os.path.join(VS_ROOT_PATH, vs_id) filelist = [] - if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)): - os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id)) if local_doc_qa.llm and local_doc_qa.embeddings: if isinstance(files, list): for file in files: @@ -166,9 +170,8 @@ def change_vs_name_input(vs_id, history): return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history else: file_status = f"已加载知识库{vs_id},请开始提问" - return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH, - vs_id), history + [ - [None, file_status]] + return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \ + os.path.join(VS_ROOT_PATH, vs_id), history + [[None, file_status]] knowledge_base_test_mode_info = ("【注意】\n\n" @@ -206,19 +209,29 @@ def change_chunk_conent(mode, label_conent, history): return gr.update(visible=False), history + [[None, f"【已关闭{conent}】"]] -def add_vs_name(vs_name, vs_list, chatbot): - if vs_name in vs_list: +def add_vs_name(vs_name, chatbot): + if vs_name in get_vs_list(): vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交" chatbot = chatbot + [[None, vs_status]] - return gr.update(visible=True), vs_list, gr.update(visible=True), gr.update(visible=True), gr.update( + return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update( visible=False), chatbot else: + # 新建上传文件存储路径 + if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_name)): + os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_name)) + # 新建向量库存储路径 + if not os.path.exists(os.path.join(VS_ROOT_PATH, vs_name)): + os.makedirs(os.path.join(VS_ROOT_PATH, vs_name)) vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """ chatbot = chatbot + [[None, vs_status]] - return gr.update(visible=True, choices=[vs_name] + vs_list, value=vs_name), [vs_name] + vs_list, gr.update( + return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update( visible=False), gr.update(visible=False), gr.update(visible=True), chatbot +def refresh_vs_list(): + return gr.update(choices=get_vs_list()) + + block_css = """.importantButton { background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; border: none !important; @@ -232,7 +245,7 @@ webui_title = """ # 🎉langchain-ChatGLM WebUI🎉 👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM) """ -default_vs = vs_list[0] if len(vs_list) > 1 else "为空" +default_vs = get_vs_list()[0] if len(get_vs_list()) > 1 else "为空" init_message = f"""欢迎使用 langchain-ChatGLM Web UI! 请在右侧切换模式,目前支持直接与 LLM 模型对话或基于本地知识库问答。 @@ -243,16 +256,7 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI! """ # 初始化消息 -args = None -args = parser.parse_args() - -args_dict = vars(args) -shared.loaderCheckPoint = LoaderCheckPoint(args_dict) -llm_model_ins = shared.loaderLLM() -llm_model_ins.set_history_len(LLM_HISTORY_LEN) - -model_status = init_model(llm_model=llm_model_ins) - +model_status = init_model() default_theme_args = dict( font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'], @@ -260,10 +264,9 @@ default_theme_args = dict( ) with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo: - vs_path, file_status, model_status, vs_list = gr.State( - os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""), gr.State(""), gr.State( - model_status), gr.State(vs_list) - + vs_path, file_status, model_status = gr.State( + os.path.join(VS_ROOT_PATH, get_vs_list()[0]) if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State( + model_status) gr.Markdown(webui_title) with gr.Tab("对话"): with gr.Row(): @@ -283,10 +286,11 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as inputs=[mode, chatbot], outputs=[vs_setting, knowledge_set, chatbot]) with vs_setting: - select_vs = gr.Dropdown(vs_list.value, + vs_refresh = gr.Button("更新已有知识库选项") + select_vs = gr.Dropdown(get_vs_list(), label="请选择要加载的知识库", interactive=True, - value=vs_list.value[0] if len(vs_list.value) > 0 else None + value=get_vs_list()[0] if len(get_vs_list()) > 0 else None ) vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文", lines=1, @@ -302,19 +306,21 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as interactive=True, visible=True) with gr.Tab("上传文件"): files = gr.File(label="添加文件", - file_types=['.txt', '.md', '.docx', '.pdf'], + file_types=['.txt', '.md', '.docx', '.pdf', '.png', '.jpg'], file_count="multiple", show_label=False) load_file_button = gr.Button("上传文件并加载知识库") with gr.Tab("上传文件夹"): folder_files = gr.File(label="添加文件", - # file_types=['.txt', '.md', '.docx', '.pdf'], file_count="directory", show_label=False) load_folder_button = gr.Button("上传文件夹并加载知识库") + vs_refresh.click(fn=refresh_vs_list, + inputs=[], + outputs=select_vs) vs_add.click(fn=add_vs_name, - inputs=[vs_name, vs_list, chatbot], - outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot]) + inputs=[vs_name, chatbot], + outputs=[select_vs, vs_name, vs_add, file2vs, chatbot]) select_vs.change(fn=change_vs_name_input, inputs=[select_vs, chatbot], outputs=[vs_name, vs_add, file2vs, vs_path, chatbot]) @@ -366,10 +372,11 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as inputs=[chunk_conent, gr.Textbox(value="chunk_conent", visible=False), chatbot], outputs=[chunk_sizes, chatbot]) with vs_setting: - select_vs = gr.Dropdown(vs_list.value, + vs_refresh = gr.Button("更新已有知识库选项") + select_vs = gr.Dropdown(get_vs_list(), label="请选择要加载的知识库", interactive=True, - value=vs_list.value[0] if len(vs_list.value) > 0 else None) + value=get_vs_list()[0] if len(get_vs_list()) > 0 else None) vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文", lines=1, interactive=True, @@ -402,9 +409,12 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as interactive=True) load_conent_button = gr.Button("添加内容并加载知识库") # 将上传的文件保存到content文件夹下,并更新下拉框 + vs_refresh.click(fn=refresh_vs_list, + inputs=[], + outputs=select_vs) vs_add.click(fn=add_vs_name, - inputs=[vs_name, vs_list, chatbot], - outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot]) + inputs=[vs_name, chatbot], + outputs=[select_vs, vs_name, vs_add, file2vs, chatbot]) select_vs.change(fn=change_vs_name_input, inputs=[select_vs, chatbot], outputs=[vs_name, vs_add, file2vs, vs_path, chatbot]) @@ -455,8 +465,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as label="向量匹配 top k", interactive=True) load_model_button = gr.Button("重新加载模型") load_model_button.click(reinit_model, show_progress=True, - inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, - top_k, chatbot], outputs=chatbot) + inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, + use_lora, top_k, chatbot], outputs=chatbot) (demo .queue(concurrency_count=3) @@ -464,4 +474,4 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as server_port=7860, show_api=False, share=False, - inbrowser=False)) \ No newline at end of file + inbrowser=False))