2023-07-27 23:22:07 +08:00
|
|
|
import os
|
|
|
|
|
import urllib
|
|
|
|
|
import shutil
|
|
|
|
|
from fastapi import File, Form, UploadFile
|
2023-08-04 15:12:14 +08:00
|
|
|
from server.utils import BaseResponse, ListResponse
|
2023-07-27 23:22:07 +08:00
|
|
|
from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path,
|
2023-08-04 23:54:26 +08:00
|
|
|
get_file_path, refresh_vs_cache, get_vs_path)
|
2023-08-04 20:26:14 +08:00
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
import json
|
2023-08-05 22:57:19 +08:00
|
|
|
from server.knowledge_base.knowledge_file import KnowledgeFile
|
|
|
|
|
from server.knowledge_base.knowledge_base import KnowledgeBase
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def list_docs(knowledge_base_name: str):
|
|
|
|
|
if not validate_kb_name(knowledge_base_name):
|
|
|
|
|
return ListResponse(code=403, msg="Don't attack me", data=[])
|
|
|
|
|
|
|
|
|
|
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
|
|
|
|
kb_path = get_kb_path(knowledge_base_name)
|
|
|
|
|
if not os.path.exists(kb_path):
|
|
|
|
|
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
|
|
|
|
else:
|
2023-08-05 22:57:19 +08:00
|
|
|
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
|
2023-07-27 23:22:07 +08:00
|
|
|
return ListResponse(data=all_doc_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|
|
|
|
knowledge_base_name: str = Form(..., description="知识库名称", example="kb1"),
|
2023-08-04 15:53:44 +08:00
|
|
|
override: bool = Form(False, description="覆盖已有文件", example=False),
|
2023-07-27 23:22:07 +08:00
|
|
|
):
|
|
|
|
|
if not validate_kb_name(knowledge_base_name):
|
|
|
|
|
return BaseResponse(code=403, msg="Don't attack me")
|
|
|
|
|
|
|
|
|
|
saved_path = get_doc_path(knowledge_base_name)
|
|
|
|
|
if not os.path.exists(saved_path):
|
2023-08-04 15:12:14 +08:00
|
|
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
|
|
|
file_content = await file.read() # 读取上传文件的内容
|
|
|
|
|
|
|
|
|
|
file_path = os.path.join(saved_path, file.filename)
|
2023-08-04 15:53:44 +08:00
|
|
|
if (os.path.exists(file_path)
|
|
|
|
|
and not override
|
|
|
|
|
and os.path.getsize(file_path) == len(file_content)
|
|
|
|
|
):
|
2023-07-27 23:22:07 +08:00
|
|
|
file_status = f"文件 {file.filename} 已存在。"
|
|
|
|
|
return BaseResponse(code=404, msg=file_status)
|
|
|
|
|
|
2023-08-04 15:12:14 +08:00
|
|
|
try:
|
|
|
|
|
with open(file_path, "wb") as f:
|
|
|
|
|
f.write(file_content)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return BaseResponse(code=500, msg=f"{file.filename} 文件上传失败,报错信息为: {e}")
|
2023-07-27 23:22:07 +08:00
|
|
|
|
2023-08-04 23:54:26 +08:00
|
|
|
kb_file = KnowledgeFile(filename=file.filename,
|
|
|
|
|
knowledge_base_name=knowledge_base_name)
|
2023-08-05 13:46:00 +08:00
|
|
|
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
2023-08-05 23:01:56 +08:00
|
|
|
kb.add_doc(kb_file)
|
2023-08-04 15:12:14 +08:00
|
|
|
|
|
|
|
|
return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}")
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def delete_doc(knowledge_base_name: str,
|
|
|
|
|
doc_name: str,
|
|
|
|
|
):
|
|
|
|
|
if not validate_kb_name(knowledge_base_name):
|
|
|
|
|
return BaseResponse(code=403, msg="Don't attack me")
|
|
|
|
|
|
|
|
|
|
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
|
|
|
|
if not os.path.exists(get_kb_path(knowledge_base_name)):
|
2023-08-04 15:12:14 +08:00
|
|
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
2023-07-27 23:22:07 +08:00
|
|
|
doc_path = get_file_path(knowledge_base_name, doc_name)
|
|
|
|
|
if os.path.exists(doc_path):
|
|
|
|
|
os.remove(doc_path)
|
|
|
|
|
remain_docs = await list_docs(knowledge_base_name)
|
|
|
|
|
if len(remain_docs.data) == 0:
|
|
|
|
|
shutil.rmtree(get_kb_path(knowledge_base_name), ignore_errors=True)
|
2023-08-04 15:12:14 +08:00
|
|
|
return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功")
|
2023-07-27 23:22:07 +08:00
|
|
|
else:
|
|
|
|
|
# TODO: 重写从向量库中删除文件
|
|
|
|
|
status = "" # local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_name))
|
|
|
|
|
if "success" in status:
|
2023-08-04 10:16:30 +08:00
|
|
|
refresh_vs_cache(knowledge_base_name)
|
2023-08-04 15:12:14 +08:00
|
|
|
return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功")
|
2023-07-27 23:22:07 +08:00
|
|
|
else:
|
2023-08-04 15:12:14 +08:00
|
|
|
return BaseResponse(code=500, msg=f"{doc_name} 文件删除失败")
|
2023-07-27 23:22:07 +08:00
|
|
|
else:
|
2023-08-04 15:12:14 +08:00
|
|
|
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def update_doc():
|
|
|
|
|
# TODO: 替换文件
|
2023-08-04 10:16:30 +08:00
|
|
|
# refresh_vs_cache(knowledge_base_name)
|
2023-07-27 23:22:07 +08:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
async def download_doc():
|
|
|
|
|
# TODO: 下载文件
|
|
|
|
|
pass
|
2023-08-04 20:26:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def recreate_vector_store(knowledge_base_name: str):
|
|
|
|
|
'''
|
|
|
|
|
recreate vector store from the content.
|
|
|
|
|
this is usefull when user can copy files to content folder directly instead of upload through network.
|
|
|
|
|
'''
|
2023-08-05 13:46:00 +08:00
|
|
|
async def output(kb_name):
|
|
|
|
|
vs_path = get_vs_path(kb_name)
|
2023-08-04 20:26:14 +08:00
|
|
|
if os.path.isdir(vs_path):
|
|
|
|
|
shutil.rmtree(vs_path)
|
|
|
|
|
os.mkdir(vs_path)
|
|
|
|
|
print(f"start to recreate vectore in {vs_path}")
|
|
|
|
|
|
2023-08-05 13:46:00 +08:00
|
|
|
docs = (await list_docs(kb_name)).data
|
2023-08-04 20:26:14 +08:00
|
|
|
for i, filename in enumerate(docs):
|
2023-08-04 23:54:26 +08:00
|
|
|
kb_file = KnowledgeFile(filename=filename,
|
2023-08-05 13:46:00 +08:00
|
|
|
knowledge_base_name=kb_name)
|
|
|
|
|
print(f"processing {kb_file.filepath} to vector store.")
|
|
|
|
|
kb = KnowledgeBase.load(knowledge_base_name=kb_name)
|
2023-08-05 23:01:56 +08:00
|
|
|
kb.add_doc(kb_file)
|
2023-08-04 20:26:14 +08:00
|
|
|
yield json.dumps({
|
|
|
|
|
"total": len(docs),
|
|
|
|
|
"finished": i + 1,
|
|
|
|
|
"doc": filename,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(output(knowledge_base_name), media_type="text/event-stream")
|