diff --git a/webui.py b/webui.py
index a93de0c..5676dc2 100644
--- a/webui.py
+++ b/webui.py
@@ -50,18 +50,18 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
enumerate(resp["source_documents"])])
history[-1][-1] += source
yield history, ""
- elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
- for resp, history in local_doc_qa.get_knowledge_based_answer(
- query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
- source = "\n\n"
- source += "".join(
- [f""" 出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}
\n"""
- f"""{doc.page_content}\n"""
- f""" """
- for i, doc in
- enumerate(resp["source_documents"])])
- history[-1][-1] += source
- yield history, ""
+ elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
+ for resp, history in local_doc_qa.get_knowledge_based_answer(
+ query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
+ source = "\n\n"
+ source += "".join(
+ [f""" 出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}
\n"""
+ f"""{doc.page_content}\n"""
+ f""" """
+ for i, doc in
+ enumerate(resp["source_documents"])])
+ history[-1][-1] += source
+ yield history, ""
elif mode == "知识库测试":
if os.path.exists(vs_path):
resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path,
@@ -86,12 +86,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield history + [[query,
"请选择知识库后进行测试,当前未选择知识库。"]], ""
else:
- for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
+ for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history[:-1],
streaming=streaming):
resp = answer_result.llm_output["answer"]
history = answer_result.history
- history[-1][-1] = resp + (
- "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
+ history[-1][-1] = resp
yield history, ""
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
@@ -169,9 +168,13 @@ def change_vs_name_input(vs_id, history):
if vs_id == "新建知识库":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
else:
- file_status = f"已加载知识库{vs_id},请开始提问"
+ vs_path = os.path.join(VS_ROOT_PATH, vs_id)
+ if "index.faiss" in os.listdir(vs_path):
+ file_status = f"已加载知识库{vs_id},请开始提问"
+ else:
+ file_status = f"已选择知识库{vs_id},当前知识库中未上传文件,请先上传文件后,再开始提问"
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \
- os.path.join(VS_ROOT_PATH, vs_id), history + [[None, file_status]]
+ vs_path, history + [[None, file_status]]
knowledge_base_test_mode_info = ("【注意】\n\n"