update requirements.txt and kb_doc_api.py

This commit is contained in:
imClumsyPanda 2023-08-04 15:12:14 +08:00
parent 3318cef751
commit 7bfbe18011
3 changed files with 37 additions and 36 deletions

View File

@ -14,7 +14,7 @@ uvicorn~=0.23.1
starlette~=0.27.0 starlette~=0.27.0
numpy~=1.24.4 numpy~=1.24.4
pydantic~=1.10.11 pydantic~=1.10.11
unstructured[local-inference] unstructured[all-docs]
streamlit>=1.25.0 streamlit>=1.25.0
streamlit-option-menu streamlit-option-menu

View File

@ -1,14 +1,11 @@
import os import os
import urllib import urllib
import shutil import shutil
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from fastapi import File, Form, UploadFile from fastapi import File, Form, UploadFile
from server.utils import BaseResponse, ListResponse, torch_gc from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path, from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path,
get_vs_path, get_file_path, file2text) get_file_path, file2text, docs2vs,
from configs.model_config import embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE refresh_vs_cache, )
from server.knowledge_base.utils import load_embeddings, refresh_vs_cache
async def list_docs(knowledge_base_name: str): async def list_docs(knowledge_base_name: str):
@ -39,7 +36,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
saved_path = get_doc_path(knowledge_base_name) saved_path = get_doc_path(knowledge_base_name)
if not os.path.exists(saved_path): if not os.path.exists(saved_path):
return BaseResponse(code=404, msg="未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
file_content = await file.read() # 读取上传文件的内容 file_content = await file.read() # 读取上传文件的内容
@ -48,32 +45,17 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
file_status = f"文件 {file.filename} 已存在。" file_status = f"文件 {file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status) return BaseResponse(code=404, msg=file_status)
try:
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(file_content) f.write(file_content)
except Exception as e:
return BaseResponse(code=500, msg=f"{file.filename} 文件上传失败,报错信息为: {e}")
vs_path = get_vs_path(knowledge_base_name)
# TODO: 重写知识库生成/添加逻辑
filepath = get_file_path(knowledge_base_name, file.filename) filepath = get_file_path(knowledge_base_name, file.filename)
docs = file2text(filepath) docs = file2text(filepath)
loaded_files = [file] docs2vs(docs, knowledge_base_name)
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}")
vector_store = FAISS.load_local(vs_path, embeddings)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
torch_gc()
vector_store.save_local(vs_path)
if len(loaded_files) > 0:
file_status = f"成功上传文件 {file.filename}"
refresh_vs_cache(knowledge_base_name)
return BaseResponse(code=200, msg=file_status)
else:
file_status = f"上传文件 {file.filename} 失败"
return BaseResponse(code=500, msg=file_status)
async def delete_doc(knowledge_base_name: str, async def delete_doc(knowledge_base_name: str,
@ -84,24 +66,24 @@ async def delete_doc(knowledge_base_name: str,
knowledge_base_name = urllib.parse.unquote(knowledge_base_name) knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
if not os.path.exists(get_kb_path(knowledge_base_name)): if not os.path.exists(get_kb_path(knowledge_base_name)):
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_name} not found") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
doc_path = get_file_path(knowledge_base_name, doc_name) doc_path = get_file_path(knowledge_base_name, doc_name)
if os.path.exists(doc_path): if os.path.exists(doc_path):
os.remove(doc_path) os.remove(doc_path)
remain_docs = await list_docs(knowledge_base_name) remain_docs = await list_docs(knowledge_base_name)
if len(remain_docs.data) == 0: if len(remain_docs.data) == 0:
shutil.rmtree(get_kb_path(knowledge_base_name), ignore_errors=True) shutil.rmtree(get_kb_path(knowledge_base_name), ignore_errors=True)
return BaseResponse(code=200, msg=f"document {doc_name} delete success") return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功")
else: else:
# TODO: 重写从向量库中删除文件 # TODO: 重写从向量库中删除文件
status = "" # local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_name)) status = "" # local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_name))
if "success" in status: if "success" in status:
refresh_vs_cache(knowledge_base_name) refresh_vs_cache(knowledge_base_name)
return BaseResponse(code=200, msg=f"document {doc_name} delete success") return BaseResponse(code=200, msg=f"{doc_name} 文件删除成功")
else: else:
return BaseResponse(code=500, msg=f"document {doc_name} delete fail") return BaseResponse(code=500, msg=f"{doc_name} 文件删除失败")
else: else:
return BaseResponse(code=404, msg=f"document {doc_name} not found") return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
async def update_doc(): async def update_doc():

View File

@ -1,4 +1,6 @@
import os import os
from typing import List
from server.utils import torch_gc
from configs.model_config import KB_ROOT_PATH from configs.model_config import KB_ROOT_PATH
from langchain.vectorstores import FAISS from langchain.vectorstores import FAISS
from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings
@ -43,6 +45,23 @@ def file2text(filepath):
docs = loader.load_and_split(text_splitter) docs = loader.load_and_split(text_splitter)
return docs return docs
def docs2vs(
docs: List[Document],
knowledge_base_name: str):
vs_path = get_vs_path(knowledge_base_name)
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
vector_store = FAISS.load_local(vs_path, embeddings)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
torch_gc()
vector_store.save_local(vs_path)
refresh_vs_cache(knowledge_base_name)
@lru_cache(1) @lru_cache(1)
def load_embeddings(model: str, device: str): def load_embeddings(model: str, device: str):