Feat (#1951)
* 知识库支持子目录(不包括temp和tmp开头的目录),文件相对路径总长度不可超过255 * init_database.py 增加 --import-db 参数,在版本升级时,如果 info.db 表结构发生变化,但向量库无需重建,可以在重建数据库后,使用本参数从旧的数据库中导入信息
This commit is contained in:
parent
d8e15b57ba
commit
554122f60e
|
|
@ -1,6 +1,7 @@
|
|||
import sys
|
||||
sys.path.append(".")
|
||||
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
|
||||
from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||
folder2db, prune_db_docs, prune_folder_files)
|
||||
from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
|
@ -28,6 +29,10 @@ if __name__ == "__main__":
|
|||
action="store_true",
|
||||
help=("drop the database tables before recreate vector stores")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--import-db",
|
||||
help="import tables from specified sqlite database"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u",
|
||||
"--update-in-db",
|
||||
|
|
@ -97,6 +102,8 @@ if __name__ == "__main__":
|
|||
if args.recreate_vs:
|
||||
print("recreating all vector stores")
|
||||
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
|
||||
elif args.import_db:
|
||||
import_from_db(args.import_db)
|
||||
elif args.update_in_db:
|
||||
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
|
||||
elif args.increament:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import SQLALCHEMY_DATABASE_URI
|
||||
|
|
@ -13,4 +13,4 @@ engine = create_engine(
|
|||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
Base: DeclarativeMeta = declarative_base()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from functools import wraps
|
||||
from contextlib import contextmanager
|
||||
from server.db.base import SessionLocal
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope():
|
||||
def session_scope() -> Session:
|
||||
"""上下文管理器用于自动获取 Session, 避免错误"""
|
||||
session = SessionLocal()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -5,10 +5,12 @@ from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
|||
list_files_from_folder,files2docs_in_thread,
|
||||
KnowledgeFile,)
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_file_repository import add_file_to_db
|
||||
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
|
||||
from server.db.base import Base, engine
|
||||
from server.db.session import session_scope
|
||||
import os
|
||||
from typing import Literal, Any, List
|
||||
from dateutil.parser import parse
|
||||
from typing import Literal, List
|
||||
|
||||
|
||||
def create_tables():
|
||||
|
|
@ -20,6 +22,45 @@ def reset_tables():
|
|||
create_tables()
|
||||
|
||||
|
||||
def import_from_db(
|
||||
sqlite_path: str = None,
|
||||
# csv_path: str = None,
|
||||
) -> bool:
|
||||
'''
|
||||
在知识库与向量库无变化的情况下,从备份数据库中导入数据到 info.db。
|
||||
适用于版本升级时,info.db 结构变化,但无需重新向量化的情况。
|
||||
请确保两边数据库表名一致,需要导入的字段名一致
|
||||
当前仅支持 sqlite
|
||||
'''
|
||||
import sqlite3 as sql
|
||||
from pprint import pprint
|
||||
|
||||
models = list(Base.registry.mappers)
|
||||
|
||||
try:
|
||||
con = sql.connect(sqlite_path)
|
||||
con.row_factory = sql.Row
|
||||
cur = con.cursor()
|
||||
tables = [x["name"] for x in cur.execute("select name from sqlite_master where type='table'").fetchall()]
|
||||
for model in models:
|
||||
table = model.local_table.fullname
|
||||
if table not in tables:
|
||||
continue
|
||||
print(f"processing table: {table}")
|
||||
with session_scope() as session:
|
||||
for row in cur.execute(f"select * from {table}").fetchall():
|
||||
data = {k: row[k] for k in row.keys() if k in model.columns}
|
||||
if "create_time" in data:
|
||||
data["create_time"] = parse(data["create_time"])
|
||||
pprint(data)
|
||||
session.add(model.class_(**data))
|
||||
con.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"无法读取备份数据库:{sqlite_path}。错误信息:{e}")
|
||||
return False
|
||||
|
||||
|
||||
def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
|
||||
kb_files = []
|
||||
for file in files:
|
||||
|
|
|
|||
Loading…
Reference in New Issue