diff --git a/webui.py b/webui.py index a93de0c..5676dc2 100644 --- a/webui.py +++ b/webui.py @@ -50,18 +50,18 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR enumerate(resp["source_documents"])]) history[-1][-1] += source yield history, "" - elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path): - for resp, history in local_doc_qa.get_knowledge_based_answer( - query=query, vs_path=vs_path, chat_history=history, streaming=streaming): - source = "\n\n" - source += "".join( - [f"""
出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" - f"""{doc.page_content}\n""" - f"""
""" - for i, doc in - enumerate(resp["source_documents"])]) - history[-1][-1] += source - yield history, "" + elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): + for resp, history in local_doc_qa.get_knowledge_based_answer( + query=query, vs_path=vs_path, chat_history=history, streaming=streaming): + source = "\n\n" + source += "".join( + [f"""
出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" + f"""{doc.page_content}\n""" + f"""
""" + for i, doc in + enumerate(resp["source_documents"])]) + history[-1][-1] += source + yield history, "" elif mode == "知识库测试": if os.path.exists(vs_path): resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path, @@ -86,12 +86,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR yield history + [[query, "请选择知识库后进行测试,当前未选择知识库。"]], "" else: - for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history, + for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history[:-1], streaming=streaming): resp = answer_result.llm_output["answer"] history = answer_result.history - history[-1][-1] = resp + ( - "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") + history[-1][-1] = resp yield history, "" logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}") flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME) @@ -169,9 +168,13 @@ def change_vs_name_input(vs_id, history): if vs_id == "新建知识库": return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history else: - file_status = f"已加载知识库{vs_id},请开始提问" + vs_path = os.path.join(VS_ROOT_PATH, vs_id) + if "index.faiss" in os.listdir(vs_path): + file_status = f"已加载知识库{vs_id},请开始提问" + else: + file_status = f"已选择知识库{vs_id},当前知识库中未上传文件,请先上传文件后,再开始提问" return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \ - os.path.join(VS_ROOT_PATH, vs_id), history + [[None, file_status]] + vs_path, history + [[None, file_status]] knowledge_base_test_mode_info = ("【注意】\n\n"