add KnowledgeFile class
This commit is contained in:
parent
775d46fecf
commit
590367a5b5
|
|
@ -4,10 +4,10 @@ import shutil
|
|||
from fastapi import File, Form, UploadFile
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path,
|
||||
get_file_path, file2text, docs2vs,
|
||||
refresh_vs_cache, get_vs_path, )
|
||||
get_file_path, refresh_vs_cache, get_vs_path)
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
|
||||
|
||||
async def list_docs(knowledge_base_name: str):
|
||||
|
|
@ -57,9 +57,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
except Exception as e:
|
||||
return BaseResponse(code=500, msg=f"{file.filename} 文件上传失败,报错信息为: {e}")
|
||||
|
||||
filepath = get_file_path(knowledge_base_name, file.filename)
|
||||
docs = file2text(filepath)
|
||||
docs2vs(docs, knowledge_base_name)
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb_file.file2text()
|
||||
kb_file.docs2vs()
|
||||
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}")
|
||||
|
||||
|
|
@ -116,10 +117,11 @@ async def recreate_vector_store(knowledge_base_name: str):
|
|||
|
||||
docs = (await list_docs(kb)).data
|
||||
for i, filename in enumerate(docs):
|
||||
filepath = get_file_path(kb, filename)
|
||||
print(f"processing {filepath} to vector store.")
|
||||
docs = file2text(filepath)
|
||||
docs2vs(docs, kb)
|
||||
kb_file = KnowledgeFile(filename=filename,
|
||||
knowledge_base_name=kb)
|
||||
print(f"processing {get_file_path(kb, filename)} to vector store.")
|
||||
kb_file.file2text()
|
||||
kb_file.docs2vs()
|
||||
yield json.dumps({
|
||||
"total": len(docs),
|
||||
"finished": i + 1,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
import os.path
|
||||
from server.knowledge_base.utils import (get_file_path, get_vs_path,
|
||||
refresh_vs_cache, load_embeddings)
|
||||
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
|
||||
from langchain.vectorstores import FAISS
|
||||
from server.utils import torch_gc
|
||||
|
||||
|
||||
class KnowledgeFile:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
knowledge_base_name: str
|
||||
):
|
||||
self.knowledge_base_name = knowledge_base_name
|
||||
self.knowledge_base_type = "faiss"
|
||||
self.filename = filename
|
||||
self.ext = os.path.splitext(filename)[-1]
|
||||
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||
self.docs = None
|
||||
|
||||
def file2text(self):
|
||||
if self.ext in []:
|
||||
from langchain.document_loaders import UnstructuredFileLoader
|
||||
loader = UnstructuredFileLoader(self.filepath)
|
||||
elif self.ext in []:
|
||||
pass
|
||||
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
|
||||
self.docs = loader.load_and_split(text_splitter)
|
||||
return True
|
||||
|
||||
def docs2vs(self):
|
||||
vs_path = get_vs_path(self.knowledge_base_name)
|
||||
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
|
||||
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
|
||||
vector_store = FAISS.load_local(vs_path, embeddings)
|
||||
vector_store.add_documents(self.docs)
|
||||
torch_gc()
|
||||
else:
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
vector_store = FAISS.from_documents(self.docs, embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(vs_path)
|
||||
refresh_vs_cache(self.knowledge_base_name)
|
||||
return True
|
||||
|
|
@ -3,7 +3,6 @@ from typing import List
|
|||
from server.utils import torch_gc
|
||||
from configs.model_config import KB_ROOT_PATH
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.schema import Document
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from configs.model_config import (CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
|
||||
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
|
||||
|
|
@ -36,34 +35,6 @@ def validate_kb_name(knowledge_base_id: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def file2text(filepath):
|
||||
# TODO: 替换处理方式
|
||||
from langchain.document_loaders import UnstructuredFileLoader
|
||||
loader = UnstructuredFileLoader(filepath)
|
||||
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
return docs
|
||||
|
||||
def docs2vs(
|
||||
docs: List[Document],
|
||||
knowledge_base_name: str):
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
|
||||
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
|
||||
vector_store = FAISS.load_local(vs_path, embeddings)
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
else:
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(vs_path)
|
||||
refresh_vs_cache(knowledge_base_name)
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
|
|
@ -106,7 +77,3 @@ def refresh_vs_cache(kb_name: str):
|
|||
"""
|
||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = "/Users/liuqian/PycharmProjects/chatchat/knowledge_base/123/content/test.txt"
|
||||
docs = file2text(filepath)
|
||||
|
|
|
|||
Loading…
Reference in New Issue