diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index d2c3e4a..65957b1 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -4,10 +4,10 @@ import shutil from fastapi import File, Form, UploadFile from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path, - get_file_path, file2text, docs2vs, - refresh_vs_cache, get_vs_path, ) + get_file_path, refresh_vs_cache, get_vs_path) from fastapi.responses import StreamingResponse import json +from server.knowledge_base.knowledge_file import KnowledgeFile async def list_docs(knowledge_base_name: str): @@ -57,9 +57,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), except Exception as e: return BaseResponse(code=500, msg=f"{file.filename} 文件上传失败,报错信息为: {e}") - filepath = get_file_path(knowledge_base_name, file.filename) - docs = file2text(filepath) - docs2vs(docs, knowledge_base_name) + kb_file = KnowledgeFile(filename=file.filename, + knowledge_base_name=knowledge_base_name) + kb_file.file2text() + kb_file.docs2vs() return BaseResponse(code=200, msg=f"成功上传文件 {file.filename}") @@ -116,10 +117,11 @@ async def recreate_vector_store(knowledge_base_name: str): docs = (await list_docs(kb)).data for i, filename in enumerate(docs): - filepath = get_file_path(kb, filename) - print(f"processing {filepath} to vector store.") - docs = file2text(filepath) - docs2vs(docs, kb) + kb_file = KnowledgeFile(filename=filename, + knowledge_base_name=kb) + print(f"processing {get_file_path(kb, filename)} to vector store.") + kb_file.file2text() + kb_file.docs2vs() yield json.dumps({ "total": len(docs), "finished": i + 1, diff --git a/server/knowledge_base/knowledge_file.py b/server/knowledge_base/knowledge_file.py new file mode 100644 index 0000000..52a5699 --- /dev/null +++ b/server/knowledge_base/knowledge_file.py @@ -0,0 +1,48 @@ +import os.path +from server.knowledge_base.utils import (get_file_path, get_vs_path, + refresh_vs_cache, load_embeddings) +from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE) +from langchain.vectorstores import FAISS +from server.utils import torch_gc + + +class KnowledgeFile: + def __init__( + self, + filename: str, + knowledge_base_name: str + ): + self.knowledge_base_name = knowledge_base_name + self.knowledge_base_type = "faiss" + self.filename = filename + self.ext = os.path.splitext(filename)[-1] + self.filepath = get_file_path(knowledge_base_name, filename) + self.docs = None + + def file2text(self): + if self.ext in []: + from langchain.document_loaders import UnstructuredFileLoader + loader = UnstructuredFileLoader(self.filepath) + elif self.ext in []: + pass + + from langchain.text_splitter import CharacterTextSplitter + text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) + self.docs = loader.load_and_split(text_splitter) + return True + + def docs2vs(self): + vs_path = get_vs_path(self.knowledge_base_name) + embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE) + if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): + vector_store = FAISS.load_local(vs_path, embeddings) + vector_store.add_documents(self.docs) + torch_gc() + else: + if not os.path.exists(vs_path): + os.makedirs(vs_path) + vector_store = FAISS.from_documents(self.docs, embeddings) # docs 为Document列表 + torch_gc() + vector_store.save_local(vs_path) + refresh_vs_cache(self.knowledge_base_name) + return True \ No newline at end of file diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index dd0c49f..896a8ad 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -3,7 +3,6 @@ from typing import List from server.utils import torch_gc from configs.model_config import KB_ROOT_PATH from langchain.vectorstores import FAISS -from langchain.schema import Document from langchain.embeddings.huggingface import HuggingFaceEmbeddings from configs.model_config import (CACHED_VS_NUM, VECTOR_SEARCH_TOP_K, embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE) @@ -36,34 +35,6 @@ def validate_kb_name(knowledge_base_id: str) -> bool: return True -def file2text(filepath): - # TODO: 替换处理方式 - from langchain.document_loaders import UnstructuredFileLoader - loader = UnstructuredFileLoader(filepath) - - from langchain.text_splitter import CharacterTextSplitter - text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) - docs = loader.load_and_split(text_splitter) - return docs - -def docs2vs( - docs: List[Document], - knowledge_base_name: str): - vs_path = get_vs_path(knowledge_base_name) - embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE) - if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): - vector_store = FAISS.load_local(vs_path, embeddings) - vector_store.add_documents(docs) - torch_gc() - else: - if not os.path.exists(vs_path): - os.makedirs(vs_path) - vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 - torch_gc() - vector_store.save_local(vs_path) - refresh_vs_cache(knowledge_base_name) - - @lru_cache(1) def load_embeddings(model: str, device: str): embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], @@ -106,7 +77,3 @@ def refresh_vs_cache(kb_name: str): """ _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 - -if __name__ == "__main__": - filepath = "/Users/liuqian/PycharmProjects/chatchat/knowledge_base/123/content/test.txt" - docs = file2text(filepath)