reformat kb_doc_api.py
This commit is contained in:
parent
fe73ceab15
commit
fb32c31a70
|
|
@ -2,18 +2,18 @@ import os
|
||||||
import urllib
|
import urllib
|
||||||
from fastapi import File, Form, Body, Query, UploadFile
|
from fastapi import File, Form, Body, Query, UploadFile
|
||||||
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||||
logger, log_verbose,)
|
logger, log_verbose, )
|
||||||
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
||||||
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
|
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
|
||||||
files2docs_in_thread, KnowledgeFile)
|
files2docs_in_thread, KnowledgeFile)
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
from pydantic import Json
|
from pydantic import Json
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||||
from typing import List, Dict
|
from typing import List
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,11 +21,16 @@ class DocumentWithScore(Document):
|
||||||
score: float = None
|
score: float = None
|
||||||
|
|
||||||
|
|
||||||
def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]),
|
def search_docs(
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
) -> List[DocumentWithScore]:
|
score_threshold: float = Body(SCORE_THRESHOLD,
|
||||||
|
description="知识库匹配相关度阈值,取值范围在0-1之间,"
|
||||||
|
"SCORE越小,相关度越高,"
|
||||||
|
"取到1相当于不筛选,建议设置在0.5左右",
|
||||||
|
ge=0, le=1),
|
||||||
|
) -> List[DocumentWithScore]:
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
if kb is None:
|
if kb is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -35,7 +40,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
||||||
|
|
||||||
|
|
||||||
def list_files(
|
def list_files(
|
||||||
knowledge_base_name: str
|
knowledge_base_name: str
|
||||||
) -> ListResponse:
|
) -> ListResponse:
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||||
|
|
@ -50,12 +55,13 @@ def list_files(
|
||||||
|
|
||||||
|
|
||||||
def _save_files_in_thread(files: List[UploadFile],
|
def _save_files_in_thread(files: List[UploadFile],
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
override: bool):
|
override: bool):
|
||||||
'''
|
"""
|
||||||
通过多线程将上传的文件保存到对应知识库目录内。
|
通过多线程将上传的文件保存到对应知识库目录内。
|
||||||
生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
|
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
|
||||||
'''
|
'''
|
||||||
保存单个文件。
|
保存单个文件。
|
||||||
|
|
@ -67,8 +73,8 @@ def _save_files_in_thread(files: List[UploadFile],
|
||||||
|
|
||||||
file_content = file.file.read() # 读取上传文件的内容
|
file_content = file.file.read() # 读取上传文件的内容
|
||||||
if (os.path.isfile(file_path)
|
if (os.path.isfile(file_path)
|
||||||
and not override
|
and not override
|
||||||
and os.path.getsize(file_path) == len(file_content)
|
and os.path.getsize(file_path) == len(file_content)
|
||||||
):
|
):
|
||||||
# TODO: filesize 不同后的处理
|
# TODO: filesize 不同后的处理
|
||||||
file_status = f"文件 {filename} 已存在。"
|
file_status = f"文件 {filename} 已存在。"
|
||||||
|
|
@ -117,19 +123,21 @@ def _save_files_in_thread(files: List[UploadFile],
|
||||||
# yield json.dumps(result, ensure_ascii=False)
|
# yield json.dumps(result, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
def upload_docs(
|
||||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||||
override: bool = Form(False, description="覆盖已有文件"),
|
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||||
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
|
override: bool = Form(False, description="覆盖已有文件"),
|
||||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
|
||||||
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||||
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||||
docs: Json = Form({}, description="自定义的docs,需要转为json字符串", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
docs: Json = Form({}, description="自定义的docs,需要转为json字符串",
|
||||||
) -> BaseResponse:
|
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||||
'''
|
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
|
) -> BaseResponse:
|
||||||
|
"""
|
||||||
API接口:上传文件,并/或向量化
|
API接口:上传文件,并/或向量化
|
||||||
'''
|
"""
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
|
|
@ -168,11 +176,12 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件,
|
||||||
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
|
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
|
||||||
|
|
||||||
|
|
||||||
def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
|
def delete_docs(
|
||||||
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
delete_content: bool = Body(False),
|
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
|
||||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
delete_content: bool = Body(False),
|
||||||
) -> BaseResponse:
|
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
|
) -> BaseResponse:
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
|
|
@ -202,9 +211,10 @@ def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
|
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
|
||||||
|
|
||||||
|
|
||||||
def update_info(knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
def update_info(
|
||||||
kb_info:str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
):
|
kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
|
||||||
|
):
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
|
|
@ -217,18 +227,19 @@ def update_info(knowledge_base_name: str = Body(..., description="知识库名
|
||||||
|
|
||||||
|
|
||||||
def update_docs(
|
def update_docs(
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
|
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
|
||||||
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||||
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||||
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||||
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||||||
docs: Json = Body({}, description="自定义的docs,需要转为json字符串", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
docs: Json = Body({}, description="自定义的docs,需要转为json字符串",
|
||||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||||
) -> BaseResponse:
|
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
'''
|
) -> BaseResponse:
|
||||||
|
"""
|
||||||
更新知识库文档
|
更新知识库文档
|
||||||
'''
|
"""
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
|
|
@ -241,7 +252,7 @@ def update_docs(
|
||||||
|
|
||||||
# 生成需要加载docs的文件列表
|
# 生成需要加载docs的文件列表
|
||||||
for file_name in file_names:
|
for file_name in file_names:
|
||||||
file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name)
|
file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name)
|
||||||
# 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖
|
# 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖
|
||||||
if file_detail.get("custom_docs") and not override_custom_docs:
|
if file_detail.get("custom_docs") and not override_custom_docs:
|
||||||
continue
|
continue
|
||||||
|
|
@ -289,13 +300,13 @@ def update_docs(
|
||||||
|
|
||||||
|
|
||||||
def download_doc(
|
def download_doc(
|
||||||
knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]),
|
knowledge_base_name: str = Query(..., description="知识库名称", examples=["samples"]),
|
||||||
file_name: str = Query(...,description="文件名称", examples=["test.txt"]),
|
file_name: str = Query(..., description="文件名称", examples=["test.txt"]),
|
||||||
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
|
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
|
||||||
):
|
):
|
||||||
'''
|
"""
|
||||||
下载知识库文档
|
下载知识库文档
|
||||||
'''
|
"""
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
|
|
@ -329,21 +340,21 @@ def download_doc(
|
||||||
|
|
||||||
|
|
||||||
def recreate_vector_store(
|
def recreate_vector_store(
|
||||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
allow_empty_kb: bool = Body(True),
|
allow_empty_kb: bool = Body(True),
|
||||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||||
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||||
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
):
|
):
|
||||||
'''
|
"""
|
||||||
recreate vector store from the content.
|
recreate vector store from the content.
|
||||||
this is usefull when user can copy files to content folder directly instead of upload through network.
|
this is usefull when user can copy files to content folder directly instead of upload through network.
|
||||||
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
|
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
|
||||||
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def output():
|
def output():
|
||||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||||
|
|
@ -357,9 +368,9 @@ def recreate_vector_store(
|
||||||
kb_files = [(file, knowledge_base_name) for file in files]
|
kb_files = [(file, knowledge_base_name) for file in files]
|
||||||
i = 0
|
i = 0
|
||||||
for status, result in files2docs_in_thread(kb_files,
|
for status, result in files2docs_in_thread(kb_files,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
zh_title_enhance=zh_title_enhance):
|
zh_title_enhance=zh_title_enhance):
|
||||||
if status:
|
if status:
|
||||||
kb_name, file_name, docs = result
|
kb_name, file_name, docs = result
|
||||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue