update api/kb_doc_api and knowledge base management ui:

1. add update_doc to api which can udpate vector store from existed
   content file
2. add parameter `delete_content` to delete_doc api. user can decide
   whether delete local content file when delete doc.
3. fix bug in ApiReqeust.upload_doc
4. support listing docs existed in local folder bu not in db
This commit is contained in:
liunux4odoo 2023-08-09 16:52:04 +08:00
parent 25280e0cea
commit c7b91bfaf1
7 changed files with 109 additions and 30 deletions

View File

@ -94,11 +94,11 @@ def create_app():
summary="删除知识库内的文件" summary="删除知识库内的文件"
)(delete_doc) )(delete_doc)
# app.post("/knowledge_base/update_doc", app.post("/knowledge_base/update_doc",
# tags=["Knowledge Base Management"], tags=["Knowledge Base Management"],
# response_model=BaseResponse, response_model=BaseResponse,
# summary="上传文件到知识库,并删除另一个文件" summary="更新现有文件到知识库"
# )(update_doc) )(update_doc)
app.post("/knowledge_base/recreate_vector_store", app.post("/knowledge_base/recreate_vector_store",
tags=["Knowledge Base Management"], tags=["Knowledge Base Management"],

View File

@ -1,8 +1,8 @@
import os import os
import urllib import urllib
from fastapi import File, Form, UploadFile from fastapi import File, Form, Body, UploadFile
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import (validate_kb_name) from server.knowledge_base.utils import (get_file_path, validate_kb_name)
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import json import json
from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder
@ -58,8 +58,9 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
async def delete_doc(knowledge_base_name: str, async def delete_doc(knowledge_base_name: str = Body(...),
doc_name: str, doc_name: str = Body(...),
delete_content: bool = Body(...),
): ):
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -73,14 +74,33 @@ async def delete_doc(knowledge_base_name: str,
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
kb_file = KnowledgeFile(filename=doc_name, kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name) knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file) kb.delete_doc(kb_file, delete_content)
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败") # return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
async def update_doc(): async def update_doc(
# TODO: 替换文件 knowledge_base_name: str = Body(...),
pass file_name: str = Body(...),
):
'''
更新知识库文档
'''
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
async def download_doc(): async def download_doc():
@ -89,9 +109,9 @@ async def download_doc():
async def recreate_vector_store( async def recreate_vector_store(
knowledge_base_name: str, knowledge_base_name: str = Body(...),
allow_empty_kb: bool = True, allow_empty_kb: bool = Body(True),
vs_type: str = "faiss", vs_type: str = Body("faiss"),
): ):
''' '''
recreate vector store from the content. recreate vector store from the content.

View File

@ -65,21 +65,32 @@ class KBService(ABC):
向知识库添加文件 向知识库添加文件
""" """
docs = kb_file.file2text() docs = kb_file.file2text()
if docs:
embeddings = self._load_embeddings() embeddings = self._load_embeddings()
self.do_add_doc(docs, embeddings) self.do_add_doc(docs, embeddings)
status = add_doc_to_db(kb_file) status = add_doc_to_db(kb_file)
else:
status = False
return status return status
def delete_doc(self, kb_file: KnowledgeFile): def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False):
""" """
从知识库删除文件 从知识库删除文件
""" """
if os.path.exists(kb_file.filepath): if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath) os.remove(kb_file.filepath)
self.do_delete_doc(kb_file) self.do_delete_doc(kb_file)
status = delete_file_from_db(kb_file) status = delete_file_from_db(kb_file)
return status return status
def update_doc(self, kb_file: KnowledgeFile):
"""
使用content中的文件更新向量库
"""
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file)
return self.add_doc(kb_file)
def exist_doc(self, file_name: str): def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
filename=file_name)) filename=file_name))

View File

@ -136,3 +136,13 @@ class FaissKBService(KBService):
def do_clear_vs(self): def do_clear_vs(self):
shutil.rmtree(self.vs_path) shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path) os.makedirs(self.vs_path)
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
return "in_db"
content_path = os.path.join(self.kb_path, "content")
if os.path.isfile(os.path.join(content_path, file_name)):
return "in_folder"
else:
return False

View File

@ -3,6 +3,7 @@ import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config) from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config)
from functools import lru_cache from functools import lru_cache
import langchain.document_loaders
import sys import sys

View File

@ -88,7 +88,7 @@ def config_aggrid(
gb = GridOptionsBuilder.from_dataframe(df) gb = GridOptionsBuilder.from_dataframe(df)
gb.configure_column("No", width=50) gb.configure_column("No", width=50)
for k, v in titles.items(): for k, v in titles.items():
gb.configure_column(k, v, maxWidth=100) gb.configure_column(k, v, maxWidth=100, wrapHeaderText=True)
gb.configure_selection(selection_mode, use_checkbox, pre_selected_rows=[0]) gb.configure_selection(selection_mode, use_checkbox, pre_selected_rows=[0])
return gb return gb
@ -149,7 +149,6 @@ def knowledge_base_page(api: ApiRequest):
files = st.file_uploader("上传知识文件", files = st.file_uploader("上传知识文件",
["docx", "txt", "md", "csv", "xlsx", "pdf"], ["docx", "txt", "md", "csv", "xlsx", "pdf"],
accept_multiple_files=True, accept_multiple_files=True,
key="files",
) )
if st.button( if st.button(
"添加文件到知识库", "添加文件到知识库",
@ -199,7 +198,7 @@ def knowledge_base_page(api: ApiRequest):
cols = st.columns(3) cols = st.columns(3)
selected_rows = doc_grid.get("selected_rows", []) selected_rows = doc_grid.get("selected_rows", [])
cols = st.columns([2, 3, 2]) cols = st.columns(4)
if selected_rows: if selected_rows:
file_name = selected_rows[0]["file_name"] file_name = selected_rows[0]["file_name"]
file_path = get_file_path(kb, file_name) file_path = get_file_path(kb, file_name)
@ -207,9 +206,20 @@ def knowledge_base_page(api: ApiRequest):
cols[0].download_button("下载选中文档", fp, file_name=file_name) cols[0].download_button("下载选中文档", fp, file_name=file_name)
else: else:
cols[0].download_button("下载选中文档", "", disabled=True) cols[0].download_button("下载选中文档", "", disabled=True)
if cols[2].button("删除选中文档!", type="primary"):
if cols[1].button("入库", disabled=len(selected_rows)==0):
for row in selected_rows: for row in selected_rows:
ret = api.delete_kb_doc(kb, row["file_name"]) api.update_kb_doc(kb, row["file_name"])
st.experimental_rerun()
if cols[2].button("出库", disabled=len(selected_rows)==0):
for row in selected_rows:
api.delete_kb_doc(kb, row["file_name"])
st.experimental_rerun()
if cols[3].button("删除选中文档!", type="primary"):
for row in selected_rows:
ret = api.delete_kb_doc(kb, row["file_name"], True)
st.toast(ret["msg"]) st.toast(ret["msg"])
st.experimental_rerun() st.experimental_rerun()

View File

@ -397,9 +397,11 @@ class ApiRequest:
if no_remote_api is None: if no_remote_api is None:
no_remote_api = self.no_remote_api no_remote_api = self.no_remote_api
if isinstance(file, bytes): if isinstance(file, bytes): # raw bytes
file = BytesIO(file) file = BytesIO(file)
else: elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
file = Path(file).absolute().open("rb") file = Path(file).absolute().open("rb")
filename = filename or file.name filename = filename or file.name
@ -410,6 +412,7 @@ class ApiRequest:
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read()) temp_file.write(file.read())
temp_file.seek(0)
response = run_async(upload_doc( response = run_async(upload_doc(
UploadFile(temp_file, filename=filename), UploadFile(temp_file, filename=filename),
knowledge_base_name, knowledge_base_name,
@ -428,6 +431,7 @@ class ApiRequest:
self, self,
knowledge_base_name: str, knowledge_base_name: str,
doc_name: str, doc_name: str,
delete_content: bool = False,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -438,11 +442,34 @@ class ApiRequest:
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_doc_api import delete_doc from server.knowledge_base.kb_doc_api import delete_doc
response = run_async(delete_doc(knowledge_base_name, doc_name)) response = run_async(delete_doc(knowledge_base_name, doc_name, delete_content))
return response.dict() return response.dict()
else: else:
response = self.delete( response = self.delete(
"/knowledge_base/delete_doc", "/knowledge_base/delete_doc",
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name, "delete_content": delete_content},
)
return response.json()
def update_kb_doc(
self,
knowledge_base_name: str,
doc_name: str,
no_remote_api: bool = None,
):
'''
对应api.py/knowledge_base/update_doc接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
from server.knowledge_base.kb_doc_api import update_doc
response = run_async(update_doc(knowledge_base_name, doc_name))
return response.dict()
else:
response = self.delete(
"/knowledge_base/update_doc",
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name}, json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name},
) )
return response.json() return response.json()