add knowledge_base folder and move vector_store and content inside
This commit is contained in:
parent
585b4427ba
commit
a887df1715
111
api.py
111
api.py
|
|
@ -15,7 +15,7 @@ from typing_extensions import Annotated
|
|||
from starlette.responses import RedirectResponse
|
||||
|
||||
from chains.local_doc_qa import LocalDocQA
|
||||
from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
||||
from configs.model_config import (KB_ROOT_PATH, EMBEDDING_DEVICE,
|
||||
EMBEDDING_MODEL, NLTK_DATA_PATH,
|
||||
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
|
||||
import models.shared as shared
|
||||
|
|
@ -80,15 +80,15 @@ class ChatMessage(BaseModel):
|
|||
|
||||
|
||||
def get_folder_path(local_doc_id: str):
|
||||
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id)
|
||||
return os.path.join(KB_ROOT_PATH, local_doc_id, "content")
|
||||
|
||||
|
||||
def get_vs_path(local_doc_id: str):
|
||||
return os.path.join(VS_ROOT_PATH, local_doc_id)
|
||||
return os.path.join(KB_ROOT_PATH, local_doc_id, "vector_store")
|
||||
|
||||
|
||||
def get_file_path(local_doc_id: str, doc_name: str):
|
||||
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
|
||||
return os.path.join(KB_ROOT_PATH, local_doc_id, "content", doc_name)
|
||||
|
||||
|
||||
async def upload_file(
|
||||
|
|
@ -147,64 +147,76 @@ async def upload_files(
|
|||
return BaseResponse(code=500, msg=file_status)
|
||||
|
||||
|
||||
async def list_kbs():
|
||||
# Get List of Knowledge Base
|
||||
if not os.path.exists(KB_ROOT_PATH):
|
||||
all_doc_ids = []
|
||||
else:
|
||||
all_doc_ids = [
|
||||
folder
|
||||
for folder in os.listdir(KB_ROOT_PATH)
|
||||
if os.path.isdir(os.path.join(KB_ROOT_PATH, folder))
|
||||
and os.path.exists(os.path.join(KB_ROOT_PATH, folder, "vector_store", "index.faiss"))
|
||||
]
|
||||
|
||||
return ListDocsResponse(data=all_doc_ids)
|
||||
|
||||
|
||||
async def list_docs(
|
||||
knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1")
|
||||
):
|
||||
if knowledge_base_id:
|
||||
local_doc_folder = get_folder_path(knowledge_base_id)
|
||||
if not os.path.exists(local_doc_folder):
|
||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||||
all_doc_names = [
|
||||
doc
|
||||
for doc in os.listdir(local_doc_folder)
|
||||
if os.path.isfile(os.path.join(local_doc_folder, doc))
|
||||
]
|
||||
return ListDocsResponse(data=all_doc_names)
|
||||
else:
|
||||
if not os.path.exists(UPLOAD_ROOT_PATH):
|
||||
all_doc_ids = []
|
||||
else:
|
||||
all_doc_ids = [
|
||||
folder
|
||||
for folder in os.listdir(UPLOAD_ROOT_PATH)
|
||||
if os.path.isdir(os.path.join(UPLOAD_ROOT_PATH, folder))
|
||||
]
|
||||
local_doc_folder = get_folder_path(knowledge_base_id)
|
||||
if not os.path.exists(local_doc_folder):
|
||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||||
all_doc_names = [
|
||||
doc
|
||||
for doc in os.listdir(local_doc_folder)
|
||||
if os.path.isfile(os.path.join(local_doc_folder, doc))
|
||||
]
|
||||
return ListDocsResponse(data=all_doc_names)
|
||||
|
||||
return ListDocsResponse(data=all_doc_ids)
|
||||
|
||||
async def delete_kbs(
|
||||
knowledge_base_id: str = Query(...,
|
||||
description="Knowledge Base Name",
|
||||
example="kb1"),
|
||||
):
|
||||
# TODO: 确认是否支持批量删除知识库
|
||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||||
shutil.rmtree(get_folder_path(knowledge_base_id))
|
||||
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
|
||||
|
||||
|
||||
async def delete_docs(
|
||||
knowledge_base_id: str = Query(...,
|
||||
description="Knowledge Base Name",
|
||||
example="kb1"),
|
||||
doc_name: Optional[str] = Query(
|
||||
doc_name: str = Query(
|
||||
None, description="doc name", example="doc_name_1.pdf"
|
||||
),
|
||||
):
|
||||
# TODO: 确认是否支持批量删除文件
|
||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, knowledge_base_id)):
|
||||
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
||||
if doc_name:
|
||||
doc_path = get_file_path(knowledge_base_id, doc_name)
|
||||
if os.path.exists(doc_path):
|
||||
os.remove(doc_path)
|
||||
doc_path = get_file_path(knowledge_base_id, doc_name)
|
||||
if os.path.exists(doc_path):
|
||||
os.remove(doc_path)
|
||||
|
||||
# 删除上传的文件后重新生成知识库(FAISS)内的数据
|
||||
remain_docs = await list_docs(knowledge_base_id)
|
||||
if len(remain_docs.data) == 0:
|
||||
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
|
||||
else:
|
||||
local_doc_qa.init_knowledge_vector_store(
|
||||
get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id)
|
||||
)
|
||||
|
||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||
# 删除上传的文件后重新生成知识库(FAISS)内的数据
|
||||
# TODO: 删除向量库中对应文件
|
||||
remain_docs = await list_docs(knowledge_base_id)
|
||||
if len(remain_docs.data) == 0:
|
||||
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
|
||||
else:
|
||||
return BaseResponse(code=1, msg=f"document {doc_name} not found")
|
||||
|
||||
local_doc_qa.init_knowledge_vector_store(
|
||||
get_folder_path(knowledge_base_id), get_vs_path(knowledge_base_id)
|
||||
)
|
||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||
else:
|
||||
shutil.rmtree(get_folder_path(knowledge_base_id))
|
||||
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
|
||||
return BaseResponse(code=1, msg=f"document {doc_name} not found")
|
||||
|
||||
|
||||
async def local_doc_chat(
|
||||
|
|
@ -221,7 +233,7 @@ async def local_doc_chat(
|
|||
],
|
||||
),
|
||||
):
|
||||
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
|
||||
vs_path = get_vs_path(knowledge_base_id)
|
||||
if not os.path.exists(vs_path):
|
||||
# return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found")
|
||||
return ChatMessage(
|
||||
|
|
@ -278,6 +290,7 @@ async def bing_search_chat(
|
|||
source_documents=source_documents,
|
||||
)
|
||||
|
||||
|
||||
async def chat(
|
||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||
history: List[List[str]] = Body(
|
||||
|
|
@ -310,8 +323,9 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
|||
turn = 1
|
||||
while True:
|
||||
input_json = await websocket.receive_json()
|
||||
question, history, knowledge_base_id = input_json["question"], input_json["history"], input_json["knowledge_base_id"]
|
||||
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
|
||||
question, history, knowledge_base_id = input_json["question"], input_json["history"], input_json[
|
||||
"knowledge_base_id"]
|
||||
vs_path = get_vs_path(knowledge_base_id)
|
||||
|
||||
if not os.path.exists(vs_path):
|
||||
await websocket.send_json({"error": f"Knowledge base {knowledge_base_id} not found"})
|
||||
|
|
@ -351,9 +365,6 @@ async def document():
|
|||
return RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def api_start(host, port):
|
||||
global app
|
||||
global local_doc_qa
|
||||
|
|
|
|||
|
|
@ -187,8 +187,9 @@ class LocalDocQA:
|
|||
torch_gc()
|
||||
else:
|
||||
if not vs_path:
|
||||
vs_path = os.path.join(VS_ROOT_PATH,
|
||||
f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
|
||||
vs_path = os.path.join(KB_ROOT_PATH,
|
||||
f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""",
|
||||
"vector_store")
|
||||
vector_store = MyFAISS.from_documents(docs, self.embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
|
||||
|
|
|
|||
|
|
@ -119,10 +119,8 @@ USE_PTUNING_V2 = False
|
|||
# LLM running device
|
||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
|
||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store")
|
||||
|
||||
UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content")
|
||||
# 知识库默认存储路径
|
||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
||||
|
||||
# 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
|
||||
PROMPT_TEMPLATE = """已知信息:
|
||||
|
|
@ -139,10 +137,10 @@ SENTENCE_SIZE = 100
|
|||
# 匹配后单段上下文长度
|
||||
CHUNK_SIZE = 250
|
||||
|
||||
# LLM input history length
|
||||
# 传入LLM的历史记录长度
|
||||
LLM_HISTORY_LEN = 3
|
||||
|
||||
# return top-k text chunk from vector store
|
||||
# 知识库检索时返回的匹配内容条数
|
||||
VECTOR_SEARCH_TOP_K = 5
|
||||
|
||||
# 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准
|
||||
|
|
|
|||
|
|
@ -1,23 +1,26 @@
|
|||
# 基于本地知识的 ChatGLM 应用实现
|
||||
# 基于本地知识库的 ChatGLM 等大语言模型应用实现
|
||||
|
||||
## 介绍
|
||||
|
||||
🌍 [_READ THIS IN ENGLISH_](README_en.md)
|
||||
|
||||
🤖️ 一种利用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) + [langchain](https://github.com/hwchase17/langchain) 实现的基于本地知识的 ChatGLM 应用。增加 [clue-ai/ChatYuan](https://github.com/clue-ai/ChatYuan) 项目的模型 [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) 的支持。
|
||||
🤖️ 一种利用 [langchain](https://github.com/hwchase17/langchain) 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。
|
||||
|
||||
💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai) 和 [AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全部基于开源模型实现的本地知识问答应用。
|
||||
💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai) 和 [AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全流程可使用开源模型实现的本地知识库问答应用。现已支持使用 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 等大语言模型直接接入,或通过 [fastchat](https://github.com/lm-sys/FastChat) api 形式接入 Vicuna, Alpaca, LLaMA, Koala, RWKV 等模型。
|
||||
|
||||
✅ 本项目中 Embedding 默认选用的是 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main),LLM 默认选用的是 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)。依托上述模型,本项目可实现全部使用**开源**模型**离线私有部署**。
|
||||
|
||||
⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 -> 在文本向量中匹配出与问句向量最相似的`top k`个 -> 匹配出的文本作为上下文和问题一起添加到`prompt`中 -> 提交给`LLM`生成回答。
|
||||
|
||||
📺 [原理介绍视频](https://www.bilibili.com/video/BV13M4y1e7cN/?share_source=copy_web&vd_source=e6c5aafe684f30fbe41925d61ca6d514)
|
||||
|
||||

|
||||
|
||||
从文档处理角度来看,实现流程如下:
|
||||
|
||||

|
||||
|
||||
|
||||
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
|
||||
|
||||
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM)
|
||||
|
|
@ -33,7 +36,7 @@
|
|||
- ChatGLM-6B 模型硬件需求
|
||||
|
||||
注:如未将模型下载至本地,请执行前检查`$HOME/.cache/huggingface/`文件夹剩余空间,模型文件下载至本地需要 15 GB 存储空间。
|
||||
|
||||
注:一些其它的可选启动项见[项目启动选项](docs/StartOption.md)
|
||||
模型下载方法可参考 [常见问题](docs/FAQ.md) 中 Q8。
|
||||
|
||||
| **量化等级** | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) |
|
||||
|
|
@ -79,9 +82,10 @@ docker run --gpus all -d --name chatglm -p 7860:7860 -v ~/github/langchain-ChatG
|
|||
|
||||
### 软件需求
|
||||
|
||||
本项目已在 Python 3.8 - 3.10,CUDA 11.7 环境下完成测试。已在 Windows、ARM 架构的 macOS、Linux 系统中完成测试。
|
||||
本项目已在 Python 3.8.1 - 3.10,CUDA 11.7 环境下完成测试。已在 Windows、ARM 架构的 macOS、Linux 系统中完成测试。
|
||||
|
||||
vue前端需要node18环境
|
||||
|
||||
### 从本地加载模型
|
||||
|
||||
请参考 [THUDM/ChatGLM-6B#从本地加载模型](https://github.com/THUDM/ChatGLM-6B#从本地加载模型)
|
||||
|
|
@ -94,6 +98,8 @@ vue前端需要node18环境
|
|||
|
||||
在开始执行 Web UI 或命令行交互前,请先检查 [configs/model_config.py](configs/model_config.py) 中的各项模型参数设计是否符合需求。
|
||||
|
||||
如需通过 fastchat 以 api 形式调用 llm,请参考 [fastchat 调用实现](docs/fastchat.md)
|
||||
|
||||
### 3. 执行脚本体验 Web UI 或命令行交互
|
||||
|
||||
> 注:鉴于环境部署过程中可能遇到问题,建议首先测试命令行脚本。建议命令行脚本测试可正常运行后再运行 Web UI。
|
||||
|
|
@ -122,9 +128,17 @@ $ pnpm i
|
|||
$ npm run dev
|
||||
```
|
||||
|
||||
执行后效果如下图所示:
|
||||
VUE 前端界面如下图所示:
|
||||
1. `对话` 界面
|
||||

|
||||
2. `知识库问答` 界面
|
||||

|
||||
3. `Bing搜索` 界面
|
||||

|
||||
|
||||
WebUI 界面如下图所示:
|
||||
1. `对话` Tab 界面
|
||||

|
||||

|
||||
2. `知识库测试 Beta` Tab 界面
|
||||

|
||||
3. `模型配置` Tab 界面
|
||||
|
|
@ -177,36 +191,44 @@ Web UI 可以实现如下功能:
|
|||
|
||||
- [ ] Langchain 应用
|
||||
- [x] 接入非结构化文档(已支持 md、pdf、docx、txt 文件格式)
|
||||
- [ ] 搜索引擎与本地网页接入
|
||||
- [x] jpg 与 png 格式图片的 OCR 文字识别
|
||||
- [x] 搜索引擎接入
|
||||
- [ ] 本地网页接入
|
||||
- [ ] 结构化数据接入(如 csv、Excel、SQL 等)
|
||||
- [ ] 知识图谱/图数据库接入
|
||||
- [ ] Agent 实现
|
||||
- [ ] 增加更多 LLM 模型支持
|
||||
- [x] 增加更多 LLM 模型支持
|
||||
- [x] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
|
||||
- [x] [THUDM/chatglm-6b-int8](https://huggingface.co/THUDM/chatglm-6b-int8)
|
||||
- [x] [THUDM/chatglm-6b-int4](https://huggingface.co/THUDM/chatglm-6b-int4)
|
||||
- [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
|
||||
- [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
|
||||
- [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft)
|
||||
- [ ] 增加更多 Embedding 模型支持
|
||||
- [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm
|
||||
- [x] 增加更多 Embedding 模型支持
|
||||
- [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
|
||||
- [x] [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
|
||||
- [x] [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese)
|
||||
- [x] [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
|
||||
- [x] [moka-ai/m3e-small](https://huggingface.co/moka-ai/m3e-small)
|
||||
- [x] [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base)
|
||||
- [ ] Web UI
|
||||
- [x] 利用 gradio 实现 Web UI DEMO
|
||||
- [x] 基于 gradio 实现 Web UI DEMO
|
||||
- [x] 基于 streamlit 实现 Web UI DEMO
|
||||
- [x] 添加输出内容及错误提示
|
||||
- [x] 引用标注
|
||||
- [ ] 增加知识库管理
|
||||
- [x] 选择知识库开始问答
|
||||
- [x] 上传文件/文件夹至知识库
|
||||
- [x] 知识库测试
|
||||
- [ ] 删除知识库中文件
|
||||
- [ ] 利用 streamlit 实现 Web UI Demo
|
||||
- [x] 支持搜索引擎问答
|
||||
- [ ] 增加 API 支持
|
||||
- [x] 利用 fastapi 实现 API 部署方式
|
||||
- [ ] 实现调用 API 的 Web UI Demo
|
||||
- [x] VUE 前端
|
||||
|
||||
## 项目交流群
|
||||

|
||||

|
||||
|
||||
🎉 langchain-ChatGLM 项目交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||
|
Before Width: | Height: | Size: 7.9 KiB After Width: | Height: | Size: 7.9 KiB |
|
|
@ -33,7 +33,7 @@ class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.jpg")
|
||||
filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.jpg")
|
||||
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
|
||||
docs = loader.load()
|
||||
for doc in docs:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf")
|
||||
filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.pdf")
|
||||
loader = UnstructuredPaddlePDFLoader(filepath, mode="elements")
|
||||
docs = loader.load()
|
||||
for doc in docs:
|
||||
|
|
|
|||
29
webui.py
29
webui.py
|
|
@ -16,9 +16,9 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|||
|
||||
def get_vs_list():
|
||||
lst_default = ["新建知识库"]
|
||||
if not os.path.exists(VS_ROOT_PATH):
|
||||
if not os.path.exists(KB_ROOT_PATH):
|
||||
return lst_default
|
||||
lst = os.listdir(VS_ROOT_PATH)
|
||||
lst = os.listdir(KB_ROOT_PATH)
|
||||
if not lst:
|
||||
return lst_default
|
||||
lst.sort()
|
||||
|
|
@ -141,14 +141,14 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
|
|||
|
||||
|
||||
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||
filelist = []
|
||||
if local_doc_qa.llm and local_doc_qa.embeddings:
|
||||
if isinstance(files, list):
|
||||
for file in files:
|
||||
filename = os.path.split(file.name)[-1]
|
||||
shutil.move(file.name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
|
||||
filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
|
||||
shutil.move(file.name, os.path.join(KB_ROOT_PATH, vs_id, "content", filename))
|
||||
filelist.append(os.path.join(KB_ROOT_PATH, vs_id, "content", filename))
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path, sentence_size)
|
||||
else:
|
||||
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
||||
|
|
@ -168,7 +168,7 @@ def change_vs_name_input(vs_id, history):
|
|||
if vs_id == "新建知识库":
|
||||
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
|
||||
else:
|
||||
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||
if "index.faiss" in os.listdir(vs_path):
|
||||
file_status = f"已加载知识库{vs_id},请开始提问"
|
||||
else:
|
||||
|
|
@ -220,11 +220,11 @@ def add_vs_name(vs_name, chatbot):
|
|||
visible=False), chatbot
|
||||
else:
|
||||
# 新建上传文件存储路径
|
||||
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_name)):
|
||||
os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_name))
|
||||
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_name, "content")):
|
||||
os.makedirs(os.path.join(KB_ROOT_PATH, vs_name, "content"))
|
||||
# 新建向量库存储路径
|
||||
if not os.path.exists(os.path.join(VS_ROOT_PATH, vs_name)):
|
||||
os.makedirs(os.path.join(VS_ROOT_PATH, vs_name))
|
||||
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_name, "vector_store")):
|
||||
os.makedirs(os.path.join(KB_ROOT_PATH, vs_name, "vector_store"))
|
||||
vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
|
||||
chatbot = chatbot + [[None, vs_status]]
|
||||
return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update(
|
||||
|
|
@ -234,12 +234,13 @@ def add_vs_name(vs_name, chatbot):
|
|||
# 自动化加载固定文件间中文件
|
||||
def reinit_vector_store(vs_id, history):
|
||||
try:
|
||||
shutil.rmtree(VS_ROOT_PATH)
|
||||
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||
shutil.rmtree(os.path.join(KB_ROOT_PATH, vs_id, "vector_store"))
|
||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||
sentence_size = gr.Number(value=SENTENCE_SIZE, precision=0,
|
||||
label="文本入库分句长度限制",
|
||||
interactive=True, visible=True)
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(UPLOAD_ROOT_PATH, vs_path, sentence_size)
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(os.path.join(KB_ROOT_PATH, vs_id, "content"),
|
||||
vs_path, sentence_size)
|
||||
model_status = """知识库构建成功"""
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
|
@ -285,7 +286,7 @@ default_theme_args = dict(
|
|||
|
||||
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
|
||||
vs_path, file_status, model_status = gr.State(
|
||||
os.path.join(VS_ROOT_PATH, get_vs_list()[0]) if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
|
||||
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
|
||||
model_status)
|
||||
gr.Markdown(webui_title)
|
||||
with gr.Tab("对话"):
|
||||
|
|
|
|||
16
webui_st.py
16
webui_st.py
|
|
@ -20,9 +20,9 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|||
|
||||
def get_vs_list():
|
||||
lst_default = ["新建知识库"]
|
||||
if not os.path.exists(VS_ROOT_PATH):
|
||||
if not os.path.exists(KB_ROOT_PATH):
|
||||
return lst_default
|
||||
lst = os.listdir(VS_ROOT_PATH)
|
||||
lst = os.listdir(KB_ROOT_PATH)
|
||||
if not lst:
|
||||
return lst_default
|
||||
lst.sort()
|
||||
|
|
@ -144,18 +144,18 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
|
|||
|
||||
|
||||
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||
filelist = []
|
||||
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)):
|
||||
os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id))
|
||||
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
|
||||
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
|
||||
if local_doc_qa.llm and local_doc_qa.embeddings:
|
||||
if isinstance(files, list):
|
||||
for file in files:
|
||||
filename = os.path.split(file.name)[-1]
|
||||
shutil.move(file.name, os.path.join(
|
||||
UPLOAD_ROOT_PATH, vs_id, filename))
|
||||
KB_ROOT_PATH, vs_id, "content", filename))
|
||||
filelist.append(os.path.join(
|
||||
UPLOAD_ROOT_PATH, vs_id, filename))
|
||||
KB_ROOT_PATH, vs_id, "content", filename))
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(
|
||||
filelist, vs_path, sentence_size)
|
||||
else:
|
||||
|
|
@ -516,7 +516,7 @@ with st.form('my_form', clear_on_submit=True):
|
|||
last_response = output_messages()
|
||||
for history, _ in answer(q,
|
||||
vs_path=os.path.join(
|
||||
VS_ROOT_PATH, vs_path),
|
||||
KB_ROOT_PATH, vs_path, "vector_store"),
|
||||
history=[],
|
||||
mode=mode,
|
||||
score_threshold=score_threshold,
|
||||
|
|
|
|||
Loading…
Reference in New Issue