Feat (#1951)
* 知识库支持子目录(不包括temp和tmp开头的目录),文件相对路径总长度不可超过255 * init_database.py 增加 --import-db 参数,在版本升级时,如果 info.db 表结构发生变化,但向量库无需重建,可以在重建数据库后,使用本参数从旧的数据库中导入信息
This commit is contained in:
parent
d8e15b57ba
commit
554122f60e
|
|
@ -1,6 +1,7 @@
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(".")
|
sys.path.append(".")
|
||||||
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
|
from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||||
|
folder2db, prune_db_docs, prune_folder_files)
|
||||||
from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
|
from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
|
||||||
import nltk
|
import nltk
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
@ -28,6 +29,10 @@ if __name__ == "__main__":
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help=("drop the database tables before recreate vector stores")
|
help=("drop the database tables before recreate vector stores")
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--import-db",
|
||||||
|
help="import tables from specified sqlite database"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-u",
|
"-u",
|
||||||
"--update-in-db",
|
"--update-in-db",
|
||||||
|
|
@ -97,6 +102,8 @@ if __name__ == "__main__":
|
||||||
if args.recreate_vs:
|
if args.recreate_vs:
|
||||||
print("recreating all vector stores")
|
print("recreating all vector stores")
|
||||||
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
|
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
|
||||||
|
elif args.import_db:
|
||||||
|
import_from_db(args.import_db)
|
||||||
elif args.update_in_db:
|
elif args.update_in_db:
|
||||||
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
|
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
|
||||||
elif args.increament:
|
elif args.increament:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from configs import SQLALCHEMY_DATABASE_URI
|
from configs import SQLALCHEMY_DATABASE_URI
|
||||||
|
|
@ -13,4 +13,4 @@ engine = create_engine(
|
||||||
|
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
Base = declarative_base()
|
Base: DeclarativeMeta = declarative_base()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from server.db.base import SessionLocal
|
from server.db.base import SessionLocal
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def session_scope():
|
def session_scope() -> Session:
|
||||||
"""上下文管理器用于自动获取 Session, 避免错误"""
|
"""上下文管理器用于自动获取 Session, 避免错误"""
|
||||||
session = SessionLocal()
|
session = SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,12 @@ from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
||||||
list_files_from_folder,files2docs_in_thread,
|
list_files_from_folder,files2docs_in_thread,
|
||||||
KnowledgeFile,)
|
KnowledgeFile,)
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.db.repository.knowledge_file_repository import add_file_to_db
|
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
|
||||||
from server.db.base import Base, engine
|
from server.db.base import Base, engine
|
||||||
|
from server.db.session import session_scope
|
||||||
import os
|
import os
|
||||||
from typing import Literal, Any, List
|
from dateutil.parser import parse
|
||||||
|
from typing import Literal, List
|
||||||
|
|
||||||
|
|
||||||
def create_tables():
|
def create_tables():
|
||||||
|
|
@ -20,6 +22,45 @@ def reset_tables():
|
||||||
create_tables()
|
create_tables()
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_db(
|
||||||
|
sqlite_path: str = None,
|
||||||
|
# csv_path: str = None,
|
||||||
|
) -> bool:
|
||||||
|
'''
|
||||||
|
在知识库与向量库无变化的情况下,从备份数据库中导入数据到 info.db。
|
||||||
|
适用于版本升级时,info.db 结构变化,但无需重新向量化的情况。
|
||||||
|
请确保两边数据库表名一致,需要导入的字段名一致
|
||||||
|
当前仅支持 sqlite
|
||||||
|
'''
|
||||||
|
import sqlite3 as sql
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
models = list(Base.registry.mappers)
|
||||||
|
|
||||||
|
try:
|
||||||
|
con = sql.connect(sqlite_path)
|
||||||
|
con.row_factory = sql.Row
|
||||||
|
cur = con.cursor()
|
||||||
|
tables = [x["name"] for x in cur.execute("select name from sqlite_master where type='table'").fetchall()]
|
||||||
|
for model in models:
|
||||||
|
table = model.local_table.fullname
|
||||||
|
if table not in tables:
|
||||||
|
continue
|
||||||
|
print(f"processing table: {table}")
|
||||||
|
with session_scope() as session:
|
||||||
|
for row in cur.execute(f"select * from {table}").fetchall():
|
||||||
|
data = {k: row[k] for k in row.keys() if k in model.columns}
|
||||||
|
if "create_time" in data:
|
||||||
|
data["create_time"] = parse(data["create_time"])
|
||||||
|
pprint(data)
|
||||||
|
session.add(model.class_(**data))
|
||||||
|
con.close()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"无法读取备份数据库:{sqlite_path}。错误信息:{e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
|
def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
|
||||||
kb_files = []
|
kb_files = []
|
||||||
for file in files:
|
for file in files:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue