From 18d31f5116e633bd868609497a7b88f6591aba25 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sun, 6 Aug 2023 23:43:54 +0800 Subject: [PATCH] add KBService and KBServiceFactory class --- configs/model_config.py.example | 10 + server/knowledge_base/kb_service/__init__.py | 0 server/knowledge_base/kb_service/base.py | 335 ++++++++++++++++++ .../kb_service/default_kb_service.py | 27 ++ .../kb_service/faiss_kb_service.py | 125 +++++++ .../kb_service/milvus_kb_service.py | 65 ++++ server/knowledge_base/knowledge_base.py | 19 +- .../knowledge_base/knowledge_base_factory.py | 36 ++ server/knowledge_base/utils.py | 15 + 9 files changed, 615 insertions(+), 17 deletions(-) create mode 100644 server/knowledge_base/kb_service/__init__.py create mode 100644 server/knowledge_base/kb_service/base.py create mode 100644 server/knowledge_base/kb_service/default_kb_service.py create mode 100644 server/knowledge_base/kb_service/faiss_kb_service.py create mode 100644 server/knowledge_base/kb_service/milvus_kb_service.py create mode 100644 server/knowledge_base/knowledge_base_factory.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 6eb522d..ecbabac 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -277,3 +277,13 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search" # 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out # 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG BING_SUBSCRIPTION_KEY = "" + +kbs_config = { + "faiss": { + + }, + "milvus": { + "milvus_host": "192.168.50.128", + "milvus_port": 19530 + } +} \ No newline at end of file diff --git a/server/knowledge_base/kb_service/__init__.py b/server/knowledge_base/kb_service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py new file mode 100644 index 0000000..006ca81 --- /dev/null +++ b/server/knowledge_base/kb_service/base.py @@ -0,0 +1,335 @@ +from abc import ABC, abstractmethod + +import os +import sqlite3 +from functools import lru_cache + +from langchain.embeddings import HuggingFaceEmbeddings +from langchain.embeddings.base import Embeddings +from langchain.docstore.document import Document + +from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K, + embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL) +import datetime +from server.knowledge_base.utils import (get_kb_path, get_doc_path) +from server.knowledge_base.knowledge_file import KnowledgeFile +from typing import List + + +class SupportedVSType: + FAISS = 'faiss' + MILVUS = 'milvus' + DEFAULT = 'default' + + +def init_db(): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + c.execute('''CREATE TABLE if not exists knowledge_base + (id INTEGER PRIMARY KEY AUTOINCREMENT, + kb_name TEXT, + vs_type TEXT, + embed_model TEXT, + file_count INTEGER, + create_time DATETIME) ''') + c.execute('''CREATE TABLE if not exists knowledge_files + (id INTEGER PRIMARY KEY AUTOINCREMENT, + file_name TEXT, + file_ext TEXT, + kb_name TEXT, + document_loader_name TEXT, + text_splitter_name TEXT, + file_version INTEGER, + create_time DATETIME) ''') + conn.commit() + conn.close() + + +def list_kbs_from_db(): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + c.execute(f'''SELECT kb_name + FROM knowledge_base + WHERE file_count>0 ''') + kbs = [i[0] for i in c.fetchall() if i] + conn.commit() + conn.close() + return kbs + + +def add_kb_to_db(kb_name, vs_type, embed_model): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + # Insert a row of data + c.execute(f"""INSERT INTO knowledge_base + (kb_name, vs_type, embed_model, file_count, create_time) + VALUES + ('{kb_name}','{vs_type}','{embed_model}', + 0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") + conn.commit() + conn.close() + return True + + +def kb_exists(kb_name): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + c.execute(f'''SELECT COUNT(*) + FROM knowledge_base + WHERE kb_name="{kb_name}" ''') + status = True if c.fetchone()[0] else False + conn.commit() + conn.close() + return status + + +def load_kb_from_db(kb_name): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + c.execute(f'''SELECT kb_name, vs_type, embed_model + FROM knowledge_base + WHERE kb_name="{kb_name}" ''') + resp = c.fetchone() + if resp: + kb_name, vs_type, embed_model = resp + else: + kb_name, vs_type, embed_model = None, None, None + conn.commit() + conn.close() + return kb_name, vs_type, embed_model + + +def delete_kb_from_db(kb_name): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + c.execute(f'''DELETE + FROM knowledge_base + WHERE kb_name="{kb_name}" ''') + c.execute(f"""DELETE + FROM knowledge_files + WHERE kb_name="{kb_name}" + """) + conn.commit() + conn.close() + return True + + +def list_docs_from_db(kb_name): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + c.execute(f'''SELECT file_name + FROM knowledge_files + WHERE kb_name="{kb_name}" ''') + kbs = [i[0] for i in c.fetchall() if i] + conn.commit() + conn.close() + return kbs + + +def add_doc_to_db(kb_file: KnowledgeFile): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + # Insert a row of data + c.execute( + f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """) + record_exist = c.fetchone() + if record_exist is not None: + c.execute(f"""UPDATE knowledge_files + SET file_version = file_version + 1 + WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" + """) + else: + c.execute(f"""INSERT INTO knowledge_files + (file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time) + VALUES + ('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}', + '{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") + c.execute(f"""UPDATE knowledge_base + SET file_count = file_count + 1 + WHERE kb_name="{kb_file.kb_name}" + """) + conn.commit() + conn.close() + return True + + +def delete_file_from_db(kb_file: KnowledgeFile): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + # Insert a row of data + c.execute(f"""DELETE + FROM knowledge_files + WHERE file_name="{kb_file.filename}" + AND kb_name="{kb_file.kb_name}" + """) + c.execute(f"""UPDATE knowledge_base + SET file_count = file_count - 1 + WHERE kb_name="{kb_file.kb_name}" + """) + conn.commit() + conn.close() + return True + + +def doc_exists(kb_file: KnowledgeFile): + conn = sqlite3.connect(DB_ROOT_PATH) + c = conn.cursor() + c.execute(f'''SELECT COUNT(*) + FROM knowledge_files + WHERE file_name="{kb_file.filename}" + AND kb_name="{kb_file.kb_name}" ''') + status = True if c.fetchone()[0] else False + conn.commit() + conn.close() + return status + + +@lru_cache(1) +def load_embeddings(model: str, device: str): + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}) + return embeddings + + +class KBService(ABC): + + def __init__(self, + knowledge_base_name: str, + vector_store_type: str = "faiss", + embed_model: str = EMBEDDING_MODEL, + ): + self.kb_name = knowledge_base_name + self.vs_type = vector_store_type + self.embed_model = embed_model + self.kb_path = get_kb_path(self.kb_name) + self.doc_path = get_doc_path(self.kb_name) + self.do_init() + + def create_kb(self): + """ + 创建知识库 + """ + if not os.path.exists(self.doc_path): + os.makedirs(self.doc_path) + self.do_create_kb() + status = add_kb_to_db(self.kb_name, self.vs_type, self.embed_model) + return status + + def clear_vs(self): + """ + 用知识库中已上传文件重建向量库 + """ + self.do_clear_vs() + + def drop_kb(self): + """ + 删除知识库 + """ + self.do_remove_kb() + status = delete_kb_from_db(self.kb_name) + return status + + def add_doc(self, kb_file: KnowledgeFile): + """ + 向知识库添加文件 + """ + docs = kb_file.file2text() + embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) + self.do_add_doc(docs, embeddings) + status = add_doc_to_db(kb_file) + return status + + def delete_doc(self, kb_file: KnowledgeFile): + """ + 从知识库删除文件 + """ + if os.path.exists(kb_file.filepath): + os.remove(kb_file.filepath) + self.do_delete(kb_file) + status = delete_file_from_db(kb_file) + return status + + def exist_doc(self, file_name: str): + return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, + filename=file_name)) + + def list_docs(self): + return list_docs_from_db(self.kb_name) + + def search_docs(self, + query: str, + top_k: int = VECTOR_SEARCH_TOP_K, + embedding_device: str = EMBEDDING_DEVICE, ): + embeddings = load_embeddings(self.embed_model, embedding_device) + docs = self.do_search(query, top_k, embeddings) + return docs + + @abstractmethod + def do_create_kb(self): + """ + 创建知识库子类实自己逻辑 + """ + pass + + @staticmethod + def list_kbs_type(): + return list(kbs_config.keys()) + + @classmethod + def list_kbs(cls): + return list_kbs_from_db() + + @classmethod + def exists(cls, + knowledge_base_name: str): + return kb_exists(knowledge_base_name) + + @abstractmethod + def vs_type(self) -> str: + pass + + @abstractmethod + def do_init(self): + pass + + @abstractmethod + def do_remove_kb(self): + """ + 删除知识库子类实自己逻辑 + """ + pass + + @abstractmethod + def do_search(self, + query: str, + top_k: int, + embeddings: Embeddings, + ) -> List[Document]: + """ + 搜索知识库子类实自己逻辑 + """ + pass + + @abstractmethod + def do_add_doc(self, + docs: List[Document], + embeddings: Embeddings): + """ + 向知识库添加文档子类实自己逻辑 + """ + pass + + @abstractmethod + def do_delete(self, + kb_file: KnowledgeFile): + """ + 从知识库删除文档子类实自己逻辑 + """ + pass + + @abstractmethod + def do_clear_vs(self): + """ + 从知识库删除全部向量子类实自己逻辑 + """ + pass diff --git a/server/knowledge_base/kb_service/default_kb_service.py b/server/knowledge_base/kb_service/default_kb_service.py new file mode 100644 index 0000000..3a6e0a5 --- /dev/null +++ b/server/knowledge_base/kb_service/default_kb_service.py @@ -0,0 +1,27 @@ +from server.knowledge_base.kb_service.base import KBService + + +class DefaultKBService(KBService): + def vs_type(self) -> str: + return "default" + + def do_create_kbs(self): + pass + + def do_init(self): + pass + + def do_remove_kbs(self): + pass + + def do_search(self): + pass + + def do_insert_multi_knowledge(self): + pass + + def do_insert_one_knowledge(self): + pass + + def do_delete(self): + pass diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py new file mode 100644 index 0000000..3141754 --- /dev/null +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -0,0 +1,125 @@ +import os +import shutil + +from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_DEVICE +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings +from functools import lru_cache +from server.knowledge_base.utils import get_vs_path +from server.knowledge_base.knowledge_file import KnowledgeFile +from langchain.vectorstores import FAISS +from langchain.embeddings.base import Embeddings +from typing import List +from langchain.docstore.document import Document +from server.utils import torch_gc +import numpy as np + +_VECTOR_STORE_TICKS = {} + + +@lru_cache(CACHED_VS_NUM) +def load_vector_store( + knowledge_base_name: str, + embeddings: Embeddings, + tick: int, # tick will be changed by upload_doc etc. and make cache refreshed. +): + print(f"loading vector store in '{knowledge_base_name}'.") + vs_path = get_vs_path(knowledge_base_name) + search_index = FAISS.load_local(vs_path, embeddings) + return search_index + + +def refresh_vs_cache(kb_name: str): + """ + make vector store cache refreshed when next loading + """ + _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 + + +def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]): + overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values()) + if not overlapping: + raise ValueError("ids do not exist in the current object") + _reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()} + index_to_delete = [_reversed_index[i] for i in ids] + vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64)) + for _id in index_to_delete: + del vector_store.index_to_docstore_id[_id] + # Remove items from docstore. + overlapping2 = set(ids).intersection(vector_store.docstore._dict) + if not overlapping2: + raise ValueError(f"Tried to delete ids that does not exist: {ids}") + for _id in ids: + vector_store.docstore._dict.pop(_id) + return vector_store + + +class FaissKBService(KBService): + vs_path: str + kb_path: str + + def vs_type(self) -> str: + return SupportedVSType.FAISS + + @staticmethod + def get_vs_path(knowledge_base_name: str): + return os.path.join(FaissKBService.get_kb_path(knowledge_base_name), "vector_store") + + @staticmethod + def get_kb_path(knowledge_base_name: str): + return os.path.join(KB_ROOT_PATH, knowledge_base_name) + + def do_init(self): + self.kb_path = FaissKBService.get_kb_path(self.kb_name) + self.vs_path = FaissKBService.get_vs_path(self.kb_name) + + def do_create_kb(self): + if not os.path.exists(self.vs_path): + os.makedirs(self.vs_path) + + def do_remove_kb(self): + shutil.rmtree(self.kb_path) + + def do_search(self, + query: str, + top_k: int, + embeddings: Embeddings, + ) -> List[Document]: + search_index = load_vector_store(self.kb_name, + embeddings, + _VECTOR_STORE_TICKS.get(self.kb_name)) + docs = search_index.similarity_search(query, k=top_k) + return docs + + def do_add_doc(self, + docs: List[Document], + embeddings: Embeddings): + if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): + vector_store = FAISS.load_local(self.vs_path, embeddings) + vector_store.add_documents(docs) + torch_gc() + else: + if not os.path.exists(self.vs_path): + os.makedirs(self.vs_path) + vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 + torch_gc() + vector_store.save_local(self.vs_path) + refresh_vs_cache(self.kb_name) + + def do_delete(self, + kb_file: KnowledgeFile): + embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) + if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): + vector_store = FAISS.load_local(self.vs_path, embeddings) + ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] + if len(ids) == 0: + return None + print(len(ids)) + vector_store = delete_doc_from_faiss(vector_store, ids) + vector_store.save_local(self.vs_path) + refresh_vs_cache(self.kb_name) + return True + else: + return None + + def do_clear_vs(self): + shutil.rmtree(self.vs_path) diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py new file mode 100644 index 0000000..5d43633 --- /dev/null +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -0,0 +1,65 @@ +from pymilvus import ( + connections, + utility, + FieldSchema, + CollectionSchema, + DataType, + Collection, +) + +from server.knowledge_base.kb_service.base import KBService + + +def get_collection(milvus_name): + return Collection(milvus_name) + + +def search(milvus_name, content, limit=3): + search_params = { + "metric_type": "L2", + "params": {"nprobe": 10}, + } + c = get_collection(milvus_name) + return c.search(content, "embeddings", search_params, limit=limit, output_fields=["random"]) + + +class MilvusKBService(): + milvus_host: str + milvus_port: int + dim: int + + def __init__(self, knowledge_base_name: str, vector_store_type: str, milvus_host="localhost", milvus_port=19530, + dim=8): + + super().__init__(knowledge_base_name, vector_store_type) + self.milvus_host = milvus_host + self.milvus_port = milvus_port + self.dim = dim + + def connect(self): + connections.connect("default", host=self.milvus_host, port=self.milvus_port) + + def create_collection(self, milvus_name): + fields = [ + FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=False), + FieldSchema(name="content", dtype=DataType.STRING), + FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.dim) + ] + schema = CollectionSchema(fields) + collection = Collection(milvus_name, schema) + index = { + "index_type": "IVF_FLAT", + "metric_type": "L2", + "params": {"nlist": 128}, + } + collection.create_index("embeddings", index) + collection.load() + return collection + + def insert_collection(self, milvus_name, content=[]): + get_collection(milvus_name).insert(dataset) + + +if __name__ == '__main__': + milvusService = MilvusService(milvus_host='192.168.50.128') + milvusService.insert_collection(test,dataset) diff --git a/server/knowledge_base/knowledge_base.py b/server/knowledge_base/knowledge_base.py index b4d5311..10ccd9d 100644 --- a/server/knowledge_base/knowledge_base.py +++ b/server/knowledge_base/knowledge_base.py @@ -5,33 +5,18 @@ import shutil from langchain.vectorstores import FAISS from langchain.embeddings.huggingface import HuggingFaceEmbeddings from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE, - KB_ROOT_PATH, DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM) + DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM) from server.utils import torch_gc from functools import lru_cache from server.knowledge_base.knowledge_file import KnowledgeFile from typing import List import numpy as np +from server.knowledge_base.utils import (get_kb_path, get_doc_path, get_vs_path) SUPPORTED_VS_TYPES = ["faiss", "milvus"] _VECTOR_STORE_TICKS = {} - -def get_kb_path(knowledge_base_name: str): - return os.path.join(KB_ROOT_PATH, knowledge_base_name) - - -def get_doc_path(knowledge_base_name: str): - return os.path.join(get_kb_path(knowledge_base_name), "content") - - -def get_vs_path(knowledge_base_name: str): - return os.path.join(get_kb_path(knowledge_base_name), "vector_store") - - -def get_file_path(knowledge_base_name: str, doc_name: str): - return os.path.join(get_doc_path(knowledge_base_name), doc_name) - @lru_cache(1) def load_embeddings(model: str, device: str): embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], diff --git a/server/knowledge_base/knowledge_base_factory.py b/server/knowledge_base/knowledge_base_factory.py new file mode 100644 index 0000000..793074e --- /dev/null +++ b/server/knowledge_base/knowledge_base_factory.py @@ -0,0 +1,36 @@ +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db +from server.knowledge_base.kb_service.default_kb_service import DefaultKBService + + +class KBServiceFactory: + + @staticmethod + def get_service(kb_name: str, + vector_store_type: SupportedVSType + ) -> KBService: + if SupportedVSType.FAISS == vector_store_type: + from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService + return FaissKBService(kb_name) + elif SupportedVSType.DEFAULT == vector_store_type: + return DefaultKBService(kb_name) + + @staticmethod + def get_service_by_name(kb_name: str + ) -> KBService: + kb_name, vs_type = load_kb_from_db(kb_name) + return KBServiceFactory.get_service(kb_name, vs_type) + + @staticmethod + def get_default(): + return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT) + + +if __name__ == '__main__': + KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS) + init_db() + KBService.create_kbs() + KBService = KBServiceFactory.get_default() + print(KBService.list_kbs()) + KBService = KBServiceFactory.get_service_by_name("test") + print(KBService.list_docs()) + KBService.drop_kbs() diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 5773f71..215ab36 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,5 +1,20 @@ +import os.path +from configs.model_config import KB_ROOT_PATH + def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 if "../" in knowledge_base_id: return False return True + +def get_kb_path(knowledge_base_name: str): + return os.path.join(KB_ROOT_PATH, knowledge_base_name) + +def get_doc_path(knowledge_base_name: str): + return os.path.join(get_kb_path(knowledge_base_name), "content") + +def get_vs_path(knowledge_base_name: str): + return os.path.join(get_kb_path(knowledge_base_name), "vector_store") + +def get_file_path(knowledge_base_name: str, doc_name: str): + return os.path.join(get_doc_path(knowledge_base_name), doc_name) \ No newline at end of file