diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index abad0cc..b339b6f 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -2,18 +2,18 @@ import os import urllib from fastapi import File, Form, Body, Query, UploadFile from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, - VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, - logger, log_verbose,) + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, + logger, log_verbose, ) from server.utils import BaseResponse, ListResponse, run_in_thread_pool -from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path, - files2docs_in_thread, KnowledgeFile) +from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path, + files2docs_in_thread, KnowledgeFile) from fastapi.responses import StreamingResponse, FileResponse from pydantic import Json import json from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_file_repository import get_file_detail -from typing import List, Dict +from typing import List from langchain.docstore.document import Document @@ -21,11 +21,16 @@ class DocumentWithScore(Document): score: float = None -def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]), - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), - score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), - ) -> List[DocumentWithScore]: +def search_docs( + query: str = Body(..., description="用户输入", examples=["你好"]), + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, + description="知识库匹配相关度阈值,取值范围在0-1之间," + "SCORE越小,相关度越高," + "取到1相当于不筛选,建议设置在0.5左右", + ge=0, le=1), +) -> List[DocumentWithScore]: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return [] @@ -35,7 +40,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" def list_files( - knowledge_base_name: str + knowledge_base_name: str ) -> ListResponse: if not validate_kb_name(knowledge_base_name): return ListResponse(code=403, msg="Don't attack me", data=[]) @@ -50,12 +55,13 @@ def list_files( def _save_files_in_thread(files: List[UploadFile], - knowledge_base_name: str, - override: bool): - ''' + knowledge_base_name: str, + override: bool): + """ 通过多线程将上传的文件保存到对应知识库目录内。 生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} - ''' + """ + def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict: ''' 保存单个文件。 @@ -67,8 +73,8 @@ def _save_files_in_thread(files: List[UploadFile], file_content = file.file.read() # 读取上传文件的内容 if (os.path.isfile(file_path) - and not override - and os.path.getsize(file_path) == len(file_content) + and not override + and os.path.getsize(file_path) == len(file_content) ): # TODO: filesize 不同后的处理 file_status = f"文件 {filename} 已存在。" @@ -117,19 +123,21 @@ def _save_files_in_thread(files: List[UploadFile], # yield json.dumps(result, ensure_ascii=False) -def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), - knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), - override: bool = Form(False, description="覆盖已有文件"), - to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), - chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), - chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), - zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), - docs: Json = Form({}, description="自定义的docs,需要转为json字符串", examples=[{"test.txt": [Document(page_content="custom doc")]}]), - not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), - ) -> BaseResponse: - ''' +def upload_docs( + files: List[UploadFile] = File(..., description="上传文件,支持多文件"), + knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), + override: bool = Form(False, description="覆盖已有文件"), + to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), + chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + docs: Json = Form({}, description="自定义的docs,需要转为json字符串", + examples=[{"test.txt": [Document(page_content="custom doc")]}]), + not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), +) -> BaseResponse: + """ API接口:上传文件,并/或向量化 - ''' + """ if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -168,11 +176,12 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件, return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files}) -def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]), - file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), - delete_content: bool = Body(False), - not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), - ) -> BaseResponse: +def delete_docs( + knowledge_base_name: str = Body(..., examples=["samples"]), + file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), + delete_content: bool = Body(False), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), +) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -202,9 +211,10 @@ def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]), return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files}) -def update_info(knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - kb_info:str = Body(..., description="知识库介绍", examples=["这是一个知识库"]), - ): +def update_info( + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]), +): if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -217,18 +227,19 @@ def update_info(knowledge_base_name: str = Body(..., description="知识库名 def update_docs( - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]), - chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), - chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), - zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), - override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), - docs: Json = Body({}, description="自定义的docs,需要转为json字符串", examples=[{"test.txt": [Document(page_content="custom doc")]}]), - not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), - ) -> BaseResponse: - ''' + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]), + chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), + docs: Json = Body({}, description="自定义的docs,需要转为json字符串", + examples=[{"test.txt": [Document(page_content="custom doc")]}]), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), +) -> BaseResponse: + """ 更新知识库文档 - ''' + """ if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -241,7 +252,7 @@ def update_docs( # 生成需要加载docs的文件列表 for file_name in file_names: - file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name) + file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name) # 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖 if file_detail.get("custom_docs") and not override_custom_docs: continue @@ -289,13 +300,13 @@ def update_docs( def download_doc( - knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]), - file_name: str = Query(...,description="文件名称", examples=["test.txt"]), - preview: bool = Query(False, description="是:浏览器内预览;否:下载"), - ): - ''' + knowledge_base_name: str = Query(..., description="知识库名称", examples=["samples"]), + file_name: str = Query(..., description="文件名称", examples=["test.txt"]), + preview: bool = Query(False, description="是:浏览器内预览;否:下载"), +): + """ 下载知识库文档 - ''' + """ if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -329,21 +340,21 @@ def download_doc( def recreate_vector_store( - knowledge_base_name: str = Body(..., examples=["samples"]), - allow_empty_kb: bool = Body(True), - vs_type: str = Body(DEFAULT_VS_TYPE), - embed_model: str = Body(EMBEDDING_MODEL), - chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), - chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), - zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), - not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), + knowledge_base_name: str = Body(..., examples=["samples"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(EMBEDDING_MODEL), + chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ): - ''' + """ recreate vector store from the content. this is usefull when user can copy files to content folder directly instead of upload through network. by default, get_service_by_name only return knowledge base in the info.db and having document files in it. set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. - ''' + """ def output(): kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) @@ -357,9 +368,9 @@ def recreate_vector_store( kb_files = [(file, knowledge_base_name) for file in files] i = 0 for status, result in files2docs_in_thread(kb_files, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance): + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance): if status: kb_name, file_name, docs = result kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)