From a887df17150ab0b41f09d3f8e3e563d756aa0c50 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Tue, 13 Jun 2023 23:54:29 +0800 Subject: [PATCH] add knowledge_base folder and move vector_store and content inside --- api.py | 111 ++++++++++-------- chains/local_doc_qa.py | 5 +- configs/model_config.py | 10 +- .../samples/content}/README.md | 48 ++++++-- .../samples/content}/test.jpg | Bin .../samples/content}/test.pdf | Bin .../samples/content}/test.txt | 0 loader/image_loader.py | 2 +- loader/pdf_loader.py | 2 +- webui.py | 29 ++--- webui_st.py | 16 +-- 11 files changed, 128 insertions(+), 95 deletions(-) rename {content/samples => knowledge_base/samples/content}/README.md (84%) rename {content/samples => knowledge_base/samples/content}/test.jpg (100%) rename {content/samples => knowledge_base/samples/content}/test.pdf (100%) rename {content/samples => knowledge_base/samples/content}/test.txt (100%) diff --git a/api.py b/api.py index 7926fb2..9206ac5 100644 --- a/api.py +++ b/api.py @@ -15,7 +15,7 @@ from typing_extensions import Annotated from starlette.responses import RedirectResponse from chains.local_doc_qa import LocalDocQA -from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE, +from configs.model_config import (KB_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN) import models.shared as shared @@ -80,15 +80,15 @@ class ChatMessage(BaseModel): def get_folder_path(local_doc_id: str): - return os.path.join(UPLOAD_ROOT_PATH, local_doc_id) + return os.path.join(KB_ROOT_PATH, local_doc_id, "content") def get_vs_path(local_doc_id: str): - return os.path.join(VS_ROOT_PATH, local_doc_id) + return os.path.join(KB_ROOT_PATH, local_doc_id, "vector_store") def get_file_path(local_doc_id: str, doc_name: str): - return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name) + return os.path.join(KB_ROOT_PATH, local_doc_id, "content", doc_name) async def upload_file( @@ -147,64 +147,76 @@ async def upload_files( return BaseResponse(code=500, msg=file_status) +async def list_kbs(): + # Get List of Knowledge Base + if not os.path.exists(KB_ROOT_PATH): + all_doc_ids = [] + else: + all_doc_ids = [ + folder + for folder in os.listdir(KB_ROOT_PATH) + if os.path.isdir(os.path.join(KB_ROOT_PATH, folder)) + and os.path.exists(os.path.join(KB_ROOT_PATH, folder, "vector_store", "index.faiss")) + ] + + return ListDocsResponse(data=all_doc_ids) + + async def list_docs( knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1") ): - if knowledge_base_id: - local_doc_folder = get_folder_path(knowledge_base_id) - if not os.path.exists(local_doc_folder): - return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} - all_doc_names = [ - doc - for doc in os.listdir(local_doc_folder) - if os.path.isfile(os.path.join(local_doc_folder, doc)) - ] - return ListDocsResponse(data=all_doc_names) - else: - if not os.path.exists(UPLOAD_ROOT_PATH): - all_doc_ids = [] - else: - all_doc_ids = [ - folder - for folder in os.listdir(UPLOAD_ROOT_PATH) - if os.path.isdir(os.path.join(UPLOAD_ROOT_PATH, folder)) - ] + local_doc_folder = get_folder_path(knowledge_base_id) + if not os.path.exists(local_doc_folder): + return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} + all_doc_names = [ + doc + for doc in os.listdir(local_doc_folder) + if os.path.isfile(os.path.join(local_doc_folder, doc)) + ] + return ListDocsResponse(data=all_doc_names) - return ListDocsResponse(data=all_doc_ids) + +async def delete_kbs( + knowledge_base_id: str = Query(..., + description="Knowledge Base Name", + example="kb1"), +): + # TODO: 确认是否支持批量删除知识库 + knowledge_base_id = urllib.parse.unquote(knowledge_base_id) + if not os.path.exists(get_folder_path(knowledge_base_id)): + return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} + shutil.rmtree(get_folder_path(knowledge_base_id)) + return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success") async def delete_docs( knowledge_base_id: str = Query(..., description="Knowledge Base Name", example="kb1"), - doc_name: Optional[str] = Query( + doc_name: str = Query( None, description="doc name", example="doc_name_1.pdf" ), ): + # TODO: 确认是否支持批量删除文件 knowledge_base_id = urllib.parse.unquote(knowledge_base_id) - if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, knowledge_base_id)): + if not os.path.exists(get_folder_path(knowledge_base_id)): return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} - if doc_name: - doc_path = get_file_path(knowledge_base_id, doc_name) - if os.path.exists(doc_path): - os.remove(doc_path) + doc_path = get_file_path(knowledge_base_id, doc_name) + if os.path.exists(doc_path): + os.remove(doc_path) - # 删除上传的文件后重新生成知识库(FAISS)内的数据 - remain_docs = await list_docs(knowledge_base_id) - if len(remain_docs.data) == 0: - shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True) - else: - local_doc_qa.init_knowledge_vector_store( - get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id) - ) - - return BaseResponse(code=200, msg=f"document {doc_name} delete success") + # 删除上传的文件后重新生成知识库(FAISS)内的数据 + # TODO: 删除向量库中对应文件 + remain_docs = await list_docs(knowledge_base_id) + if len(remain_docs.data) == 0: + shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True) else: - return BaseResponse(code=1, msg=f"document {doc_name} not found") - + local_doc_qa.init_knowledge_vector_store( + get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id) + ) + return BaseResponse(code=200, msg=f"document {doc_name} delete success") else: - shutil.rmtree(get_folder_path(knowledge_base_id)) - return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success") + return BaseResponse(code=1, msg=f"document {doc_name} not found") async def local_doc_chat( @@ -221,7 +233,7 @@ async def local_doc_chat( ], ), ): - vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) + vs_path = get_vs_path(knowledge_base_id) if not os.path.exists(vs_path): # return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found") return ChatMessage( @@ -278,6 +290,7 @@ async def bing_search_chat( source_documents=source_documents, ) + async def chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), history: List[List[str]] = Body( @@ -310,8 +323,9 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str): turn = 1 while True: input_json = await websocket.receive_json() - question, history, knowledge_base_id = input_json["question"], input_json["history"], input_json["knowledge_base_id"] - vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) + question, history, knowledge_base_id = input_json["question"], input_json["history"], input_json[ + "knowledge_base_id"] + vs_path = get_vs_path(knowledge_base_id) if not os.path.exists(vs_path): await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"}) @@ -351,9 +365,6 @@ async def document(): return RedirectResponse(url="/docs") - - - def api_start(host, port): global app global local_doc_qa diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index dac13bb..1be9f22 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -187,8 +187,9 @@ class LocalDocQA: torch_gc() else: if not vs_path: - vs_path = os.path.join(VS_ROOT_PATH, - f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") + vs_path = os.path.join(KB_ROOT_PATH, + f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""", + "vector_store") vector_store = MyFAISS.from_documents(docs, self.embeddings) # docs 为Document列表 torch_gc() diff --git a/configs/model_config.py b/configs/model_config.py index 0f25092..086bfe0 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -119,10 +119,8 @@ USE_PTUNING_V2 = False # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - -VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store") - -UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content") +# 知识库默认存储路径 +KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") # 基于上下文的prompt模版,请务必保留"{question}"和"{context}" PROMPT_TEMPLATE = """已知信息: @@ -139,10 +137,10 @@ SENTENCE_SIZE = 100 # 匹配后单段上下文长度 CHUNK_SIZE = 250 -# LLM input history length +# 传入LLM的历史记录长度 LLM_HISTORY_LEN = 3 -# return top-k text chunk from vector store +# 知识库检索时返回的匹配内容条数 VECTOR_SEARCH_TOP_K = 5 # 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准 diff --git a/content/samples/README.md b/knowledge_base/samples/content/README.md similarity index 84% rename from content/samples/README.md rename to knowledge_base/samples/content/README.md index 9a7afb4..efabb79 100644 --- a/content/samples/README.md +++ b/knowledge_base/samples/content/README.md @@ -1,23 +1,26 @@ -# 基于本地知识的 ChatGLM 应用实现 +# 基于本地知识库的 ChatGLM 等大语言模型应用实现 ## 介绍 🌍 [_READ THIS IN ENGLISH_](README_en.md) -🤖️ 一种利用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) + [langchain](https://github.com/hwchase17/langchain) 实现的基于本地知识的 ChatGLM 应用。增加 [clue-ai/ChatYuan](https://github.com/clue-ai/ChatYuan) 项目的模型 [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) 的支持。 +🤖️ 一种利用 [langchain](https://github.com/hwchase17/langchain) 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。 -💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai) 和 [AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全部基于开源模型实现的本地知识问答应用。 +💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai) 和 [AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全流程可使用开源模型实现的本地知识库问答应用。现已支持使用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 等大语言模型直接接入,或通过 [fastchat](https://github.com/lm-sys/FastChat) api 形式接入 Vicuna, Alpaca, LLaMA, Koala, RWKV 等模型。 ✅ 本项目中 Embedding 默认选用的是 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main),LLM 默认选用的是 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)。依托上述模型,本项目可实现全部使用**开源**模型**离线私有部署**。 ⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 -> 在文本向量中匹配出与问句向量最相似的`top k`个 -> 匹配出的文本作为上下文和问题一起添加到`prompt`中 -> 提交给`LLM`生成回答。 +📺 [原理介绍视频](https://www.bilibili.com/video/BV13M4y1e7cN/?share_source=copy_web&vd_source=e6c5aafe684f30fbe41925d61ca6d514) + ![实现原理图](img/langchain+chatglm.png) 从文档处理角度来看,实现流程如下: ![实现原理图2](img/langchain+chatglm2.png) + 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。 🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM) @@ -33,7 +36,7 @@ - ChatGLM-6B 模型硬件需求 注:如未将模型下载至本地,请执行前检查`$HOME/.cache/huggingface/`文件夹剩余空间,模型文件下载至本地需要 15 GB 存储空间。 - + 注:一些其它的可选启动项见[项目启动选项](docs/StartOption.md) 模型下载方法可参考 [常见问题](docs/FAQ.md) 中 Q8。 | **量化等级** | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) | @@ -79,9 +82,10 @@ docker run --gpus all -d --name chatglm -p 7860:7860 -v ~/github/langchain-ChatG ### 软件需求 -本项目已在 Python 3.8 - 3.10,CUDA 11.7 环境下完成测试。已在 Windows、ARM 架构的 macOS、Linux 系统中完成测试。 +本项目已在 Python 3.8.1 - 3.10,CUDA 11.7 环境下完成测试。已在 Windows、ARM 架构的 macOS、Linux 系统中完成测试。 vue前端需要node18环境 + ### 从本地加载模型 请参考 [THUDM/ChatGLM-6B#从本地加载模型](https://github.com/THUDM/ChatGLM-6B#从本地加载模型) @@ -94,6 +98,8 @@ vue前端需要node18环境 在开始执行 Web UI 或命令行交互前,请先检查 [configs/model_config.py](configs/model_config.py) 中的各项模型参数设计是否符合需求。 +如需通过 fastchat 以 api 形式调用 llm,请参考 [fastchat 调用实现](docs/fastchat.md) + ### 3. 执行脚本体验 Web UI 或命令行交互 > 注:鉴于环境部署过程中可能遇到问题,建议首先测试命令行脚本。建议命令行脚本测试可正常运行后再运行 Web UI。 @@ -122,9 +128,17 @@ $ pnpm i $ npm run dev ``` -执行后效果如下图所示: +VUE 前端界面如下图所示: +1. `对话` 界面 +![](img/vue_0521_0.png) +2. `知识库问答` 界面 +![](img/vue_0521_1.png) +3. `Bing搜索` 界面 +![](img/vue_0521_2.png) + +WebUI 界面如下图所示: 1. `对话` Tab 界面 -![](img/webui_0510_0.png) +![](img/webui_0521_0.png) 2. `知识库测试 Beta` Tab 界面 ![](img/webui_0510_1.png) 3. `模型配置` Tab 界面 @@ -177,36 +191,44 @@ Web UI 可以实现如下功能: - [ ] Langchain 应用 - [x] 接入非结构化文档(已支持 md、pdf、docx、txt 文件格式) - - [ ] 搜索引擎与本地网页接入 + - [x] jpg 与 png 格式图片的 OCR 文字识别 + - [x] 搜索引擎接入 + - [ ] 本地网页接入 - [ ] 结构化数据接入(如 csv、Excel、SQL 等) - [ ] 知识图谱/图数据库接入 - [ ] Agent 实现 -- [ ] 增加更多 LLM 模型支持 +- [x] 增加更多 LLM 模型支持 - [x] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b) - [x] [THUDM/chatglm-6b-int8](https://huggingface.co/THUDM/chatglm-6b-int8) - [x] [THUDM/chatglm-6b-int4](https://huggingface.co/THUDM/chatglm-6b-int4) - [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe) - [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) - [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft) -- [ ] 增加更多 Embedding 模型支持 + - [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm +- [x] 增加更多 Embedding 模型支持 - [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh) - [x] [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh) - [x] [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) - [x] [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese) + - [x] [moka-ai/m3e-small](https://huggingface.co/moka-ai/m3e-small) + - [x] [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base) - [ ] Web UI - - [x] 利用 gradio 实现 Web UI DEMO + - [x] 基于 gradio 实现 Web UI DEMO + - [x] 基于 streamlit 实现 Web UI DEMO - [x] 添加输出内容及错误提示 - [x] 引用标注 - [ ] 增加知识库管理 - [x] 选择知识库开始问答 - [x] 上传文件/文件夹至知识库 + - [x] 知识库测试 - [ ] 删除知识库中文件 - - [ ] 利用 streamlit 实现 Web UI Demo + - [x] 支持搜索引擎问答 - [ ] 增加 API 支持 - [x] 利用 fastapi 实现 API 部署方式 - [ ] 实现调用 API 的 Web UI Demo +- [x] VUE 前端 ## 项目交流群 -![二维码](img/qr_code_17.jpg) +![二维码](img/qr_code_30.jpg) 🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 diff --git a/content/samples/test.jpg b/knowledge_base/samples/content/test.jpg similarity index 100% rename from content/samples/test.jpg rename to knowledge_base/samples/content/test.jpg diff --git a/content/samples/test.pdf b/knowledge_base/samples/content/test.pdf similarity index 100% rename from content/samples/test.pdf rename to knowledge_base/samples/content/test.pdf diff --git a/content/samples/test.txt b/knowledge_base/samples/content/test.txt similarity index 100% rename from content/samples/test.txt rename to knowledge_base/samples/content/test.txt diff --git a/loader/image_loader.py b/loader/image_loader.py index 48b9d57..8008861 100644 --- a/loader/image_loader.py +++ b/loader/image_loader.py @@ -33,7 +33,7 @@ class UnstructuredPaddleImageLoader(UnstructuredFileLoader): if __name__ == "__main__": - filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.jpg") + filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.jpg") loader = UnstructuredPaddleImageLoader(filepath, mode="elements") docs = loader.load() for doc in docs: diff --git a/loader/pdf_loader.py b/loader/pdf_loader.py index e43169d..df3412b 100644 --- a/loader/pdf_loader.py +++ b/loader/pdf_loader.py @@ -49,7 +49,7 @@ class UnstructuredPaddlePDFLoader(UnstructuredFileLoader): if __name__ == "__main__": - filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf") + filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.pdf") loader = UnstructuredPaddlePDFLoader(filepath, mode="elements") docs = loader.load() for doc in docs: diff --git a/webui.py b/webui.py index d082d45..949ecb4 100644 --- a/webui.py +++ b/webui.py @@ -16,9 +16,9 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path def get_vs_list(): lst_default = ["新建知识库"] - if not os.path.exists(VS_ROOT_PATH): + if not os.path.exists(KB_ROOT_PATH): return lst_default - lst = os.listdir(VS_ROOT_PATH) + lst = os.listdir(KB_ROOT_PATH) if not lst: return lst_default lst.sort() @@ -141,14 +141,14 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): - vs_path = os.path.join(VS_ROOT_PATH, vs_id) + vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") filelist = [] if local_doc_qa.llm and local_doc_qa.embeddings: if isinstance(files, list): for file in files: filename = os.path.split(file.name)[-1] - shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename)) - filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename)) + shutil.move(file.name, os.path.join(KB_ROOT_PATH, vs_id, "content", filename)) + filelist.append(os.path.join(KB_ROOT_PATH, vs_id, "content", filename)) vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path, sentence_size) else: vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, @@ -168,7 +168,7 @@ def change_vs_name_input(vs_id, history): if vs_id == "新建知识库": return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history else: - vs_path = os.path.join(VS_ROOT_PATH, vs_id) + vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") if "index.faiss" in os.listdir(vs_path): file_status = f"已加载知识库{vs_id},请开始提问" else: @@ -220,11 +220,11 @@ def add_vs_name(vs_name, chatbot): visible=False), chatbot else: # 新建上传文件存储路径 - if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_name)): - os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_name)) + if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_name, "content")): + os.makedirs(os.path.join(KB_ROOT_PATH, vs_name, "content")) # 新建向量库存储路径 - if not os.path.exists(os.path.join(VS_ROOT_PATH, vs_name)): - os.makedirs(os.path.join(VS_ROOT_PATH, vs_name)) + if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_name, "vector_store")): + os.makedirs(os.path.join(KB_ROOT_PATH, vs_name, "vector_store")) vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """ chatbot = chatbot + [[None, vs_status]] return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update( @@ -234,12 +234,13 @@ def add_vs_name(vs_name, chatbot): # 自动化加载固定文件间中文件 def reinit_vector_store(vs_id, history): try: - shutil.rmtree(VS_ROOT_PATH) - vs_path = os.path.join(VS_ROOT_PATH, vs_id) + shutil.rmtree(os.path.join(KB_ROOT_PATH, vs_id, "vector_store")) + vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0, label="文本入库分句长度限制", interactive=True, visible=True) - vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(UPLOAD_ROOT_PATH, vs_path, sentence_size) + vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(os.path.join(KB_ROOT_PATH, vs_id, "content"), + vs_path, sentence_size) model_status = """知识库构建成功""" except Exception as e: logger.error(e) @@ -285,7 +286,7 @@ default_theme_args = dict( with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo: vs_path, file_status, model_status = gr.State( - os.path.join(VS_ROOT_PATH, get_vs_list()[0]) if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State( + os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State( model_status) gr.Markdown(webui_title) with gr.Tab("对话"): diff --git a/webui_st.py b/webui_st.py index eb48c15..6d1265e 100644 --- a/webui_st.py +++ b/webui_st.py @@ -20,9 +20,9 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path def get_vs_list(): lst_default = ["新建知识库"] - if not os.path.exists(VS_ROOT_PATH): + if not os.path.exists(KB_ROOT_PATH): return lst_default - lst = os.listdir(VS_ROOT_PATH) + lst = os.listdir(KB_ROOT_PATH) if not lst: return lst_default lst.sort() @@ -144,18 +144,18 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec' def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): - vs_path = os.path.join(VS_ROOT_PATH, vs_id) + vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") filelist = [] - if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)): - os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id)) + if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")): + os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content")) if local_doc_qa.llm and local_doc_qa.embeddings: if isinstance(files, list): for file in files: filename = os.path.split(file.name)[-1] shutil.move(file.name, os.path.join( - UPLOAD_ROOT_PATH, vs_id, filename)) + KB_ROOT_PATH, vs_id, "content", filename)) filelist.append(os.path.join( - UPLOAD_ROOT_PATH, vs_id, filename)) + KB_ROOT_PATH, vs_id, "content", filename)) vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store( filelist, vs_path, sentence_size) else: @@ -516,7 +516,7 @@ with st.form('my_form', clear_on_submit=True): last_response = output_messages() for history, _ in answer(q, vs_path=os.path.join( - VS_ROOT_PATH, vs_path), + KB_ROOT_PATH, vs_path, "vector_store"), history=[], mode=mode, score_threshold=score_threshold,