from abc import ABC, abstractmethod import os import sqlite3 from functools import lru_cache from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K, embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL) import datetime from server.knowledge_base.utils import (get_kb_path, get_doc_path) from server.knowledge_base.knowledge_file import KnowledgeFile from typing import List class SupportedVSType: FAISS = 'faiss' MILVUS = 'milvus' DEFAULT = 'default' def init_db(): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() c.execute('''CREATE TABLE if not exists knowledge_base (id INTEGER PRIMARY KEY AUTOINCREMENT, kb_name TEXT, vs_type TEXT, embed_model TEXT, file_count INTEGER, create_time DATETIME) ''') c.execute('''CREATE TABLE if not exists knowledge_files (id INTEGER PRIMARY KEY AUTOINCREMENT, file_name TEXT, file_ext TEXT, kb_name TEXT, document_loader_name TEXT, text_splitter_name TEXT, file_version INTEGER, create_time DATETIME) ''') conn.commit() conn.close() def list_kbs_from_db(): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() c.execute(f'''SELECT kb_name FROM knowledge_base WHERE file_count>0 ''') kbs = [i[0] for i in c.fetchall() if i] conn.commit() conn.close() return kbs def add_kb_to_db(kb_name, vs_type, embed_model): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() # Insert a row of data c.execute(f"""INSERT INTO knowledge_base (kb_name, vs_type, embed_model, file_count, create_time) VALUES ('{kb_name}','{vs_type}','{embed_model}', 0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") conn.commit() conn.close() return True def kb_exists(kb_name): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() c.execute(f'''SELECT COUNT(*) FROM knowledge_base WHERE kb_name="{kb_name}" ''') status = True if c.fetchone()[0] else False conn.commit() conn.close() return status def load_kb_from_db(kb_name): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() c.execute(f'''SELECT kb_name, vs_type, embed_model FROM knowledge_base WHERE kb_name="{kb_name}" ''') resp = c.fetchone() if resp: kb_name, vs_type, embed_model = resp else: kb_name, vs_type, embed_model = None, None, None conn.commit() conn.close() return kb_name, vs_type, embed_model def delete_kb_from_db(kb_name): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() c.execute(f'''DELETE FROM knowledge_base WHERE kb_name="{kb_name}" ''') c.execute(f"""DELETE FROM knowledge_files WHERE kb_name="{kb_name}" """) conn.commit() conn.close() return True def list_docs_from_db(kb_name): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() c.execute(f'''SELECT file_name FROM knowledge_files WHERE kb_name="{kb_name}" ''') kbs = [i[0] for i in c.fetchall() if i] conn.commit() conn.close() return kbs def add_doc_to_db(kb_file: KnowledgeFile): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() # Insert a row of data c.execute( f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """) record_exist = c.fetchone() if record_exist is not None: c.execute(f"""UPDATE knowledge_files SET file_version = file_version + 1 WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """) else: c.execute(f"""INSERT INTO knowledge_files (file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time) VALUES ('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}', '{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") c.execute(f"""UPDATE knowledge_base SET file_count = file_count + 1 WHERE kb_name="{kb_file.kb_name}" """) conn.commit() conn.close() return True def delete_file_from_db(kb_file: KnowledgeFile): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() # Insert a row of data c.execute(f"""DELETE FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """) c.execute(f"""UPDATE knowledge_base SET file_count = file_count - 1 WHERE kb_name="{kb_file.kb_name}" """) conn.commit() conn.close() return True def doc_exists(kb_file: KnowledgeFile): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() c.execute(f'''SELECT COUNT(*) FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" ''') status = True if c.fetchone()[0] else False conn.commit() conn.close() return status @lru_cache(1) def load_embeddings(model: str, device: str): embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) return embeddings class KBService(ABC): def __init__(self, knowledge_base_name: str, vector_store_type: str = "faiss", embed_model: str = EMBEDDING_MODEL, ): self.kb_name = knowledge_base_name self.vs_type = vector_store_type self.embed_model = embed_model self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) self.do_init() def create_kb(self): """ 创建知识库 """ if not os.path.exists(self.doc_path): os.makedirs(self.doc_path) self.do_create_kb() status = add_kb_to_db(self.kb_name, self.vs_type, self.embed_model) return status def clear_vs(self): """ 用知识库中已上传文件重建向量库 """ self.do_clear_vs() def drop_kb(self): """ 删除知识库 """ self.do_remove_kb() status = delete_kb_from_db(self.kb_name) return status def add_doc(self, kb_file: KnowledgeFile): """ 向知识库添加文件 """ docs = kb_file.file2text() embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) self.do_add_doc(docs, embeddings) status = add_doc_to_db(kb_file) return status def delete_doc(self, kb_file: KnowledgeFile): """ 从知识库删除文件 """ if os.path.exists(kb_file.filepath): os.remove(kb_file.filepath) self.do_delete(kb_file) status = delete_file_from_db(kb_file) return status def exist_doc(self, file_name: str): return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, filename=file_name)) def list_docs(self): return list_docs_from_db(self.kb_name) def search_docs(self, query: str, top_k: int = VECTOR_SEARCH_TOP_K, embedding_device: str = EMBEDDING_DEVICE, ): embeddings = load_embeddings(self.embed_model, embedding_device) docs = self.do_search(query, top_k, embeddings) return docs @abstractmethod def do_create_kb(self): """ 创建知识库子类实自己逻辑 """ pass @staticmethod def list_kbs_type(): return list(kbs_config.keys()) @classmethod def list_kbs(cls): return list_kbs_from_db() @classmethod def exists(cls, knowledge_base_name: str): return kb_exists(knowledge_base_name) @abstractmethod def vs_type(self) -> str: pass @abstractmethod def do_init(self): pass @abstractmethod def do_remove_kb(self): """ 删除知识库子类实自己逻辑 """ pass @abstractmethod def do_search(self, query: str, top_k: int, embeddings: Embeddings, ) -> List[Document]: """ 搜索知识库子类实自己逻辑 """ pass @abstractmethod def do_add_doc(self, docs: List[Document], embeddings: Embeddings): """ 向知识库添加文档子类实自己逻辑 """ pass @abstractmethod def do_delete(self, kb_file: KnowledgeFile): """ 从知识库删除文档子类实自己逻辑 """ pass @abstractmethod def do_clear_vs(self): """ 从知识库删除全部向量子类实自己逻辑 """ pass