diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 7293817..e0411f0 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -252,7 +252,7 @@ class LocalDocQA: logger.info(f"{file} 未能成功加载") if len(docs) > 0: logger.info("文件加载完毕,正在生成向量库") - if vs_path and os.path.isdir(vs_path): + if vs_path and os.path.isdir(vs_path) and "index.faiss" in os.listdir(vs_path): vector_store = load_vector_store(vs_path, self.embeddings) vector_store.add_documents(docs) torch_gc() diff --git a/webui.py b/webui.py index 799e08d..c4f6193 100644 --- a/webui.py +++ b/webui.py @@ -50,18 +50,19 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR enumerate(resp["source_documents"])]) history[-1][-1] += source yield history, "" - elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): - for resp, history in local_doc_qa.get_knowledge_based_answer( - query=query, vs_path=vs_path, chat_history=history, streaming=streaming): - source = "\n\n" - source += "".join( - [f"""
出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" - f"""{doc.page_content}\n""" - f"""
""" - for i, doc in - enumerate(resp["source_documents"])]) - history[-1][-1] += source - yield history, "" + elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path) and "index.faiss" in os.listdir( + vs_path): + for resp, history in local_doc_qa.get_knowledge_based_answer( + query=query, vs_path=vs_path, chat_history=history, streaming=streaming): + source = "\n\n" + source += "".join( + [f"""
出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}\n""" + f"""{doc.page_content}\n""" + f"""
""" + for i, doc in + enumerate(resp["source_documents"])]) + history[-1][-1] += source + yield history, "" elif mode == "知识库测试": if os.path.exists(vs_path): resp, prompt = local_doc_qa.get_knowledge_based_conent_test(query=query, vs_path=vs_path, @@ -124,7 +125,6 @@ def init_model(): def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history): - try: llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) llm_model_ins.history_len = llm_history_len @@ -229,16 +229,17 @@ def add_vs_name(vs_name, chatbot): chatbot = chatbot + [[None, vs_status]] return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update( visible=False), gr.update(visible=False), gr.update(visible=True), chatbot - + + # 自动化加载固定文件间中文件 -def init_set_vector_store(content_dir,vs_id,history): +def reinit_vector_store(vs_id, history): try: shutil.rmtree(VS_ROOT_PATH) vs_path = os.path.join(VS_ROOT_PATH, vs_id) sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0, - label="文本入库分句长度限制", - interactive=True, visible=True) - vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(content_dir, vs_path, sentence_size) + label="文本入库分句长度限制", + interactive=True, visible=True) + vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(UPLOAD_ROOT_PATH, vs_path, sentence_size) model_status = """知识库构建成功""" except Exception as e: logger.error(e) @@ -487,8 +488,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, chatbot], outputs=chatbot) load_knowlege_button = gr.Button("重新构建知识库") - load_knowlege_button.click(init_set_vector_store, show_progress=True, - inputs=[UPLOAD_ROOT_PATH, select_vs,chatbot], outputs=chatbot) + load_knowlege_button.click(reinit_vector_store, show_progress=True, + inputs=[select_vs, chatbot], outputs=chatbot) (demo .queue(concurrency_count=3) .launch(server_name='0.0.0.0',