diff --git a/webui.py b/webui.py index 18a6a3e..f041da6 100644 --- a/webui.py +++ b/webui.py @@ -31,9 +31,12 @@ def upload_file(file): def get_answer(query, vs_path, history): - resp, history = local_doc_qa.get_knowledge_based_answer( - query=query, vs_path=vs_path, chat_history=history) - return history, history + if vs_path: + resp, history = local_doc_qa.get_knowledge_based_answer( + query=query, vs_path=vs_path, chat_history=history) + else: + history = history + [[None, "请先加载文件后,再进行提问。"]] + return history def update_status(history, status): @@ -50,26 +53,30 @@ def init_model(): return """模型未成功加载,请重新选择后点击"加载模型"按钮""" -def reinit_model(llm_model, embedding_model, llm_history_len, top_k): +def reinit_model(llm_model, embedding_model, llm_history_len, top_k, history): try: local_doc_qa.init_cfg(llm_model=llm_model, embedding_model=embedding_model, llm_history_len=llm_history_len, top_k=top_k) - return """模型已成功重新加载,请选择文件后点击"加载文件"按钮""" + model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮""" except: - return """模型未成功重新加载,请重新选择后点击"加载模型"按钮""" + model_status = """模型未成功重新加载,请重新选择后点击"加载模型"按钮""" + return history + [[None, model_status]] -def get_vector_store(filepath): - vs_path = local_doc_qa.init_knowledge_vector_store(["content/" + filepath]) - if vs_path: - file_status = "文件已成功加载,请开始提问" +def get_vector_store(filepath, history): + if local_doc_qa.llm and local_doc_qa.llm: + vs_path = local_doc_qa.init_knowledge_vector_store(["content/" + filepath]) + if vs_path: + file_status = "文件已成功加载,请开始提问" + else: + file_status = "文件未成功加载,请重新上传文件" else: - file_status = "文件未成功加载,请重新上传文件" - print(file_status) - return vs_path, file_status + file_status = "模型未完成加载,请先在加载模型后再导入文件" + vs_path = None + return vs_path, history + [[None, file_status]] block_css = """.importantButton { @@ -98,7 +105,7 @@ init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请 model_status = init_model() with gr.Blocks(css=block_css) as demo: - vs_path, history, file_status, model_status = gr.State(""), gr.State([]), gr.State(""), gr.State(model_status) + vs_path, file_status, model_status = gr.State(""), gr.State(""), gr.State(model_status) gr.Markdown(webui_title) with gr.Row(): with gr.Column(scale=2): @@ -106,9 +113,9 @@ with gr.Blocks(css=block_css) as demo: elem_id="chat-box", show_label=False).style(height=750) query = gr.Textbox(show_label=False, - placeholder="请提问", - lines=1, - value="用200字总结一下" + placeholder="请输入提问内容,按回车进行提交", + # lines=1, + # value="用200字总结一下" ).style(container=False) with gr.Column(scale=1): @@ -144,26 +151,24 @@ with gr.Blocks(css=block_css) as demo: file = gr.File(label="content file", file_types=['.txt', '.md', '.docx', '.pdf'] ) # .style(height=100) - load_file_button = gr.Button("重新加载文件") + load_file_button = gr.Button("加载文件") load_model_button.click(reinit_model, show_progress=True, - inputs=[llm_model, embedding_model, llm_history_len, top_k], - outputs=model_status - ).then(update_status, [chatbot, model_status], chatbot) + inputs=[llm_model, embedding_model, llm_history_len, top_k, chatbot], + outputs=chatbot + ) # 将上传的文件保存到content文件夹下,并更新下拉框 file.upload(upload_file, inputs=file, outputs=selectFile) load_file_button.click(get_vector_store, show_progress=True, - inputs=selectFile, - outputs=[vs_path, file_status], - ).then( - update_status, [chatbot, file_status], chatbot - ) + inputs=[selectFile, chatbot], + outputs=[vs_path, chatbot], + ) query.submit(get_answer, [query, vs_path, chatbot], - [chatbot, history], + [chatbot], ) demo.queue(concurrency_count=3).launch(