reformat kb_doc_api.py

This commit is contained in:
imClumsyPanda 2023-11-06 22:44:50 +08:00
parent fe73ceab15
commit fb32c31a70
1 changed files with 80 additions and 69 deletions

View File

@ -4,16 +4,16 @@ from fastapi import File, Form, Body, Query, UploadFile
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
logger, log_verbose,)
logger, log_verbose, )
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
files2docs_in_thread, KnowledgeFile)
from fastapi.responses import StreamingResponse, FileResponse
from pydantic import Json
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import get_file_detail
from typing import List, Dict
from typing import List
from langchain.docstore.document import Document
@ -21,11 +21,16 @@ class DocumentWithScore(Document):
score: float = None
def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]),
def search_docs(
query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
) -> List[DocumentWithScore]:
score_threshold: float = Body(SCORE_THRESHOLD,
description="知识库匹配相关度阈值取值范围在0-1之间"
"SCORE越小相关度越高"
"取到1相当于不筛选建议设置在0.5左右",
ge=0, le=1),
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return []
@ -52,10 +57,11 @@ def list_files(
def _save_files_in_thread(files: List[UploadFile],
knowledge_base_name: str,
override: bool):
'''
"""
通过多线程将上传的文件保存到对应知识库目录内
生成器返回保存结果{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
'''
"""
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
'''
保存单个文件
@ -117,19 +123,21 @@ def _save_files_in_thread(files: List[UploadFile],
# yield json.dumps(result, ensure_ascii=False)
def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
def upload_docs(
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
docs: Json = Form({}, description="自定义的docs需要转为json字符串", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
docs: Json = Form({}, description="自定义的docs需要转为json字符串",
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
) -> BaseResponse:
"""
API接口上传文件/或向量化
'''
"""
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -168,11 +176,12 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件,
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
def delete_docs(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -202,9 +211,10 @@ def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
def update_info(knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
kb_info:str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
):
def update_info(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
):
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -223,12 +233,13 @@ def update_docs(
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs需要转为json字符串", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
docs: Json = Body({}, description="自定义的docs需要转为json字符串",
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
) -> BaseResponse:
"""
更新知识库文档
'''
"""
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -241,7 +252,7 @@ def update_docs(
# 生成需要加载docs的文件列表
for file_name in file_names:
file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name)
file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name)
# 如果该文件之前使用了自定义docs则根据参数决定略过或覆盖
if file_detail.get("custom_docs") and not override_custom_docs:
continue
@ -289,13 +300,13 @@ def update_docs(
def download_doc(
knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]),
file_name: str = Query(...,description="文件名称", examples=["test.txt"]),
knowledge_base_name: str = Query(..., description="知识库名称", examples=["samples"]),
file_name: str = Query(..., description="文件名称", examples=["test.txt"]),
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
):
'''
):
"""
下载知识库文档
'''
"""
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -338,12 +349,12 @@ def recreate_vector_store(
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
):
'''
"""
recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network.
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
'''
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)