update api/kb_doc_api and knowledge base management ui:
1. add update_doc to api which can udpate vector store from existed content file 2. add parameter `delete_content` to delete_doc api. user can decide whether delete local content file when delete doc. 3. fix bug in ApiReqeust.upload_doc 4. support listing docs existed in local folder bu not in db
This commit is contained in:
parent
25280e0cea
commit
c7b91bfaf1
|
|
@ -94,11 +94,11 @@ def create_app():
|
|||
summary="删除知识库内的文件"
|
||||
)(delete_doc)
|
||||
|
||||
# app.post("/knowledge_base/update_doc",
|
||||
# tags=["Knowledge Base Management"],
|
||||
# response_model=BaseResponse,
|
||||
# summary="上传文件到知识库,并删除另一个文件"
|
||||
# )(update_doc)
|
||||
app.post("/knowledge_base/update_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="更新现有文件到知识库"
|
||||
)(update_doc)
|
||||
|
||||
app.post("/knowledge_base/recreate_vector_store",
|
||||
tags=["Knowledge Base Management"],
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import os
|
||||
import urllib
|
||||
from fastapi import File, Form, UploadFile
|
||||
from fastapi import File, Form, Body, UploadFile
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import (validate_kb_name)
|
||||
from server.knowledge_base.utils import (get_file_path, validate_kb_name)
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder
|
||||
|
|
@ -58,8 +58,9 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
||||
|
||||
|
||||
async def delete_doc(knowledge_base_name: str,
|
||||
doc_name: str,
|
||||
async def delete_doc(knowledge_base_name: str = Body(...),
|
||||
doc_name: str = Body(...),
|
||||
delete_content: bool = Body(...),
|
||||
):
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
|
@ -73,14 +74,33 @@ async def delete_doc(knowledge_base_name: str,
|
|||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file)
|
||||
kb.delete_doc(kb_file, delete_content)
|
||||
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
|
||||
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
|
||||
|
||||
|
||||
async def update_doc():
|
||||
# TODO: 替换文件
|
||||
pass
|
||||
async def update_doc(
|
||||
knowledge_base_name: str = Body(...),
|
||||
file_name: str = Body(...),
|
||||
):
|
||||
'''
|
||||
更新知识库文档
|
||||
'''
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if os.path.exists(kb_file.filepath):
|
||||
kb.update_doc(kb_file)
|
||||
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
||||
else:
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
|
||||
|
||||
|
||||
async def download_doc():
|
||||
|
|
@ -89,9 +109,9 @@ async def download_doc():
|
|||
|
||||
|
||||
async def recreate_vector_store(
|
||||
knowledge_base_name: str,
|
||||
allow_empty_kb: bool = True,
|
||||
vs_type: str = "faiss",
|
||||
knowledge_base_name: str = Body(...),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body("faiss"),
|
||||
):
|
||||
'''
|
||||
recreate vector store from the content.
|
||||
|
|
|
|||
|
|
@ -65,21 +65,32 @@ class KBService(ABC):
|
|||
向知识库添加文件
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
if docs:
|
||||
embeddings = self._load_embeddings()
|
||||
self.do_add_doc(docs, embeddings)
|
||||
status = add_doc_to_db(kb_file)
|
||||
else:
|
||||
status = False
|
||||
return status
|
||||
|
||||
def delete_doc(self, kb_file: KnowledgeFile):
|
||||
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False):
|
||||
"""
|
||||
从知识库删除文件
|
||||
"""
|
||||
if os.path.exists(kb_file.filepath):
|
||||
if delete_content and os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
self.do_delete_doc(kb_file)
|
||||
status = delete_file_from_db(kb_file)
|
||||
return status
|
||||
|
||||
def update_doc(self, kb_file: KnowledgeFile):
|
||||
"""
|
||||
使用content中的文件更新向量库
|
||||
"""
|
||||
if os.path.exists(kb_file.filepath):
|
||||
self.delete_doc(kb_file)
|
||||
return self.add_doc(kb_file)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
filename=file_name))
|
||||
|
|
|
|||
|
|
@ -136,3 +136,13 @@ class FaissKBService(KBService):
|
|||
def do_clear_vs(self):
|
||||
shutil.rmtree(self.vs_path)
|
||||
os.makedirs(self.vs_path)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
if super().exist_doc(file_name):
|
||||
return "in_db"
|
||||
|
||||
content_path = os.path.join(self.kb_path, "content")
|
||||
if os.path.isfile(os.path.join(content_path, file_name)):
|
||||
return "in_folder"
|
||||
else:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import os
|
|||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config)
|
||||
from functools import lru_cache
|
||||
import langchain.document_loaders
|
||||
import sys
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ def config_aggrid(
|
|||
gb = GridOptionsBuilder.from_dataframe(df)
|
||||
gb.configure_column("No", width=50)
|
||||
for k, v in titles.items():
|
||||
gb.configure_column(k, v, maxWidth=100)
|
||||
gb.configure_column(k, v, maxWidth=100, wrapHeaderText=True)
|
||||
gb.configure_selection(selection_mode, use_checkbox, pre_selected_rows=[0])
|
||||
return gb
|
||||
|
||||
|
|
@ -149,7 +149,6 @@ def knowledge_base_page(api: ApiRequest):
|
|||
files = st.file_uploader("上传知识文件",
|
||||
["docx", "txt", "md", "csv", "xlsx", "pdf"],
|
||||
accept_multiple_files=True,
|
||||
key="files",
|
||||
)
|
||||
if st.button(
|
||||
"添加文件到知识库",
|
||||
|
|
@ -199,7 +198,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
cols = st.columns(3)
|
||||
selected_rows = doc_grid.get("selected_rows", [])
|
||||
|
||||
cols = st.columns([2, 3, 2])
|
||||
cols = st.columns(4)
|
||||
if selected_rows:
|
||||
file_name = selected_rows[0]["file_name"]
|
||||
file_path = get_file_path(kb, file_name)
|
||||
|
|
@ -207,9 +206,20 @@ def knowledge_base_page(api: ApiRequest):
|
|||
cols[0].download_button("下载选中文档", fp, file_name=file_name)
|
||||
else:
|
||||
cols[0].download_button("下载选中文档", "", disabled=True)
|
||||
if cols[2].button("删除选中文档!", type="primary"):
|
||||
|
||||
if cols[1].button("入库", disabled=len(selected_rows)==0):
|
||||
for row in selected_rows:
|
||||
ret = api.delete_kb_doc(kb, row["file_name"])
|
||||
api.update_kb_doc(kb, row["file_name"])
|
||||
st.experimental_rerun()
|
||||
|
||||
if cols[2].button("出库", disabled=len(selected_rows)==0):
|
||||
for row in selected_rows:
|
||||
api.delete_kb_doc(kb, row["file_name"])
|
||||
st.experimental_rerun()
|
||||
|
||||
if cols[3].button("删除选中文档!", type="primary"):
|
||||
for row in selected_rows:
|
||||
ret = api.delete_kb_doc(kb, row["file_name"], True)
|
||||
st.toast(ret["msg"])
|
||||
st.experimental_rerun()
|
||||
|
||||
|
|
|
|||
|
|
@ -397,9 +397,11 @@ class ApiRequest:
|
|||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
if isinstance(file, bytes):
|
||||
if isinstance(file, bytes): # raw bytes
|
||||
file = BytesIO(file)
|
||||
else:
|
||||
elif hasattr(file, "read"): # a file io like object
|
||||
filename = filename or file.name
|
||||
else: # a local path
|
||||
file = Path(file).absolute().open("rb")
|
||||
filename = filename or file.name
|
||||
|
||||
|
|
@ -410,6 +412,7 @@ class ApiRequest:
|
|||
|
||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||
temp_file.write(file.read())
|
||||
temp_file.seek(0)
|
||||
response = run_async(upload_doc(
|
||||
UploadFile(temp_file, filename=filename),
|
||||
knowledge_base_name,
|
||||
|
|
@ -428,6 +431,7 @@ class ApiRequest:
|
|||
self,
|
||||
knowledge_base_name: str,
|
||||
doc_name: str,
|
||||
delete_content: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -438,11 +442,34 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import delete_doc
|
||||
response = run_async(delete_doc(knowledge_base_name, doc_name))
|
||||
response = run_async(delete_doc(knowledge_base_name, doc_name, delete_content))
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.delete(
|
||||
"/knowledge_base/delete_doc",
|
||||
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name, "delete_content": delete_content},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def update_kb_doc(
|
||||
self,
|
||||
knowledge_base_name: str,
|
||||
doc_name: str,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/update_doc接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import update_doc
|
||||
response = run_async(update_doc(knowledge_base_name, doc_name))
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.delete(
|
||||
"/knowledge_base/update_doc",
|
||||
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name},
|
||||
)
|
||||
return response.json()
|
||||
|
|
|
|||
Loading…
Reference in New Issue