Langchain-Chatchat/server/knowledge_base/kb_doc_api.py

150 lines
5.8 KiB
Python
Raw Normal View History

2023-07-27 23:22:07 +08:00
import os
import urllib
from fastapi import File, Form, Body, UploadFile
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import (get_file_path, validate_kb_name)
from fastapi.responses import StreamingResponse
import json
from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.kb_service.base import SupportedVSType
from typing import Union
2023-07-27 23:22:07 +08:00
async def list_docs(knowledge_base_name: str):
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
2023-07-27 23:22:07 +08:00
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = kb.list_docs()
2023-07-27 23:22:07 +08:00
return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", example="kb1"),
override: bool = Form(False, description="覆盖已有文件", example=False),
2023-07-27 23:22:07 +08:00
):
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
2023-07-27 23:22:07 +08:00
file_content = await file.read() # 读取上传文件的内容
2023-08-05 23:35:20 +08:00
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
if (os.path.exists(kb_file.filepath)
2023-08-06 18:32:10 +08:00
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
2023-08-06 18:32:10 +08:00
# TODO: filesize 不同后的处理
2023-08-05 23:35:20 +08:00
file_status = f"文件 {kb_file.filename} 已存在。"
2023-07-27 23:22:07 +08:00
return BaseResponse(code=404, msg=file_status)
try:
2023-08-05 23:35:20 +08:00
with open(kb_file.filepath, "wb") as f:
f.write(file_content)
except Exception as e:
2023-08-05 23:35:20 +08:00
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
2023-07-27 23:22:07 +08:00
2023-08-05 23:01:56 +08:00
kb.add_doc(kb_file)
2023-08-05 23:35:20 +08:00
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
2023-07-27 23:22:07 +08:00
async def delete_doc(knowledge_base_name: str = Body(...),
doc_name: str = Body(...),
delete_content: bool = Body(...),
):
2023-07-27 23:22:07 +08:00
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
2023-08-05 23:35:20 +08:00
if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
2023-08-05 23:35:20 +08:00
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content)
2023-08-05 23:35:20 +08:00
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
# return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败")
2023-07-27 23:22:07 +08:00
async def update_doc(
knowledge_base_name: str = Body(...),
file_name: str = Body(...),
):
'''
更新知识库文档
'''
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
2023-07-27 23:22:07 +08:00
2023-08-06 18:32:10 +08:00
2023-07-27 23:22:07 +08:00
async def download_doc():
# TODO: 下载文件
pass
async def recreate_vector_store(
knowledge_base_name: str = Body(...),
allow_empty_kb: bool = Body(True),
vs_type: str = Body("faiss"),
):
'''
recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network.
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
'''
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
if allow_empty_kb:
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type)
else:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
2023-08-06 18:32:10 +08:00
async def output(kb):
kb.create_kb()
kb.clear_vs()
2023-08-06 18:32:10 +08:00
print(f"start to recreate vector store of {kb.kb_name}")
docs = list_docs_from_folder(knowledge_base_name)
print(docs)
for i, filename in enumerate(docs):
yield json.dumps({
"total": len(docs),
"finished": i,
"doc": filename,
})
try:
kb_file = KnowledgeFile(filename=filename,
knowledge_base_name=kb.kb_name)
print(f"processing {kb_file.filepath} to vector store.")
kb.add_doc(kb_file)
except ValueError as e:
print(e)
2023-08-06 18:32:10 +08:00
return StreamingResponse(output(kb), media_type="text/event-stream")