diff --git a/configs/config.py b/configs/config.py new file mode 100644 index 0000000..9f1f45a --- /dev/null +++ b/configs/config.py @@ -0,0 +1,12 @@ +SQLALCHEMY_DATABASE_URI = "sqlite:///./langchain_chat_glm.db" +kbs_config = { + "faiss": { + }, + "milvus": { + "host": "127.0.0.1", + "port": "19530", + "user": "", + "password": "", + "secure": False, + } +} diff --git a/configs/model_config.py.example b/configs/model_config.py.example index b73ccd9..6f096b6 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -278,15 +278,3 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search" # 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG BING_SUBSCRIPTION_KEY = "" -kbs_config = { - "faiss": { - }, - "milvus": { - "host": "127.0.0.1", - "port": "19530", - "user": "", - "password": "", - "secure": False, - } -} -DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") \ No newline at end of file diff --git a/server/db/__init__.py b/server/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/db/base.py b/server/db/base.py new file mode 100644 index 0000000..9a74a72 --- /dev/null +++ b/server/db/base.py @@ -0,0 +1,12 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +from configs.config import SQLALCHEMY_DATABASE_URI + +engine = create_engine(SQLALCHEMY_DATABASE_URI) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() + diff --git a/server/db/models/__init__.py b/server/db/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/db/models/base.py b/server/db/models/base.py new file mode 100644 index 0000000..706464f --- /dev/null +++ b/server/db/models/base.py @@ -0,0 +1,13 @@ +from datetime import datetime +from sqlalchemy import Column, DateTime, String, Integer + + +class BaseModel: + """ + 基础模型 + """ + id = Column(Integer, primary_key=True, index=True, comment="主键ID") + create_time = Column(DateTime, default=datetime.utcnow, comment="创建时间") + update_time = Column(DateTime, default=None, onupdate=datetime.utcnow, comment="更新时间") + create_by = Column(String, default=None, comment="创建者") + update_by = Column(String, default=None, comment="更新者") diff --git a/server/db/models/knowledge_base_model.py b/server/db/models/knowledge_base_model.py new file mode 100644 index 0000000..6f9e9ca --- /dev/null +++ b/server/db/models/knowledge_base_model.py @@ -0,0 +1,19 @@ +from sqlalchemy import Column, Integer, String, DateTime + +from server.db.base import Base + + +class KnowledgeBaseModel(Base): + """ + 知识库模型 + """ + __tablename__ = 'knowledge_base' + id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID') + kb_name = Column(String, comment='知识库名称') + vs_type = Column(String, comment='嵌入模型类型') + embed_model = Column(String, comment='嵌入模型名称') + file_count = Column(Integer, comment='文件数量') + create_time = Column(DateTime, comment='创建时间') + + def __repr__(self): + return f"" diff --git a/server/db/models/knowledge_file_model.py b/server/db/models/knowledge_file_model.py new file mode 100644 index 0000000..b6798cc --- /dev/null +++ b/server/db/models/knowledge_file_model.py @@ -0,0 +1,21 @@ +from sqlalchemy import Column, Integer, String, DateTime + +from server.db.base import Base + + +class KnowledgeFileModel(Base): + """ + 知识文件模型 + """ + __tablename__ = 'knowledge_file' + id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID') + file_name = Column(String, comment='文件名') + file_ext = Column(String, comment='文件扩展名') + kb_name = Column(String, comment='所属知识库名称') + document_loader_name = Column(String, comment='文档加载器名称') + text_splitter_name = Column(String, comment='文本分割器名称') + file_version = Column(Integer, comment='文件版本') + create_time = Column(DateTime, comment='创建时间') + + def __repr__(self): + return f"" diff --git a/server/db/repository/__init__.py b/server/db/repository/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/db/repository/knowledge_base_repository.py b/server/db/repository/knowledge_base_repository.py new file mode 100644 index 0000000..aa4df4f --- /dev/null +++ b/server/db/repository/knowledge_base_repository.py @@ -0,0 +1,42 @@ +from server.db.models.knowledge_base_model import KnowledgeBaseModel +from server.db.session import with_session + + +@with_session +def add_kb_to_db(session, kb_name, vs_type, embed_model): + # 创建知识库实例 + kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model) + session.add(kb) + return True + + +@with_session +def list_kbs_from_db(session): + kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > 0).all() + kbs = [kb[0] for kb in kbs] + return kbs + + +@with_session +def kb_exists(session, kb_name): + kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() + status = True if kb else False + return status + + +@with_session +def load_kb_from_db(session, kb_name): + kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() + if kb: + kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model + else: + kb_name, vs_type, embed_model = None, None, None + return kb_name, vs_type, embed_model + + +@with_session +def delete_kb_from_db(session, kb_name): + kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() + if kb: + session.delete(kb) + return True diff --git a/server/db/repository/knowledge_file_repository.py b/server/db/repository/knowledge_file_repository.py new file mode 100644 index 0000000..ba6eb04 --- /dev/null +++ b/server/db/repository/knowledge_file_repository.py @@ -0,0 +1,49 @@ +from server.db.models.knowledge_base_model import KnowledgeBaseModel +from server.db.models.knowledge_file_model import KnowledgeFileModel +from server.db.session import with_session +from server.knowledge_base.knowledge_file import KnowledgeFile + + +@with_session +def list_docs_from_db(session, kb_name): + files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all() + docs = [f.file_name for f in files] + return docs + + +@with_session +def add_doc_to_db(session, kb_file: KnowledgeFile): + kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() + if kb: + # 如果已经存在该文件,则更新文件版本号 + existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename, + kb_name=kb_file.kb_name).first() + if existing_file: + existing_file.file_version += 1 + # 否则,添加新文件 + else: + session.add(kb_file) + kb.file_count += 1 + return True + + +@with_session +def delete_file_from_db(session, kb_file: KnowledgeFile): + existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename, + kb_name=kb_file.kb_name).first() + if existing_file: + session.delete(existing_file) + session.commit() + + kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() + if kb: + kb.file_count -= 1 + session.commit() + return True + + +@with_session +def doc_exists(session, kb_file: KnowledgeFile): + existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename, + kb_name=kb_file.kb_name).first() + return True if existing_file else False diff --git a/server/db/session.py b/server/db/session.py new file mode 100644 index 0000000..9ac50e6 --- /dev/null +++ b/server/db/session.py @@ -0,0 +1,48 @@ +from functools import wraps + +from sqlalchemy.orm import sessionmaker +from contextlib import contextmanager + +from server.db.base import engine, SessionLocal + + +@contextmanager +def session_scope(): + """上下文管理器用于自动获取 Session, 避免错误""" + session = SessionLocal() + try: + yield session + session.commit() + except: + session.rollback() + raise + finally: + session.close() + + +def with_session(f): + @wraps(f) + def wrapper(*args, **kwargs): + with session_scope() as session: + try: + result = f(session, *args, **kwargs) + session.commit() + return result + except: + session.rollback() + raise + + return wrapper + + +def get_db() -> SessionLocal: + db = SessionLocal() + try: + yield db + finally: + db.close() + + +def get_db0() -> SessionLocal: + db = SessionLocal() + return db diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 89f3c3e..11318af 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -8,9 +8,13 @@ from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document -from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K, +from configs.config import kbs_config +from configs.model_config import (VECTOR_SEARCH_TOP_K, embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL) -import datetime + +from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists +from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \ + list_docs_from_db from server.knowledge_base.utils import (get_kb_path, get_doc_path) from server.knowledge_base.knowledge_file import KnowledgeFile from typing import List @@ -22,171 +26,11 @@ class SupportedVSType: DEFAULT = 'default' -def init_db(): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute('''CREATE TABLE if not exists knowledge_base - (id INTEGER PRIMARY KEY AUTOINCREMENT, - kb_name TEXT, - vs_type TEXT, - embed_model TEXT, - file_count INTEGER, - create_time DATETIME) ''') - c.execute('''CREATE TABLE if not exists knowledge_files - (id INTEGER PRIMARY KEY AUTOINCREMENT, - file_name TEXT, - file_ext TEXT, - kb_name TEXT, - document_loader_name TEXT, - text_splitter_name TEXT, - file_version INTEGER, - create_time DATETIME) ''') - conn.commit() - conn.close() - - -def list_kbs_from_db(): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute(f'''SELECT kb_name - FROM knowledge_base - WHERE file_count>0 ''') - kbs = [i[0] for i in c.fetchall() if i] - conn.commit() - conn.close() - return kbs - - -def add_kb_to_db(kb_name, vs_type, embed_model): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # Insert a row of data - c.execute(f"""INSERT INTO knowledge_base - (kb_name, vs_type, embed_model, file_count, create_time) - VALUES - ('{kb_name}','{vs_type}','{embed_model}', - 0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") - conn.commit() - conn.close() - return True - - -def kb_exists(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute(f'''SELECT COUNT(*) - FROM knowledge_base - WHERE kb_name="{kb_name}" ''') - status = True if c.fetchone()[0] else False - conn.commit() - conn.close() - return status - - -def load_kb_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute(f'''SELECT kb_name, vs_type, embed_model - FROM knowledge_base - WHERE kb_name="{kb_name}" ''') - resp = c.fetchone() - if resp: - kb_name, vs_type, embed_model = resp - else: - kb_name, vs_type, embed_model = None, None, None - conn.commit() - conn.close() - return kb_name, vs_type, embed_model - - -def delete_kb_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute(f'''DELETE - FROM knowledge_base - WHERE kb_name="{kb_name}" ''') - c.execute(f"""DELETE - FROM knowledge_files - WHERE kb_name="{kb_name}" - """) - conn.commit() - conn.close() - return True - - -def list_docs_from_db(kb_name): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute(f'''SELECT file_name - FROM knowledge_files - WHERE kb_name="{kb_name}" ''') - kbs = [i[0] for i in c.fetchall() if i] - conn.commit() - conn.close() - return kbs - def list_docs_from_folder(kb_name: str): doc_path = get_doc_path(kb_name) return [file for file in os.listdir(doc_path) if os.path.isfile(os.path.join(doc_path, file))] -def add_doc_to_db(kb_file: KnowledgeFile): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # Insert a row of data - c.execute( - f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """) - record_exist = c.fetchone() - if record_exist is not None: - c.execute(f"""UPDATE knowledge_files - SET file_version = file_version + 1 - WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" - """) - else: - c.execute(f"""INSERT INTO knowledge_files - (file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time) - VALUES - ('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}', - '{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") - c.execute(f"""UPDATE knowledge_base - SET file_count = file_count + 1 - WHERE kb_name="{kb_file.kb_name}" - """) - conn.commit() - conn.close() - return True - - -def delete_file_from_db(kb_file: KnowledgeFile): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - # Insert a row of data - c.execute(f"""DELETE - FROM knowledge_files - WHERE file_name="{kb_file.filename}" - AND kb_name="{kb_file.kb_name}" - """) - c.execute(f"""UPDATE knowledge_base - SET file_count = file_count - 1 - WHERE kb_name="{kb_file.kb_name}" - """) - conn.commit() - conn.close() - return True - - -def doc_exists(kb_file: KnowledgeFile): - conn = sqlite3.connect(DB_ROOT_PATH) - c = conn.cursor() - c.execute(f'''SELECT COUNT(*) - FROM knowledge_files - WHERE file_name="{kb_file.filename}" - AND kb_name="{kb_file.kb_name}" ''') - status = True if c.fetchone()[0] else False - conn.commit() - conn.close() - return status - @lru_cache(1) def load_embeddings(model: str, device: str): @@ -323,7 +167,7 @@ class KBService(ABC): @abstractmethod def do_delete_doc(self, - kb_file: KnowledgeFile): + kb_file: KnowledgeFile): """ 从知识库删除文档子类实自己逻辑 """ diff --git a/server/knowledge_base/kb_service/default_kb_service.py b/server/knowledge_base/kb_service/default_kb_service.py index fb31859..57a5992 100644 --- a/server/knowledge_base/kb_service/default_kb_service.py +++ b/server/knowledge_base/kb_service/default_kb_service.py @@ -1,7 +1,24 @@ +from typing import List + +from langchain.embeddings.base import Embeddings +from langchain.schema import Document + from server.knowledge_base.kb_service.base import KBService class DefaultKBService(KBService): + def do_create_kb(self): + pass + + def do_drop_kb(self): + pass + + def do_add_doc(self, docs: List[Document], embeddings: Embeddings): + pass + + def do_clear_vs(self): + pass + def vs_type(self) -> str: return "default" diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 7dcceb4..e4d3787 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -4,9 +4,10 @@ from langchain.embeddings.base import Embeddings from langchain.schema import Document from langchain.vectorstores import Milvus -from configs.model_config import EMBEDDING_DEVICE, kbs_config +from configs.config import kbs_config +from configs.model_config import EMBEDDING_DEVICE from server.knowledge_base import KnowledgeFile -from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings, add_doc_to_db +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings class MilvusKBService(KBService): @@ -55,6 +56,7 @@ class MilvusKBService(KBService): """ docs = kb_file.file2text() self.milvus.add_documents(docs) + from server.db.repository.knowledge_file_repository import add_doc_to_db status = add_doc_to_db(kb_file) return status @@ -72,8 +74,11 @@ class MilvusKBService(KBService): if __name__ == '__main__': + # 测试建表使用 + from server.db.base import Base, engine + Base.metadata.create_all(bind=engine) milvusService = MilvusKBService("test") - # milvusService.add_doc(KnowledgeFile("test.pdf", "test")) - # milvusService.delete_doc(KnowledgeFile("test.pdf", "test")) + milvusService.add_doc(KnowledgeFile("test.pdf", "test")) + milvusService.delete_doc(KnowledgeFile("test.pdf", "test")) milvusService.do_drop_kb() print(milvusService.search_docs("测试")) diff --git a/server/knowledge_base/knowledge_base_factory.py b/server/knowledge_base/knowledge_base_factory.py index 7d61748..b45b6d6 100644 --- a/server/knowledge_base/knowledge_base_factory.py +++ b/server/knowledge_base/knowledge_base_factory.py @@ -1,6 +1,6 @@ -from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db +from server.knowledge_base.kb_service.base import KBService, SupportedVSType +from server.db.repository.knowledge_base_repository import load_kb_from_db from server.knowledge_base.kb_service.default_kb_service import DefaultKBService -from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService class KBServiceFactory: @@ -21,7 +21,7 @@ class KBServiceFactory: @staticmethod def get_service_by_name(kb_name: str ) -> KBService: - kb_name, vs_type = load_kb_from_db(kb_name) + kb_name, vs_type, _ = load_kb_from_db(kb_name) return KBServiceFactory.get_service(kb_name, vs_type) @staticmethod @@ -30,8 +30,10 @@ class KBServiceFactory: if __name__ == '__main__': + # 测试建表使用 + # from server.db.base import Base, engine + # Base.metadata.create_all(bind=engine) KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS) - init_db() KBService.create_kb() KBService = KBServiceFactory.get_default() print(KBService.list_kbs())