From 554122f60ea8444b6c3b3ce1eae87fb2231803bd Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Thu, 2 Nov 2023 14:46:39 +0800 Subject: [PATCH] Feat (#1951) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 知识库支持子目录(不包括temp和tmp开头的目录),文件相对路径总长度不可超过255 * init_database.py 增加 --import-db 参数,在版本升级时,如果 info.db 表结构发生变化,但向量库无需重建,可以在重建数据库后,使用本参数从旧的数据库中导入信息 --- init_database.py | 9 ++++++- server/db/base.py | 4 +-- server/db/session.py | 3 ++- server/knowledge_base/migrate.py | 45 ++++++++++++++++++++++++++++++-- 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/init_database.py b/init_database.py index 0b15b1e..9af7747 100644 --- a/init_database.py +++ b/init_database.py @@ -1,6 +1,7 @@ import sys sys.path.append(".") -from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files +from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, + folder2db, prune_db_docs, prune_folder_files) from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -28,6 +29,10 @@ if __name__ == "__main__": action="store_true", help=("drop the database tables before recreate vector stores") ) + parser.add_argument( + "--import-db", + help="import tables from specified sqlite database" + ) parser.add_argument( "-u", "--update-in-db", @@ -97,6 +102,8 @@ if __name__ == "__main__": if args.recreate_vs: print("recreating all vector stores") folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) + elif args.import_db: + import_from_db(args.import_db) elif args.update_in_db: folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) elif args.increament: diff --git a/server/db/base.py b/server/db/base.py index ae42ac0..e4fe754 100644 --- a/server/db/base.py +++ b/server/db/base.py @@ -1,5 +1,5 @@ from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta from sqlalchemy.orm import sessionmaker from configs import SQLALCHEMY_DATABASE_URI @@ -13,4 +13,4 @@ engine = create_engine( SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base = declarative_base() +Base: DeclarativeMeta = declarative_base() diff --git a/server/db/session.py b/server/db/session.py index 51ec55d..0f461d5 100644 --- a/server/db/session.py +++ b/server/db/session.py @@ -1,10 +1,11 @@ from functools import wraps from contextlib import contextmanager from server.db.base import SessionLocal +from sqlalchemy.orm import Session @contextmanager -def session_scope(): +def session_scope() -> Session: """上下文管理器用于自动获取 Session, 避免错误""" session = SessionLocal() try: diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index d0fa620..e8ebc06 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -5,10 +5,12 @@ from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, list_files_from_folder,files2docs_in_thread, KnowledgeFile,) from server.knowledge_base.kb_service.base import KBServiceFactory -from server.db.repository.knowledge_file_repository import add_file_to_db +from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported from server.db.base import Base, engine +from server.db.session import session_scope import os -from typing import Literal, Any, List +from dateutil.parser import parse +from typing import Literal, List def create_tables(): @@ -20,6 +22,45 @@ def reset_tables(): create_tables() +def import_from_db( + sqlite_path: str = None, + # csv_path: str = None, +) -> bool: + ''' + 在知识库与向量库无变化的情况下,从备份数据库中导入数据到 info.db。 + 适用于版本升级时,info.db 结构变化,但无需重新向量化的情况。 + 请确保两边数据库表名一致,需要导入的字段名一致 + 当前仅支持 sqlite + ''' + import sqlite3 as sql + from pprint import pprint + + models = list(Base.registry.mappers) + + try: + con = sql.connect(sqlite_path) + con.row_factory = sql.Row + cur = con.cursor() + tables = [x["name"] for x in cur.execute("select name from sqlite_master where type='table'").fetchall()] + for model in models: + table = model.local_table.fullname + if table not in tables: + continue + print(f"processing table: {table}") + with session_scope() as session: + for row in cur.execute(f"select * from {table}").fetchall(): + data = {k: row[k] for k in row.keys() if k in model.columns} + if "create_time" in data: + data["create_time"] = parse(data["create_time"]) + pprint(data) + session.add(model.class_(**data)) + con.close() + return True + except Exception as e: + print(f"无法读取备份数据库:{sqlite_path}。错误信息:{e}") + return False + + def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]: kb_files = [] for file in files: