From 135af5f3ff353e7bb92af522e3308adbb9b288db Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Tue, 8 Aug 2023 21:47:20 +0800 Subject: [PATCH] update knowledge base db orm: 1. set default values for file_count, file_version, create_time. 2. fix bug: add_doc_to_db 3. make add_kb_to_db more flexiable with existing kb --- server/db/models/knowledge_base_model.py | 6 +++--- server/db/models/knowledge_file_model.py | 7 +++---- server/db/repository/knowledge_base_repository.py | 6 ++++-- server/db/repository/knowledge_file_repository.py | 11 ++++++++++- server/knowledge_base/kb_doc_api.py | 5 +---- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/server/db/models/knowledge_base_model.py b/server/db/models/knowledge_base_model.py index 6f9e9ca..b59d9a6 100644 --- a/server/db/models/knowledge_base_model.py +++ b/server/db/models/knowledge_base_model.py @@ -1,6 +1,6 @@ from sqlalchemy import Column, Integer, String, DateTime - from server.db.base import Base +from datetime import datetime class KnowledgeBaseModel(Base): @@ -12,8 +12,8 @@ class KnowledgeBaseModel(Base): kb_name = Column(String, comment='知识库名称') vs_type = Column(String, comment='嵌入模型类型') embed_model = Column(String, comment='嵌入模型名称') - file_count = Column(Integer, comment='文件数量') - create_time = Column(DateTime, comment='创建时间') + file_count = Column(Integer, comment='文件数量', default=0) + create_time = Column(DateTime, comment='创建时间', default=datetime.now) def __repr__(self): return f"" diff --git a/server/db/models/knowledge_file_model.py b/server/db/models/knowledge_file_model.py index b6798cc..43aba11 100644 --- a/server/db/models/knowledge_file_model.py +++ b/server/db/models/knowledge_file_model.py @@ -1,7 +1,6 @@ from sqlalchemy import Column, Integer, String, DateTime - from server.db.base import Base - +from datetime import datetime class KnowledgeFileModel(Base): """ @@ -14,8 +13,8 @@ class KnowledgeFileModel(Base): kb_name = Column(String, comment='所属知识库名称') document_loader_name = Column(String, comment='文档加载器名称') text_splitter_name = Column(String, comment='文本分割器名称') - file_version = Column(Integer, comment='文件版本') - create_time = Column(DateTime, comment='创建时间') + file_version = Column(Integer, comment='文件版本', default=1) + create_time = Column(DateTime, comment='创建时间', default=datetime.now) def __repr__(self): return f"" diff --git a/server/db/repository/knowledge_base_repository.py b/server/db/repository/knowledge_base_repository.py index aa4df4f..3f50c3d 100644 --- a/server/db/repository/knowledge_base_repository.py +++ b/server/db/repository/knowledge_base_repository.py @@ -5,8 +5,10 @@ from server.db.session import with_session @with_session def add_kb_to_db(session, kb_name, vs_type, embed_model): # 创建知识库实例 - kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model) - session.add(kb) + kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() + if not kb: + kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model) + session.add(kb) return True diff --git a/server/db/repository/knowledge_file_repository.py b/server/db/repository/knowledge_file_repository.py index cc7a284..cdb5c78 100644 --- a/server/db/repository/knowledge_file_repository.py +++ b/server/db/repository/knowledge_file_repository.py @@ -20,10 +20,19 @@ def add_doc_to_db(session, kb_file: KnowledgeFile): kb_name=kb_file.kb_name).first() if existing_file: existing_file.file_version += 1 + session.add(existing_file) # 否则,添加新文件 else: - session.add(kb_file) + new_file = KnowledgeFileModel( + file_name=kb_file.filename, + file_ext=kb_file.ext, + kb_name=kb_file.kb_name, + document_loader_name=kb_file.document_loader_name, + text_splitter_name=kb_file.text_splitter_name, + ) kb.file_count += 1 + session.add(new_file) + session.add(kb) return True diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index a995014..90c3deb 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -8,7 +8,6 @@ import json from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import SupportedVSType -from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache from typing import Union @@ -81,7 +80,6 @@ async def delete_doc(knowledge_base_name: str, async def update_doc(): # TODO: 替换文件 - # refresh_vs_cache(knowledge_base_name) pass @@ -109,6 +107,7 @@ async def recreate_vector_store( return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") async def output(kb): + kb.create_kb() kb.clear_vs() print(f"start to recreate vector store of {kb.kb_name}") docs = list_docs_from_folder(knowledge_base_name) @@ -126,7 +125,5 @@ async def recreate_vector_store( kb.add_doc(kb_file) except ValueError as e: print(e) - if kb.vs_type == SupportedVSType.FAISS: - refresh_vs_cache(knowledge_base_name) return StreamingResponse(output(kb), media_type="text/event-stream")