diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py
index 7293817..e0411f0 100644
--- a/chains/local_doc_qa.py
+++ b/chains/local_doc_qa.py
@@ -252,7 +252,7 @@ class LocalDocQA:
logger.info(f"{file} 未能成功加载")
if len(docs) > 0:
logger.info("文件加载完毕,正在生成向量库")
- if vs_path and os.path.isdir(vs_path):
+ if vs_path and os.path.isdir(vs_path) and "index.faiss" in os.listdir(vs_path):
vector_store = load_vector_store(vs_path, self.embeddings)
vector_store.add_documents(docs)
torch_gc()
diff --git a/webui.py b/webui.py
index 799e08d..c4f6193 100644
--- a/webui.py
+++ b/webui.py
@@ -50,18 +50,19 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
enumerate(resp["source_documents"])])
history[-1][-1] += source
yield history, ""
- elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
- for resp, history in local_doc_qa.get_knowledge_based_answer(
- query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
- source = "\n\n"
- source += "".join(
- [f""" 出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}
\n"""
- f"""{doc.page_content}\n"""
- f""" """
- for i, doc in
- enumerate(resp["source_documents"])])
- history[-1][-1] += source
- yield history, ""
+ elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path) and "index.faiss" in os.listdir(
+ vs_path):
+ for resp, history in local_doc_qa.get_knowledge_based_answer(
+ query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
+ source = "\n\n"
+ source += "".join(
+ [f""" 出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}
\n"""
+ f"""{doc.page_content}\n"""
+ f""" """
+ for i, doc in
+ enumerate(resp["source_documents"])])
+ history[-1][-1] += source
+ yield history, ""
elif mode == "知识库测试":
if os.path.exists(vs_path):
resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path,
@@ -124,7 +125,6 @@ def init_model():
def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k,
history):
-
try:
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
llm_model_ins.history_len = llm_history_len
@@ -229,16 +229,17 @@ def add_vs_name(vs_name, chatbot):
chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update(
visible=False), gr.update(visible=False), gr.update(visible=True), chatbot
-
+
+
# 自动化加载固定文件间中文件
-def init_set_vector_store(content_dir,vs_id,history):
+def reinit_vector_store(vs_id, history):
try:
shutil.rmtree(VS_ROOT_PATH)
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0,
- label="文本入库分句长度限制",
- interactive=True, visible=True)
- vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(content_dir, vs_path, sentence_size)
+ label="文本入库分句长度限制",
+ interactive=True, visible=True)
+ vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(UPLOAD_ROOT_PATH, vs_path, sentence_size)
model_status = """知识库构建成功"""
except Exception as e:
logger.error(e)
@@ -487,8 +488,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2,
use_lora, top_k, chatbot], outputs=chatbot)
load_knowlege_button = gr.Button("重新构建知识库")
- load_knowlege_button.click(init_set_vector_store, show_progress=True,
- inputs=[UPLOAD_ROOT_PATH, select_vs,chatbot], outputs=chatbot)
+ load_knowlege_button.click(reinit_vector_store, show_progress=True,
+ inputs=[select_vs, chatbot], outputs=chatbot)
(demo
.queue(concurrency_count=3)
.launch(server_name='0.0.0.0',