Langchain-Chatchat/server/knowledge_base/kb_doc_api.py

113 lines
4.6 KiB
Python

import os
import urllib
import shutil
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from fastapi import File, Form, UploadFile
from server.utils import BaseResponse, ListResponse, torch_gc
from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path,
get_vs_path, get_file_path, file2text)
from configs.model_config import embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE
async def list_docs(knowledge_base_name: str):
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb_path = get_kb_path(knowledge_base_name)
local_doc_folder = get_doc_path(knowledge_base_name)
if not os.path.exists(kb_path):
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
if not os.path.exists(local_doc_folder):
all_doc_names = []
else:
all_doc_names = [
doc
for doc in os.listdir(local_doc_folder)
if os.path.isfile(os.path.join(local_doc_folder, doc))
]
return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", example="kb1"),
):
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
saved_path = get_doc_path(knowledge_base_name)
if not os.path.exists(saved_path):
return BaseResponse(code=404, msg="未找到知识库 {knowledge_base_name}")
file_content = await file.read() # 读取上传文件的内容
file_path = os.path.join(saved_path, file.filename)
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
file_status = f"文件 {file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
with open(file_path, "wb") as f:
f.write(file_content)
vs_path = get_vs_path(knowledge_base_name)
# TODO: 重写知识库生成/添加逻辑
filepath = get_file_path(knowledge_base_name, file.filename)
docs = file2text(filepath)
loaded_files = [file]
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
model_kwargs={'device': EMBEDDING_DEVICE})
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
vector_store = FAISS.load_local(vs_path, embeddings)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
torch_gc()
vector_store.save_local(vs_path)
if len(loaded_files) > 0:
file_status = f"成功上传文件 {file.filename}"
return BaseResponse(code=200, msg=file_status)
else:
file_status = f"上传文件 {file.filename} 失败"
return BaseResponse(code=500, msg=file_status)
async def delete_doc(knowledge_base_name: str,
doc_name: str,
):
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not os.path.exists(get_kb_path(knowledge_base_name)):
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_name} not found")
doc_path = get_file_path(knowledge_base_name, doc_name)
if os.path.exists(doc_path):
os.remove(doc_path)
remain_docs = await list_docs(knowledge_base_name)
if len(remain_docs.data) == 0:
shutil.rmtree(get_kb_path(knowledge_base_name), ignore_errors=True)
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
else:
# TODO: 重写从向量库中删除文件
status = "" # local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_name))
if "success" in status:
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
else:
return BaseResponse(code=500, msg=f"document {doc_name} delete fail")
else:
return BaseResponse(code=404, msg=f"document {doc_name} not found")
async def update_doc():
# TODO: 替换文件
pass
async def download_doc():
# TODO: 下载文件
pass