add KnowledgeFile class

This commit is contained in:
imClumsyPanda 2023-08-04 23:54:26 +08:00
parent 775d46fecf
commit 590367a5b5
3 changed files with 59 additions and 42 deletions

View File

@ -4,10 +4,10 @@ import shutil
from fastapi import File, Form, UploadFile
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path,
get_file_path, file2text, docs2vs,
refresh_vs_cache, get_vs_path, )
get_file_path, refresh_vs_cache, get_vs_path)
from fastapi.responses import StreamingResponse
import json
from server.knowledge_base.knowledge_file import KnowledgeFile
async def list_docs(knowledge_base_name: str):
@ -57,9 +57,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
except Exception as e:
return BaseResponse(code=500, msg=f"{file.filename} 文件上传失败,报错信息为: {e}")
filepath = get_file_path(knowledge_base_name, file.filename)
docs = file2text(filepath)
docs2vs(docs, knowledge_base_name)
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
kb_file.file2text()
kb_file.docs2vs()
return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}")
@ -116,10 +117,11 @@ async def recreate_vector_store(knowledge_base_name: str):
docs = (await list_docs(kb)).data
for i, filename in enumerate(docs):
filepath = get_file_path(kb, filename)
print(f"processing {filepath} to vector store.")
docs = file2text(filepath)
docs2vs(docs, kb)
kb_file = KnowledgeFile(filename=filename,
knowledge_base_name=kb)
print(f"processing {get_file_path(kb, filename)} to vector store.")
kb_file.file2text()
kb_file.docs2vs()
yield json.dumps({
"total": len(docs),
"finished": i + 1,

View File

@ -0,0 +1,48 @@
import os.path
from server.knowledge_base.utils import (get_file_path, get_vs_path,
refresh_vs_cache, load_embeddings)
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
from langchain.vectorstores import FAISS
from server.utils import torch_gc
class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str
):
self.knowledge_base_name = knowledge_base_name
self.knowledge_base_type = "faiss"
self.filename = filename
self.ext = os.path.splitext(filename)[-1]
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
def file2text(self):
if self.ext in []:
from langchain.document_loaders import UnstructuredFileLoader
loader = UnstructuredFileLoader(self.filepath)
elif self.ext in []:
pass
from langchain.text_splitter import CharacterTextSplitter
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
self.docs = loader.load_and_split(text_splitter)
return True
def docs2vs(self):
vs_path = get_vs_path(self.knowledge_base_name)
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
vector_store = FAISS.load_local(vs_path, embeddings)
vector_store.add_documents(self.docs)
torch_gc()
else:
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = FAISS.from_documents(self.docs, embeddings) # docs 为Document列表
torch_gc()
vector_store.save_local(vs_path)
refresh_vs_cache(self.knowledge_base_name)
return True

View File

@ -3,7 +3,6 @@ from typing import List
from server.utils import torch_gc
from configs.model_config import KB_ROOT_PATH
from langchain.vectorstores import FAISS
from langchain.schema import Document
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
@ -36,34 +35,6 @@ def validate_kb_name(knowledge_base_id: str) -> bool:
return True
def file2text(filepath):
# TODO: 替换处理方式
from langchain.document_loaders import UnstructuredFileLoader
loader = UnstructuredFileLoader(filepath)
from langchain.text_splitter import CharacterTextSplitter
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
docs = loader.load_and_split(text_splitter)
return docs
def docs2vs(
docs: List[Document],
knowledge_base_name: str):
vs_path = get_vs_path(knowledge_base_name)
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
vector_store = FAISS.load_local(vs_path, embeddings)
vector_store.add_documents(docs)
torch_gc()
else:
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
torch_gc()
vector_store.save_local(vs_path)
refresh_vs_cache(knowledge_base_name)
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
@ -106,7 +77,3 @@ def refresh_vs_cache(kb_name: str):
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
if __name__ == "__main__":
filepath = "/Users/liuqian/PycharmProjects/chatchat/knowledge_base/123/content/test.txt"
docs = file2text(filepath)