Langchain-Chatchat/server/knowledge_base/kb_api.py

66 lines
2.6 KiB
Python
Raw Normal View History

2023-07-27 23:22:07 +08:00
import urllib
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_base_repository import list_kbs_from_db
2023-09-08 20:48:31 +08:00
from configs.model_config import EMBEDDING_MODEL, logger, log_verbose
from fastapi import Body
2023-07-27 23:22:07 +08:00
def list_kbs():
2023-07-27 23:22:07 +08:00
# Get List of Knowledge Base
return ListResponse(data=list_kbs_from_db())
2023-07-27 23:22:07 +08:00
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse:
2023-07-27 23:22:07 +08:00
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
if knowledge_base_name is None or knowledge_base_name.strip() == "":
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is not None:
2023-07-27 23:22:07 +08:00
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
try:
kb.create_kb()
except Exception as e:
msg = f"创建知识库出错: {e}"
2023-09-08 20:48:31 +08:00
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
2023-07-27 23:22:07 +08:00
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse:
2023-07-27 23:22:07 +08:00
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
2023-07-27 23:22:07 +08:00
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
try:
status = kb.clear_vs()
status = kb.drop_kb()
if status:
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
msg = f"删除知识库时出现意外: {e}"
2023-09-08 20:48:31 +08:00
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")