From ce1001a043cdcad160992e543e24d9350ebba132 Mon Sep 17 00:00:00 2001 From: WilliamChen-luckbob <58684828+WilliamChen-luckbob@users.noreply.github.com> Date: Thu, 9 Nov 2023 17:45:21 +0800 Subject: [PATCH] =?UTF-8?q?bugfix:dev=E5=88=86=E6=94=AF=E5=88=9B=E5=BB=BA?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E5=BF=85=E7=84=B6=E5=A4=B1=E8=B4=A5?= =?UTF-8?q?=E7=9A=84bug=E4=BF=AE=E5=A4=8D=20(#1980)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bugfix:dev分支创建知识库必然失败的bug修复 * 统一 KBServiceFactory.get_service_by_name 的逻辑,数据库中不存在知识库时返回 None --------- Co-authored-by: liunux4odoo --- server/knowledge_base/kb_service/base.py | 14 ++++++-------- server/knowledge_base/migrate.py | 4 ++-- tests/test_migrate.py | 2 +- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index a81474f..d823740 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -278,15 +278,10 @@ class KBServiceFactory: return DefaultKBService(kb_name) @staticmethod - def get_service_by_name(kb_name: str, - default_vs_type: SupportedVSType = SupportedVSType.FAISS, - default_embed_model: str = EMBEDDING_MODEL, - ) -> KBService: + def get_service_by_name(kb_name: str) -> KBService: _, vs_type, embed_model = load_kb_from_db(kb_name) - if vs_type is None: # faiss knowledge base not in db - vs_type = default_vs_type - if embed_model is None: - embed_model = default_embed_model + if _ is None: # kb not in db, just return None + return None return KBServiceFactory.get_service(kb_name, vs_type, embed_model) @staticmethod @@ -331,6 +326,9 @@ def get_kb_details() -> List[Dict]: def get_kb_file_details(kb_name: str) -> List[Dict]: kb = KBServiceFactory.get_service_by_name(kb_name) + if kb is None: + return [] + files_in_folder = list_files_from_folder(kb_name) files_in_db = kb.list_files() result = {} diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 146dc77..0c21946 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -157,7 +157,7 @@ def prune_db_docs(kb_names: List[str]): """ for kb_name in kb_names: kb = KBServiceFactory.get_service_by_name(kb_name) - if kb and kb.exists(): + if kb is not None: files_in_db = kb.list_files() files_in_folder = list_files_from_folder(kb_name) files = list(set(files_in_db) - set(files_in_folder)) @@ -175,7 +175,7 @@ def prune_folder_files(kb_names: List[str]): """ for kb_name in kb_names: kb = KBServiceFactory.get_service_by_name(kb_name) - if kb and kb.exists(): + if kb is not None: files_in_db = kb.list_files() files_in_folder = list_files_from_folder(kb_name) files = list(set(files_in_folder) - set(files_in_db)) diff --git a/tests/test_migrate.py b/tests/test_migrate.py index d694b02..7195dd3 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -33,7 +33,7 @@ def test_recreate_vs(): folder2db([kb_name], "recreate_vs") kb = KBServiceFactory.get_service_by_name(kb_name) - assert kb.exists() + assert kb and kb.exists() files = kb.list_files() print(files)