update requirements.txt and kb_doc_api.py
This commit is contained in:
parent
3318cef751
commit
7bfbe18011
|
|
@ -14,7 +14,7 @@ uvicorn~=0.23.1
|
|||
starlette~=0.27.0
|
||||
numpy~=1.24.4
|
||||
pydantic~=1.10.11
|
||||
unstructured[local-inference]
|
||||
unstructured[all-docs]
|
||||
|
||||
streamlit>=1.25.0
|
||||
streamlit-option-menu
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
import os
|
||||
import urllib
|
||||
import shutil
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
from fastapi import File, Form, UploadFile
|
||||
from server.utils import BaseResponse, ListResponse, torch_gc
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path,
|
||||
get_vs_path, get_file_path, file2text)
|
||||
from configs.model_config import embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE
|
||||
from server.knowledge_base.utils import load_embeddings, refresh_vs_cache
|
||||
get_file_path, file2text, docs2vs,
|
||||
refresh_vs_cache, )
|
||||
|
||||
|
||||
async def list_docs(knowledge_base_name: str):
|
||||
|
|
@ -39,7 +36,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
|
||||
saved_path = get_doc_path(knowledge_base_name)
|
||||
if not os.path.exists(saved_path):
|
||||
return BaseResponse(code=404, msg="未找到知识库 {knowledge_base_name}")
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
file_content = await file.read() # 读取上传文件的内容
|
||||
|
||||
|
|
@ -48,32 +45,17 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
file_status = f"文件 {file.filename} 已存在。"
|
||||
return BaseResponse(code=404, msg=file_status)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
try:
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
except Exception as e:
|
||||
return BaseResponse(code=500, msg=f"{file.filename} 文件上传失败,报错信息为: {e}")
|
||||
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
# TODO: 重写知识库生成/添加逻辑
|
||||
filepath = get_file_path(knowledge_base_name, file.filename)
|
||||
docs = file2text(filepath)
|
||||
loaded_files = [file]
|
||||
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
|
||||
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
|
||||
vector_store = FAISS.load_local(vs_path, embeddings)
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
else:
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(vs_path)
|
||||
if len(loaded_files) > 0:
|
||||
file_status = f"成功上传文件 {file.filename}"
|
||||
refresh_vs_cache(knowledge_base_name)
|
||||
return BaseResponse(code=200, msg=file_status)
|
||||
else:
|
||||
file_status = f"上传文件 {file.filename} 失败"
|
||||
return BaseResponse(code=500, msg=file_status)
|
||||
docs2vs(docs, knowledge_base_name)
|
||||
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}")
|
||||
|
||||
|
||||
async def delete_doc(knowledge_base_name: str,
|
||||
|
|
@ -84,24 +66,24 @@ async def delete_doc(knowledge_base_name: str,
|
|||
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
if not os.path.exists(get_kb_path(knowledge_base_name)):
|
||||
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_name} not found")
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
doc_path = get_file_path(knowledge_base_name, doc_name)
|
||||
if os.path.exists(doc_path):
|
||||
os.remove(doc_path)
|
||||
remain_docs = await list_docs(knowledge_base_name)
|
||||
if len(remain_docs.data) == 0:
|
||||
shutil.rmtree(get_kb_path(knowledge_base_name), ignore_errors=True)
|
||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||
return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功")
|
||||
else:
|
||||
# TODO: 重写从向量库中删除文件
|
||||
status = "" # local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_name))
|
||||
if "success" in status:
|
||||
refresh_vs_cache(knowledge_base_name)
|
||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||
return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功")
|
||||
else:
|
||||
return BaseResponse(code=500, msg=f"document {doc_name} delete fail")
|
||||
return BaseResponse(code=500, msg=f"{doc_name} 文件删除失败")
|
||||
else:
|
||||
return BaseResponse(code=404, msg=f"document {doc_name} not found")
|
||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||
|
||||
|
||||
async def update_doc():
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
from typing import List
|
||||
from server.utils import torch_gc
|
||||
from configs.model_config import KB_ROOT_PATH
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
|
|
@ -43,6 +45,23 @@ def file2text(filepath):
|
|||
docs = loader.load_and_split(text_splitter)
|
||||
return docs
|
||||
|
||||
def docs2vs(
|
||||
docs: List[Document],
|
||||
knowledge_base_name: str):
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
|
||||
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
|
||||
vector_store = FAISS.load_local(vs_path, embeddings)
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
else:
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(vs_path)
|
||||
refresh_vs_cache(knowledge_base_name)
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
|
|
|
|||
Loading…
Reference in New Issue