update knowledge base db orm:

1. set default values for file_count, file_version, create_time.
2. fix bug: add_doc_to_db
3. make add_kb_to_db more flexiable with existing kb
This commit is contained in:
liunux4odoo 2023-08-08 21:47:20 +08:00
parent 2b0f8caa62
commit 135af5f3ff
5 changed files with 21 additions and 14 deletions

View File

@ -1,6 +1,6 @@
from sqlalchemy import Column, Integer, String, DateTime
from server.db.base import Base
from datetime import datetime
class KnowledgeBaseModel(Base):
@ -12,8 +12,8 @@ class KnowledgeBaseModel(Base):
kb_name = Column(String, comment='知识库名称')
vs_type = Column(String, comment='嵌入模型类型')
embed_model = Column(String, comment='嵌入模型名称')
file_count = Column(Integer, comment='文件数量')
create_time = Column(DateTime, comment='创建时间')
file_count = Column(Integer, comment='文件数量', default=0)
create_time = Column(DateTime, comment='创建时间', default=datetime.now)
def __repr__(self):
return f"<KnowledgeBase(id='{self.id}', kb_name='{self.kb_name}', vs_type='{self.vs_type}', embed_model='{self.embed_model}', file_count='{self.file_count}', create_time='{self.create_time}')>"

View File

@ -1,7 +1,6 @@
from sqlalchemy import Column, Integer, String, DateTime
from server.db.base import Base
from datetime import datetime
class KnowledgeFileModel(Base):
"""
@ -14,8 +13,8 @@ class KnowledgeFileModel(Base):
kb_name = Column(String, comment='所属知识库名称')
document_loader_name = Column(String, comment='文档加载器名称')
text_splitter_name = Column(String, comment='文本分割器名称')
file_version = Column(Integer, comment='文件版本')
create_time = Column(DateTime, comment='创建时间')
file_version = Column(Integer, comment='文件版本', default=1)
create_time = Column(DateTime, comment='创建时间', default=datetime.now)
def __repr__(self):
return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>"

View File

@ -5,6 +5,8 @@ from server.db.session import with_session
@with_session
def add_kb_to_db(session, kb_name, vs_type, embed_model):
# 创建知识库实例
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
if not kb:
kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model)
session.add(kb)
return True

View File

@ -20,10 +20,19 @@ def add_doc_to_db(session, kb_file: KnowledgeFile):
kb_name=kb_file.kb_name).first()
if existing_file:
existing_file.file_version += 1
session.add(existing_file)
# 否则,添加新文件
else:
session.add(kb_file)
new_file = KnowledgeFileModel(
file_name=kb_file.filename,
file_ext=kb_file.ext,
kb_name=kb_file.kb_name,
document_loader_name=kb_file.document_loader_name,
text_splitter_name=kb_file.text_splitter_name,
)
kb.file_count += 1
session.add(new_file)
session.add(kb)
return True

View File

@ -8,7 +8,6 @@ import json
from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.kb_service.base import SupportedVSType
from server.knowledge_base.kb_service.faiss_kb_service import refresh_vs_cache
from typing import Union
@ -81,7 +80,6 @@ async def delete_doc(knowledge_base_name: str,
async def update_doc():
# TODO: 替换文件
# refresh_vs_cache(knowledge_base_name)
pass
@ -109,6 +107,7 @@ async def recreate_vector_store(
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
async def output(kb):
kb.create_kb()
kb.clear_vs()
print(f"start to recreate vector store of {kb.kb_name}")
docs = list_docs_from_folder(knowledge_base_name)
@ -126,7 +125,5 @@ async def recreate_vector_store(
kb.add_doc(kb_file)
except ValueError as e:
print(e)
if kb.vs_type == SupportedVSType.FAISS:
refresh_vs_cache(knowledge_base_name)
return StreamingResponse(output(kb), media_type="text/event-stream")