merge pr1413

This commit is contained in:
liunux4odoo 2023-09-08 10:30:07 +08:00
commit 1195eb75eb
11 changed files with 730 additions and 309 deletions

View File

@ -15,8 +15,8 @@ from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store,
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
import httpx
@ -98,23 +98,23 @@ def create_app():
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_doc",
app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库"
)(upload_doc)
summary="上传文件到知识库,并/或进行向量化"
)(upload_docs)
app.post("/knowledge_base/delete_doc",
app.post("/knowledge_base/delete_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内指定文件"
)(delete_doc)
)(delete_docs)
app.post("/knowledge_base/update_doc",
app.post("/knowledge_base/update_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新现有文件到知识库"
)(update_doc)
)(update_docs)
app.get("/knowledge_base/download_doc",
tags=["Knowledge Base Management"],

View File

@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_base_repository import list_kbs_from_db
from configs.model_config import EMBEDDING_MODEL
from configs.model_config import EMBEDDING_MODEL, logger
from fastapi import Body
@ -30,8 +30,9 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
try:
kb.create_kb()
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
msg = f"创建知识库出错: {e}"
logger.error(msg)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
@ -55,7 +56,8 @@ async def delete_kb(
if status:
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
msg = f"删除知识库时出现意外: {e}"
logger.error(msg)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")

View File

@ -1,12 +1,17 @@
import os
import urllib
from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
logger,)
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
files2docs_in_thread, KnowledgeFile)
from fastapi.responses import StreamingResponse, FileResponse
from pydantic import Json
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import get_file_detail
from typing import List, Dict
from langchain.docstore.document import Document
@ -44,11 +49,83 @@ async def list_files(
return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
def _save_files_in_thread(files: List[UploadFile],
knowledge_base_name: str,
override: bool):
'''
通过多线程将上传的文件保存到对应知识库目录内
生成器返回保存结果{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
'''
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
'''
保存单个文件
'''
try:
filename = file.filename
file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename)
data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}
file_content = file.file.read() # 读取上传文件的内容
if (os.path.isfile(file_path)
and not override
and os.path.getsize(file_path) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {filename} 已存在。"
logger.warn(file_status)
return dict(code=404, msg=file_status, data=data)
with open(file_path, "wb") as f:
f.write(file_content)
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
except Exception as e:
msg = f"{filename} 文件上传失败,报错信息为: {e}"
logger.error(msg)
return dict(code=500, msg=msg, data=data)
params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files]
for result in run_in_thread_pool(save_file, params=params):
yield result
# 似乎没有单独增加一个文件上传API接口的必要
# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
# override: bool = Form(False, description="覆盖已有文件")):
# '''
# API接口上传文件。流式返回保存结果{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
# '''
# def generate(files, knowledge_base_name, override):
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
# yield json.dumps(result, ensure_ascii=False)
# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream")
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
# override: bool = Form(False, description="覆盖已有文件"),
# save: bool = Form(True, description="是否将文件保存到知识库目录")):
# def save_files(files, knowledge_base_name, override):
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
# yield json.dumps(result, ensure_ascii=False)
# def files_to_docs(files):
# for result in files2docs_in_thread(files):
# yield json.dumps(result, ensure_ascii=False)
async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
API接口上传文件/或向量化
'''
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -56,37 +133,36 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
file_content = await file.read() # 读取上传文件的内容
failed_files = {}
file_names = list(docs.keys())
try:
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
# 先将上传的文件保存到磁盘
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
filename = result["data"]["file_name"]
if result["code"] != 200:
failed_files[filename] = result["msg"]
if filename not in file_names:
file_names.append(filename)
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
# 对保存的文件进行向量化
if to_vector_store:
result = await update_docs(
knowledge_base_name=knowledge_base_name,
file_names=file_names,
override_custom_docs=True,
docs=docs,
not_refresh_vs_cache=True,
)
failed_files.update(result.data["failed_files"])
if not not_refresh_vs_cache:
kb.save_vector_store()
with open(kb_file.filepath, "wb") as f:
f.write(file_content)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
try:
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name.md"]),
async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
@ -98,23 +174,31 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
failed_files = {}
for file_name in file_names:
if not kb.exist_doc(file_name):
failed_files[file_name] = f"未找到文件 {file_name}"
try:
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
except Exception as e:
msg = f"{file_name} 文件删除失败,错误信息:{e}"
logger.error(msg)
failed_files[file_name] = msg
if not not_refresh_vs_cache:
kb.save_vector_store()
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
async def update_doc(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]),
async def update_docs(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
@ -127,22 +211,57 @@ async def update_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
failed_files = {}
kb_files = []
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
# 生成需要加载docs的文件列表
for file_name in file_names:
file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name)
# 如果该文件之前使用了自定义docs则根据参数决定略过或覆盖
if file_detail.get("custom_docs") and not override_custom_docs:
continue
if file_name not in docs:
try:
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
except Exception as e:
msg = f"加载文档 {file_name} 时出错:{e}"
logger.error(msg)
failed_files[file_name] = msg
# 从文件生成docs并进行向量化。
# 这里利用了KnowledgeFile的缓存功能在多线程中加载Document然后传给KnowledgeFile
for status, result in files2docs_in_thread(kb_files):
if status:
kb_name, file_name, new_docs = result
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
kb_file.splited_docs = new_docs
kb.update_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
failed_files[file_name] = error
# 将自定义的docs进行向量化
for file_name, v in docs.items():
try:
v = [x if isinstance(x, Document) else Document(**x) for x in v]
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)
kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
except Exception as e:
msg = f"{file_name} 添加自定义docs时出错{e}"
logger.error(msg)
failed_files[file_name] = msg
if not not_refresh_vs_cache:
kb.save_vector_store()
return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})
async def download_doc(
knowledge_base_name: str = Query(..., examples=["samples"]),
file_name: str = Query(..., examples=["test.txt"]),
knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]),
file_name: str = Query(...,description="文件名称", examples=["test.txt"]),
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
):
'''
下载知识库文档
@ -154,6 +273,11 @@ async def download_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
if preview:
content_disposition_type = "inline"
else:
content_disposition_type = None
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
@ -162,10 +286,13 @@ async def download_doc(
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
media_type="multipart/form-data",
content_disposition_type=content_disposition_type,
)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}"
logger.error(msg)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
@ -190,27 +317,30 @@ async def recreate_vector_store(
else:
kb.create_kb()
kb.clear_vs()
docs = list_files_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
files = list_files_from_folder(knowledge_base_name)
kb_files = [(file, knowledge_base_name) for file in files]
i = 0
for status, result in files2docs_in_thread(kb_files):
if status:
kb_name, file_name, docs = result
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
kb_file.splited_docs = docs
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(docs)}): {doc}",
"total": len(docs),
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i,
"doc": doc,
"doc": file_name,
}, ensure_ascii=False)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
kb.add_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
"msg": msg,
})
i += 1
return StreamingResponse(output(), media_type="text/event-stream")

View File

@ -51,6 +51,13 @@ class KBService(ABC):
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def save_vector_store(self, vector_store=None):
'''
保存向量库仅支持FAISS对于其它向量库该函数不做任何操作
减少FAISS向量库操作时的类型判断
'''
pass
def create_kb(self):
"""
创建知识库
@ -84,6 +91,8 @@ class KBService(ABC):
"""
if docs:
custom_docs = True
for doc in docs:
doc.metadata.setdefault("source", kb_file.filepath)
else:
docs = kb_file.file2text()
custom_docs = False

View File

@ -5,7 +5,8 @@ from configs.model_config import (
KB_ROOT_PATH,
CACHED_VS_NUM,
EMBEDDING_MODEL,
SCORE_THRESHOLD
SCORE_THRESHOLD,
logger,
)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from functools import lru_cache
@ -28,7 +29,7 @@ def load_faiss_vector_store(
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
) -> FAISS:
print(f"loading vector store in '{knowledge_base_name}'.")
logger.info(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
@ -57,7 +58,7 @@ def refresh_vs_cache(kb_name: str):
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
class FaissKBService(KBService):
@ -133,7 +134,7 @@ class FaissKBService(KBService):
**kwargs):
vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
if len(ids) == 0:
return None

View File

@ -7,16 +7,20 @@ from configs.model_config import (
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE
ZH_TITLE_ENHANCE,
logger,
)
from functools import lru_cache
import importlib
from text_splitter import zh_title_enhance
import langchain.document_loaders
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from pathlib import Path
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool
import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
@ -173,12 +177,74 @@ def get_LoaderClass(file_extension):
return LoaderClass
# 把一些向量化共用逻辑从KnowledgeFile抽取出来等langchain支持内存文件的时候可以将非磁盘文件向量化
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
'''
根据loader_name和文件路径或内容返回文档加载器
'''
try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, loader_name)
except Exception as e:
logger.error(f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}")
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
elif loader_name == "CSVLoader":
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
elif loader_name == "JSONLoader":
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
elif loader_name == "CustomJSONLoader":
loader = DocumentLoader(file_path_or_content, text_content=False)
elif loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
else:
loader = DocumentLoader(file_path_or_content)
return loader
def make_text_splitter(
splitter_name: str = "SpacyTextSplitter",
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
):
'''
根据参数获取特定的分词器
'''
splitter_name = splitter_name or "SpacyTextSplitter"
text_splitter_module = importlib.import_module('langchain.text_splitter')
try:
TextSplitter = getattr(text_splitter_module, splitter_name)
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except Exception as e:
logger.error(f"查找分词器 {splitter_name} 时出错:{e}")
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return text_splitter
class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str
):
'''
对应知识库目录中的文件必须是磁盘上存在的才能进行向量化等操作
'''
self.kb_name = knowledge_base_name
self.filename = filename
self.ext = os.path.splitext(filename)[-1].lower()
@ -186,76 +252,62 @@ class KnowledgeFile:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
self.splited_docs = None
self.document_loader_name = get_LoaderClass(self.ext)
# TODO: 增加依据文件格式匹配text_splitter
self.text_splitter_name = None
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False):
if self.docs is not None and not refresh:
return self.docs
def file2docs(self, refresh: bool=False):
if self.docs is None or refresh:
logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(self.document_loader_name, self.filepath)
self.docs = loader.load()
return self.docs
print(f"{self.document_loader_name} used for {self.filepath}")
try:
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
except Exception as e:
print(e)
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
elif self.document_loader_name == "CSVLoader":
loader = DocumentLoader(self.filepath, encoding="utf-8")
elif self.document_loader_name == "JSONLoader":
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
elif self.document_loader_name == "CustomJSONLoader":
loader = DocumentLoader(self.filepath, text_content=False)
elif self.document_loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(self.filepath, mode="elements")
elif self.document_loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(self.filepath, mode="elements")
else:
loader = DocumentLoader(self.filepath)
def docs2texts(
self,
docs: List[Document] = None,
using_zh_title_enhance=ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
docs = docs or self.file2docs(refresh=refresh)
if not docs:
return []
if self.ext not in [".csv"]:
if text_splitter is None:
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
docs = text_splitter.split_documents(docs)
if self.ext in ".csv":
docs = loader.load()
else:
try:
if self.text_splitter_name is None:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
self.text_splitter_name = "SpacyTextSplitter"
else:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
text_splitter = TextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE)
except Exception as e:
print(e)
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
docs = loader.load_and_split(text_splitter)
print(docs[0])
print(f"文档切分示例:{docs[0]}")
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
self.docs = docs
return docs
self.splited_docs = docs
return self.splited_docs
def file2text(
self,
using_zh_title_enhance=ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
if self.splited_docs is None or refresh:
docs = self.file2docs()
self.splited_docs = self.docs2texts(docs=docs,
using_zh_title_enhance=using_zh_title_enhance,
refresh=refresh,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
text_splitter=text_splitter)
return self.splited_docs
def file_exist(self):
return os.path.isfile(self.filepath)
def get_mtime(self):
return os.path.getmtime(self.filepath)
@ -264,53 +316,47 @@ class KnowledgeFile:
return os.path.getsize(self.filepath)
def run_in_thread_pool(
func: Callable,
params: List[Dict] = [],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
在线程池中批量运行任务并将运行结果以生成器的形式返回
请确保任务中的所有操作是线程安全的任务函数请全部使用关键字参数
'''
tasks = []
if pool is None:
pool = ThreadPoolExecutor()
for kwargs in params:
thread = pool.submit(func, **kwargs)
tasks.append(thread)
for obj in as_completed(tasks):
yield obj.result()
def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
利用多线程批量将文件转化成langchain Document.
生成器返回值为{(kb_name, file_name): docs}
利用多线程批量将磁盘文件转化成langchain Document.
如果传入参数是Tuple形式为(filename, kb_name)
生成器返回值为 status, (kb_name, file_name, docs | error)
'''
def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]:
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
return False, e
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
logger.error(msg)
return False, (file.kb_name, file.filename, msg)
kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
if isinstance(file, tuple) and len(file) >= 2:
files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
elif isinstance(file, dict):
filename = file.pop("filename")
kb_name = file.pop("kb_name")
files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs = file
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs["file"] = file
kwargs_list.append(kwargs)
for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool):
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
yield result
if __name__ == "__main__":
from pprint import pprint
kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
docs = kb_file.file2docs()
pprint(docs[-1])
docs = kb_file.file2text()
pprint(docs[-1])

View File

@ -8,7 +8,11 @@ from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDIN
from configs.server_config import FSCHAT_MODEL_WORKERS
import os
from server import model_workers
from typing import Literal, Optional, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal, Optional, Callable, Generator, Dict, Any
thread_pool = ThreadPoolExecutor()
class BaseResponse(BaseModel):
@ -305,3 +309,24 @@ def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
def run_in_thread_pool(
func: Callable,
params: List[Dict] = [],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
在线程池中批量运行任务并将运行结果以生成器的形式返回
请确保任务中的所有操作是线程安全的任务函数请全部使用关键字参数
'''
tasks = []
pool = pool or thread_pool
for kwargs in params:
thread = pool.submit(func, **kwargs)
tasks.append(thread)
for obj in as_completed(tasks):
yield obj.result()

View File

@ -7,19 +7,23 @@ root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path
from server.knowledge_base.utils import get_kb_path, get_file_path
from pprint import pprint
api_base_url = api_address()
kb = "kb_for_api_test"
test_files = {
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
"README.MD": str(root_path / "README.MD"),
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
"test.txt": get_file_path("samples", "test.txt"),
}
print("\n\n直接url访问\n")
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
if not Path(get_kb_path(kb)).exists():
@ -78,37 +82,36 @@ def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
assert kb in data["data"]
def test_upload_doc(api="/knowledge_base/upload_doc"):
def test_upload_docs(api="/knowledge_base/upload_docs"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n上传知识文件: {name}")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"文件 {name} 已存在。"
print(f"\n上传知识文件")
data = {"knowledge_base_name": kb, "override": True}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 覆盖")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
print(f"\n尝试重新上传知识文件, 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == len(test_files)
print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_list_files(api="/knowledge_base/list_files"):
@ -134,26 +137,26 @@ def test_search_docs(api="/knowledge_base/search_docs"):
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_doc(api="/knowledge_base/update_doc"):
def test_update_docs(api="/knowledge_base/update_docs"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n更新知识文件 {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功更新文件 {name}"
print(f"\n更新知识文件")
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_delete_doc(api="/knowledge_base/delete_doc"):
def test_delete_docs(api="/knowledge_base/delete_docs"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n删除知识文件 {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"{name} 文件删除成功"
print(f"\n删除知识文件")
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
url = api_base_url + "/knowledge_base/search_docs"
query = "介绍一下langchain-chatchat项目"

View File

@ -0,0 +1,161 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest
from pprint import pprint
api_base_url = api_address()
api: ApiRequest = ApiRequest(api_base_url)
kb = "kb_for_api_test"
test_files = {
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
"README.MD": str(root_path / "README.MD"),
"test.txt": get_file_path("samples", "test.txt"),
}
print("\n\nApiRquest调用\n")
def test_delete_kb_before():
if not Path(get_kb_path(kb)).exists():
return
data = api.delete_knowledge_base(kb)
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]
def test_create_kb():
print(f"\n尝试用空名称创建知识库:")
data = api.create_knowledge_base(" ")
pprint(data)
assert data["code"] == 404
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
print(f"\n创建新知识库: {kb}")
data = api.create_knowledge_base(kb)
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"已新增知识库 {kb}"
print(f"\n尝试创建同名知识库: {kb}")
data = api.create_knowledge_base(kb)
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"已存在同名知识库 {kb}"
def test_list_kbs():
data = api.list_knowledge_bases()
pprint(data)
assert isinstance(data, list) and len(data) > 0
assert kb in data
def test_upload_docs():
files = list(test_files.values())
print(f"\n上传知识文件")
data = {"knowledge_base_name": kb, "override": True}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
print(f"\n尝试重新上传知识文件, 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == len(test_files)
print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_list_files():
print("\n获取知识库中文件列表:")
data = api.list_kb_docs(knowledge_base_name=kb)
pprint(data)
assert isinstance(data, list)
for name in test_files:
assert name in data
def test_search_docs():
query = "介绍一下langchain-chatchat项目"
print("\n检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_docs():
print(f"\n更新知识文件")
data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_delete_docs():
print(f"\n删除知识文件")
data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
query = "介绍一下langchain-chatchat项目"
print("\n尝试检索删除后的检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == 0
def test_recreate_vs():
print("\n重建知识库:")
r = api.recreate_vector_store(kb)
for data in r:
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
query = "本项目支持哪些文件格式?"
print("\n尝试检索重建后的检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_delete_kb_after():
print("\n删除知识库")
data = api.delete_knowledge_base(kb)
pprint(data)
# check kb not exists anymore
print("\n获取知识库列表:")
data = api.list_knowledge_bases()
pprint(data)
assert isinstance(data, list) and len(data) > 0
assert kb not in data

View File

@ -138,14 +138,11 @@ def knowledge_base_page(api: ApiRequest):
# use_container_width=True,
disabled=len(files) == 0,
):
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
data[-1]["not_refresh_vs_cache"]=False
for k in data:
ret = api.upload_kb_doc(**k)
if msg := check_success_msg(ret):
st.toast(msg, icon="")
elif msg := check_error_msg(ret):
st.toast(msg, icon="")
ret = api.upload_kb_docs(files, knowledge_base_name=kb, override=True)
if msg := check_success_msg(ret):
st.toast(msg, icon="")
elif msg := check_error_msg(ret):
st.toast(msg, icon="")
st.session_state.files = []
st.divider()
@ -218,8 +215,8 @@ def knowledge_base_page(api: ApiRequest):
disabled=not file_exists(kb, selected_rows)[0],
use_container_width=True,
):
for row in selected_rows:
api.update_kb_doc(kb, row["file_name"])
file_names = [row["file_name"] for row in selected_rows]
api.update_kb_docs(kb, file_names=file_names)
st.experimental_rerun()
# 将文件从向量库中删除,但不删除文件本身。
@ -228,8 +225,8 @@ def knowledge_base_page(api: ApiRequest):
disabled=not (selected_rows and selected_rows[0]["in_db"]),
use_container_width=True,
):
for row in selected_rows:
api.delete_kb_doc(kb, row["file_name"])
file_names = [row["file_name"] for row in selected_rows]
api.delete_kb_docs(kb, file_names=file_names)
st.experimental_rerun()
if cols[3].button(
@ -237,9 +234,8 @@ def knowledge_base_page(api: ApiRequest):
type="primary",
use_container_width=True,
):
for row in selected_rows:
ret = api.delete_kb_doc(kb, row["file_name"], True)
st.toast(ret.get("msg", " "))
file_names = [row["file_name"] for row in selected_rows]
api.delete_kb_docs(kb, file_names=file_names, delete_content=True)
st.experimental_rerun()
st.divider()

View File

@ -21,9 +21,7 @@ from fastapi.responses import StreamingResponse
import contextlib
import json
from io import BytesIO
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.db.repository.knowledge_file_repository import get_file_detail
from server.utils import run_async, iter_over_async, set_httpx_timeout
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
from configs.model_config import NLTK_DATA_PATH
import nltk
@ -43,7 +41,7 @@ class ApiRequest:
'''
def __init__(
self,
base_url: str = "http://127.0.0.1:7861",
base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT,
no_remote_api: bool = False, # call api view function directly
):
@ -78,7 +76,7 @@ class ApiRequest:
else:
return httpx.get(url, params=params, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when get {url}: {e}")
retry -= 1
async def aget(
@ -99,7 +97,7 @@ class ApiRequest:
else:
return await client.get(url, params=params, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when aget {url}: {e}")
retry -= 1
def post(
@ -121,7 +119,7 @@ class ApiRequest:
else:
return httpx.post(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when post {url}: {e}")
retry -= 1
async def apost(
@ -143,7 +141,7 @@ class ApiRequest:
else:
return await client.post(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when apost {url}: {e}")
retry -= 1
def delete(
@ -164,7 +162,7 @@ class ApiRequest:
else:
return httpx.delete(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when delete {url}: {e}")
retry -= 1
async def adelete(
@ -186,7 +184,7 @@ class ApiRequest:
else:
return await client.delete(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when adelete {url}: {e}")
retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
@ -205,7 +203,7 @@ class ApiRequest:
elif chunk.strip():
yield chunk
except Exception as e:
logger.error(e)
logger.error(f"error when run fastapi router: {e}")
def _httpx_stream2generator(
self,
@ -231,18 +229,18 @@ class ApiRequest:
print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。"
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
logger.error(msg)
logger.error(msg)
logger.error(e)
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'"
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'。({e}"
logger.error(msg)
logger.error(e)
yield {"code": 500, "msg": msg}
except Exception as e:
logger.error(e)
yield {"code": 500, "msg": str(e)}
msg = f"API通信遇到错误{e}"
logger.error(msg)
yield {"code": 500, "msg": msg}
# 对话相关操作
@ -413,8 +411,9 @@ class ApiRequest:
try:
return response.json()
except Exception as e:
logger.error(e)
return {"code": 500, "msg": errorMsg or str(e)}
msg = "API未能返回正确的JSON。" + (errorMsg or str(e))
logger.error(msg)
return {"code": 500, "msg": msg}
def list_knowledge_bases(
self,
@ -510,12 +509,45 @@ class ApiRequest:
data = self._check_httpx_json_response(response)
return data.get("data", [])
def upload_kb_doc(
def search_kb_docs(
self,
file: Union[str, Path, bytes],
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD,
no_remote_api: bool = None,
) -> List:
'''
对应api.py/knowledge_base/search_docs接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import search_docs
return search_docs(**data)
else:
response = self.post(
"/knowledge_base/search_docs",
json=data,
)
data = self._check_httpx_json_response(response)
return data
def upload_kb_docs(
self,
files: List[Union[str, Path, bytes]],
knowledge_base_name: str,
filename: str = None,
override: bool = False,
to_vector_store: bool = True,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
@ -525,97 +557,113 @@ class ApiRequest:
if no_remote_api is None:
no_remote_api = self.no_remote_api
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or file.name
def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or file.name
return filename, file
files = [convert_file(file) for file in files]
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"to_vector_store": to_vector_store,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import upload_doc
from server.knowledge_base.kb_doc_api import upload_docs
from fastapi import UploadFile
from tempfile import SpooledTemporaryFile
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read())
temp_file.seek(0)
response = run_async(upload_doc(
UploadFile(file=temp_file, filename=filename),
knowledge_base_name,
override,
))
upload_files = []
for file, filename in files:
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read())
temp_file.seek(0)
upload_files.append(UploadFile(file=temp_file, filename=filename))
response = run_async(upload_docs(upload_files, **data))
return response.dict()
else:
if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post(
"/knowledge_base/upload_doc",
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
files={"file": (filename, file)},
"/knowledge_base/upload_docs",
data=data,
files=[("files", (filename, file)) for filename, file in files],
)
return self._check_httpx_json_response(response)
def delete_kb_doc(
def delete_kb_docs(
self,
knowledge_base_name: str,
doc_name: str,
file_names: List[str],
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
对应api.py/knowledge_base/delete_doc接口
对应api.py/knowledge_base/delete_docs接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"knowledge_base_name": knowledge_base_name,
"doc_name": doc_name,
"file_names": file_names,
"delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import delete_doc
response = run_async(delete_doc(**data))
from server.knowledge_base.kb_doc_api import delete_docs
response = run_async(delete_docs(**data))
return response.dict()
else:
response = self.post(
"/knowledge_base/delete_doc",
"/knowledge_base/delete_docs",
json=data,
)
return self._check_httpx_json_response(response)
def update_kb_doc(
def update_kb_docs(
self,
knowledge_base_name: str,
file_name: str,
file_names: List[str],
override_custom_docs: bool = False,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
对应api.py/knowledge_base/update_doc接口
对应api.py/knowledge_base/update_docs接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"knowledge_base_name": knowledge_base_name,
"file_names": file_names,
"override_custom_docs": override_custom_docs,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import update_doc
response = run_async(update_doc(knowledge_base_name, file_name))
from server.knowledge_base.kb_doc_api import update_docs
response = run_async(update_docs(**data))
return response.dict()
else:
if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post(
"/knowledge_base/update_doc",
json={
"knowledge_base_name": knowledge_base_name,
"file_name": file_name,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
"/knowledge_base/update_docs",
json=data,
)
return self._check_httpx_json_response(response)