Langchain-Chatchat/server/db/repository/knowledge_file_repository.py

59 lines
2.2 KiB
Python
Raw Normal View History

2023-08-08 14:25:55 +08:00
from server.db.models.knowledge_base_model import KnowledgeBaseModel
from server.db.models.knowledge_file_model import KnowledgeFileModel
from server.db.session import with_session
from server.knowledge_base.utils import KnowledgeFile
2023-08-08 14:25:55 +08:00
@with_session
def list_docs_from_db(session, kb_name):
files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all()
docs = [f.file_name for f in files]
return docs
@with_session
def add_doc_to_db(session, kb_file: KnowledgeFile):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb:
# 如果已经存在该文件,则更新文件版本号
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
if existing_file:
existing_file.file_version += 1
session.add(existing_file)
2023-08-08 14:25:55 +08:00
# 否则,添加新文件
else:
new_file = KnowledgeFileModel(
file_name=kb_file.filename,
file_ext=kb_file.ext,
kb_name=kb_file.kb_name,
document_loader_name=kb_file.document_loader_name,
text_splitter_name=kb_file.text_splitter_name,
)
2023-08-08 14:25:55 +08:00
kb.file_count += 1
session.add(new_file)
session.add(kb)
2023-08-08 14:25:55 +08:00
return True
@with_session
def delete_file_from_db(session, kb_file: KnowledgeFile):
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
if existing_file:
session.delete(existing_file)
session.commit()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb:
kb.file_count -= 1
session.commit()
return True
@with_session
def doc_exists(session, kb_file: KnowledgeFile):
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
return True if existing_file else False