merge pr1413
This commit is contained in:
commit
1195eb75eb
|
|
@ -15,8 +15,8 @@ from starlette.responses import RedirectResponse
|
|||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
search_engine_chat)
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
|
||||
update_doc, download_doc, recreate_vector_store,
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
|
||||
import httpx
|
||||
|
|
@ -98,23 +98,23 @@ def create_app():
|
|||
summary="搜索知识库"
|
||||
)(search_docs)
|
||||
|
||||
app.post("/knowledge_base/upload_doc",
|
||||
app.post("/knowledge_base/upload_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="上传文件到知识库"
|
||||
)(upload_doc)
|
||||
summary="上传文件到知识库,并/或进行向量化"
|
||||
)(upload_docs)
|
||||
|
||||
app.post("/knowledge_base/delete_doc",
|
||||
app.post("/knowledge_base/delete_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库内指定文件"
|
||||
)(delete_doc)
|
||||
)(delete_docs)
|
||||
|
||||
app.post("/knowledge_base/update_doc",
|
||||
app.post("/knowledge_base/update_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="更新现有文件到知识库"
|
||||
)(update_doc)
|
||||
)(update_docs)
|
||||
|
||||
app.get("/knowledge_base/download_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
|
|||
from server.knowledge_base.utils import validate_kb_name
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
||||
from configs.model_config import EMBEDDING_MODEL
|
||||
from configs.model_config import EMBEDDING_MODEL, logger
|
||||
from fastapi import Body
|
||||
|
||||
|
||||
|
|
@ -30,8 +30,9 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||
try:
|
||||
kb.create_kb()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
|
||||
msg = f"创建知识库出错: {e}"
|
||||
logger.error(msg)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||
|
||||
|
|
@ -55,7 +56,8 @@ async def delete_kb(
|
|||
if status:
|
||||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
|
||||
msg = f"删除知识库时出现意外: {e}"
|
||||
logger.error(msg)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
import os
|
||||
import urllib
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
|
||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
logger,)
|
||||
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
||||
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
|
||||
files2docs_in_thread, KnowledgeFile)
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from pydantic import Json
|
||||
import json
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from typing import List, Dict
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
|
@ -44,11 +49,83 @@ async def list_files(
|
|||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
|
||||
def _save_files_in_thread(files: List[UploadFile],
|
||||
knowledge_base_name: str,
|
||||
override: bool):
|
||||
'''
|
||||
通过多线程将上传的文件保存到对应知识库目录内。
|
||||
生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
||||
'''
|
||||
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
|
||||
'''
|
||||
保存单个文件。
|
||||
'''
|
||||
try:
|
||||
filename = file.filename
|
||||
file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename)
|
||||
data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}
|
||||
|
||||
file_content = file.file.read() # 读取上传文件的内容
|
||||
if (os.path.isfile(file_path)
|
||||
and not override
|
||||
and os.path.getsize(file_path) == len(file_content)
|
||||
):
|
||||
# TODO: filesize 不同后的处理
|
||||
file_status = f"文件 {filename} 已存在。"
|
||||
logger.warn(file_status)
|
||||
return dict(code=404, msg=file_status, data=data)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
|
||||
except Exception as e:
|
||||
msg = f"{filename} 文件上传失败,报错信息为: {e}"
|
||||
logger.error(msg)
|
||||
return dict(code=500, msg=msg, data=data)
|
||||
|
||||
params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files]
|
||||
for result in run_in_thread_pool(save_file, params=params):
|
||||
yield result
|
||||
|
||||
|
||||
# 似乎没有单独增加一个文件上传API接口的必要
|
||||
# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
# override: bool = Form(False, description="覆盖已有文件")):
|
||||
# '''
|
||||
# API接口:上传文件。流式返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
||||
# '''
|
||||
# def generate(files, knowledge_base_name, override):
|
||||
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||
# yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream")
|
||||
|
||||
|
||||
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
|
||||
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
# override: bool = Form(False, description="覆盖已有文件"),
|
||||
# save: bool = Form(True, description="是否将文件保存到知识库目录")):
|
||||
# def save_files(files, knowledge_base_name, override):
|
||||
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||
# yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
# def files_to_docs(files):
|
||||
# for result in files2docs_in_thread(files):
|
||||
# yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
override: bool = Form(False, description="覆盖已有文件"),
|
||||
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
|
||||
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
API接口:上传文件,并/或向量化
|
||||
'''
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
|
|
@ -56,37 +133,36 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
|||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
file_content = await file.read() # 读取上传文件的内容
|
||||
failed_files = {}
|
||||
file_names = list(docs.keys())
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
# 先将上传的文件保存到磁盘
|
||||
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||
filename = result["data"]["file_name"]
|
||||
if result["code"] != 200:
|
||||
failed_files[filename] = result["msg"]
|
||||
|
||||
if filename not in file_names:
|
||||
file_names.append(filename)
|
||||
|
||||
if (os.path.exists(kb_file.filepath)
|
||||
and not override
|
||||
and os.path.getsize(kb_file.filepath) == len(file_content)
|
||||
):
|
||||
# TODO: filesize 不同后的处理
|
||||
file_status = f"文件 {kb_file.filename} 已存在。"
|
||||
return BaseResponse(code=404, msg=file_status)
|
||||
# 对保存的文件进行向量化
|
||||
if to_vector_store:
|
||||
result = await update_docs(
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
file_names=file_names,
|
||||
override_custom_docs=True,
|
||||
docs=docs,
|
||||
not_refresh_vs_cache=True,
|
||||
)
|
||||
failed_files.update(result.data["failed_files"])
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
with open(kb_file.filepath, "wb") as f:
|
||||
f.write(file_content)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
|
||||
|
||||
try:
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
||||
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
doc_name: str = Body(..., examples=["file_name.md"]),
|
||||
async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
|
||||
delete_content: bool = Body(False),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
|
|
@ -98,23 +174,31 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
if not kb.exist_doc(doc_name):
|
||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||
failed_files = {}
|
||||
for file_name in file_names:
|
||||
if not kb.exist_doc(file_name):
|
||||
failed_files[file_name] = f"未找到文件 {file_name}"
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
|
||||
except Exception as e:
|
||||
msg = f"{file_name} 文件删除失败,错误信息:{e}"
|
||||
logger.error(msg)
|
||||
failed_files[file_name] = msg
|
||||
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
|
||||
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
async def update_doc(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_name: str = Body(..., examples=["file_name"]),
|
||||
async def update_docs(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
|
||||
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||||
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
|
|
@ -127,22 +211,57 @@ async def update_doc(
|
|||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
if os.path.exists(kb_file.filepath):
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
|
||||
failed_files = {}
|
||||
kb_files = []
|
||||
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
|
||||
# 生成需要加载docs的文件列表
|
||||
for file_name in file_names:
|
||||
file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name)
|
||||
# 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖
|
||||
if file_detail.get("custom_docs") and not override_custom_docs:
|
||||
continue
|
||||
if file_name not in docs:
|
||||
try:
|
||||
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
|
||||
except Exception as e:
|
||||
msg = f"加载文档 {file_name} 时出错:{e}"
|
||||
logger.error(msg)
|
||||
failed_files[file_name] = msg
|
||||
|
||||
# 从文件生成docs,并进行向量化。
|
||||
# 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile
|
||||
for status, result in files2docs_in_thread(kb_files):
|
||||
if status:
|
||||
kb_name, file_name, new_docs = result
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb_file.splited_docs = new_docs
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
kb_name, file_name, error = result
|
||||
failed_files[file_name] = error
|
||||
|
||||
# 将自定义的docs进行向量化
|
||||
for file_name, v in docs.items():
|
||||
try:
|
||||
v = [x if isinstance(x, Document) else Document(**x) for x in v]
|
||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)
|
||||
kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
|
||||
except Exception as e:
|
||||
msg = f"为 {file_name} 添加自定义docs时出错:{e}"
|
||||
logger.error(msg)
|
||||
failed_files[file_name] = msg
|
||||
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
async def download_doc(
|
||||
knowledge_base_name: str = Query(..., examples=["samples"]),
|
||||
file_name: str = Query(..., examples=["test.txt"]),
|
||||
knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]),
|
||||
file_name: str = Query(...,description="文件名称", examples=["test.txt"]),
|
||||
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
|
||||
):
|
||||
'''
|
||||
下载知识库文档
|
||||
|
|
@ -154,6 +273,11 @@ async def download_doc(
|
|||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
if preview:
|
||||
content_disposition_type = "inline"
|
||||
else:
|
||||
content_disposition_type = None
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
|
@ -162,10 +286,13 @@ async def download_doc(
|
|||
return FileResponse(
|
||||
path=kb_file.filepath,
|
||||
filename=kb_file.filename,
|
||||
media_type="multipart/form-data")
|
||||
media_type="multipart/form-data",
|
||||
content_disposition_type=content_disposition_type,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
|
||||
msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}"
|
||||
logger.error(msg)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
||||
|
||||
|
|
@ -190,27 +317,30 @@ async def recreate_vector_store(
|
|||
else:
|
||||
kb.create_kb()
|
||||
kb.clear_vs()
|
||||
docs = list_files_from_folder(knowledge_base_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, knowledge_base_name)
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
kb_files = [(file, knowledge_base_name) for file in files]
|
||||
i = 0
|
||||
for status, result in files2docs_in_thread(kb_files):
|
||||
if status:
|
||||
kb_name, file_name, docs = result
|
||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||
kb_file.splited_docs = docs
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(docs)}): {doc}",
|
||||
"total": len(docs),
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i,
|
||||
"doc": doc,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
kb_name, file_name, error = result
|
||||
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
|
||||
return StreamingResponse(output(), media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -51,6 +51,13 @@ class KBService(ABC):
|
|||
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||
return load_embeddings(self.embed_model, embed_device)
|
||||
|
||||
def save_vector_store(self, vector_store=None):
|
||||
'''
|
||||
保存向量库,仅支持FAISS。对于其它向量库该函数不做任何操作。
|
||||
减少FAISS向量库操作时的类型判断。
|
||||
'''
|
||||
pass
|
||||
|
||||
def create_kb(self):
|
||||
"""
|
||||
创建知识库
|
||||
|
|
@ -84,6 +91,8 @@ class KBService(ABC):
|
|||
"""
|
||||
if docs:
|
||||
custom_docs = True
|
||||
for doc in docs:
|
||||
doc.metadata.setdefault("source", kb_file.filepath)
|
||||
else:
|
||||
docs = kb_file.file2text()
|
||||
custom_docs = False
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ from configs.model_config import (
|
|||
KB_ROOT_PATH,
|
||||
CACHED_VS_NUM,
|
||||
EMBEDDING_MODEL,
|
||||
SCORE_THRESHOLD
|
||||
SCORE_THRESHOLD,
|
||||
logger,
|
||||
)
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from functools import lru_cache
|
||||
|
|
@ -28,7 +29,7 @@ def load_faiss_vector_store(
|
|||
embeddings: Embeddings = None,
|
||||
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||
) -> FAISS:
|
||||
print(f"loading vector store in '{knowledge_base_name}'.")
|
||||
logger.info(f"loading vector store in '{knowledge_base_name}'.")
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
if embeddings is None:
|
||||
embeddings = load_embeddings(embed_model, embed_device)
|
||||
|
|
@ -57,7 +58,7 @@ def refresh_vs_cache(kb_name: str):
|
|||
make vector store cache refreshed when next loading
|
||||
"""
|
||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
||||
logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
|
|
@ -133,7 +134,7 @@ class FaissKBService(KBService):
|
|||
**kwargs):
|
||||
vector_store = self.load_vector_store()
|
||||
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -7,16 +7,20 @@ from configs.model_config import (
|
|||
KB_ROOT_PATH,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
ZH_TITLE_ENHANCE
|
||||
ZH_TITLE_ENHANCE,
|
||||
logger,
|
||||
)
|
||||
from functools import lru_cache
|
||||
import importlib
|
||||
from text_splitter import zh_title_enhance
|
||||
import langchain.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from pathlib import Path
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from server.utils import run_in_thread_pool
|
||||
import io
|
||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
|
||||
|
||||
|
|
@ -173,12 +177,74 @@ def get_LoaderClass(file_extension):
|
|||
return LoaderClass
|
||||
|
||||
|
||||
# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化
|
||||
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
|
||||
'''
|
||||
根据loader_name和文件路径或内容返回文档加载器。
|
||||
'''
|
||||
try:
|
||||
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
else:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, loader_name)
|
||||
except Exception as e:
|
||||
logger.error(f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}")
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
|
||||
if loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
|
||||
elif loader_name == "CSVLoader":
|
||||
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
|
||||
elif loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
|
||||
elif loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(file_path_or_content, text_content=False)
|
||||
elif loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
elif loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
else:
|
||||
loader = DocumentLoader(file_path_or_content)
|
||||
return loader
|
||||
|
||||
|
||||
def make_text_splitter(
|
||||
splitter_name: str = "SpacyTextSplitter",
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
):
|
||||
'''
|
||||
根据参数获取特定的分词器
|
||||
'''
|
||||
splitter_name = splitter_name or "SpacyTextSplitter"
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
try:
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
text_splitter = TextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查找分词器 {splitter_name} 时出错:{e}")
|
||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
return text_splitter
|
||||
|
||||
class KnowledgeFile:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
knowledge_base_name: str
|
||||
):
|
||||
'''
|
||||
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
|
||||
'''
|
||||
self.kb_name = knowledge_base_name
|
||||
self.filename = filename
|
||||
self.ext = os.path.splitext(filename)[-1].lower()
|
||||
|
|
@ -186,76 +252,62 @@ class KnowledgeFile:
|
|||
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||
self.docs = None
|
||||
self.splited_docs = None
|
||||
self.document_loader_name = get_LoaderClass(self.ext)
|
||||
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
self.text_splitter_name = None
|
||||
|
||||
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False):
|
||||
if self.docs is not None and not refresh:
|
||||
return self.docs
|
||||
def file2docs(self, refresh: bool=False):
|
||||
if self.docs is None or refresh:
|
||||
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
||||
loader = get_loader(self.document_loader_name, self.filepath)
|
||||
self.docs = loader.load()
|
||||
return self.docs
|
||||
|
||||
print(f"{self.document_loader_name} used for {self.filepath}")
|
||||
try:
|
||||
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
else:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
if self.document_loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||
elif self.document_loader_name == "CSVLoader":
|
||||
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
||||
elif self.document_loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
||||
elif self.document_loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(self.filepath, text_content=False)
|
||||
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
else:
|
||||
loader = DocumentLoader(self.filepath)
|
||||
def docs2texts(
|
||||
self,
|
||||
docs: List[Document] = None,
|
||||
using_zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||
refresh: bool = False,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
text_splitter: TextSplitter = None,
|
||||
):
|
||||
docs = docs or self.file2docs(refresh=refresh)
|
||||
if not docs:
|
||||
return []
|
||||
if self.ext not in [".csv"]:
|
||||
if text_splitter is None:
|
||||
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
docs = text_splitter.split_documents(docs)
|
||||
|
||||
if self.ext in ".csv":
|
||||
docs = loader.load()
|
||||
else:
|
||||
try:
|
||||
if self.text_splitter_name is None:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=OVERLAP_SIZE,
|
||||
)
|
||||
self.text_splitter_name = "SpacyTextSplitter"
|
||||
else:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=OVERLAP_SIZE)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=OVERLAP_SIZE,
|
||||
)
|
||||
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
|
||||
print(docs[0])
|
||||
print(f"文档切分示例:{docs[0]}")
|
||||
if using_zh_title_enhance:
|
||||
docs = zh_title_enhance(docs)
|
||||
self.docs = docs
|
||||
return docs
|
||||
self.splited_docs = docs
|
||||
return self.splited_docs
|
||||
|
||||
def file2text(
|
||||
self,
|
||||
using_zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||
refresh: bool = False,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
text_splitter: TextSplitter = None,
|
||||
):
|
||||
if self.splited_docs is None or refresh:
|
||||
docs = self.file2docs()
|
||||
self.splited_docs = self.docs2texts(docs=docs,
|
||||
using_zh_title_enhance=using_zh_title_enhance,
|
||||
refresh=refresh,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
text_splitter=text_splitter)
|
||||
return self.splited_docs
|
||||
|
||||
def file_exist(self):
|
||||
return os.path.isfile(self.filepath)
|
||||
|
||||
def get_mtime(self):
|
||||
return os.path.getmtime(self.filepath)
|
||||
|
|
@ -264,53 +316,47 @@ class KnowledgeFile:
|
|||
return os.path.getsize(self.filepath)
|
||||
|
||||
|
||||
def run_in_thread_pool(
|
||||
func: Callable,
|
||||
params: List[Dict] = [],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||
'''
|
||||
tasks = []
|
||||
if pool is None:
|
||||
pool = ThreadPoolExecutor()
|
||||
|
||||
for kwargs in params:
|
||||
thread = pool.submit(func, **kwargs)
|
||||
tasks.append(thread)
|
||||
|
||||
for obj in as_completed(tasks):
|
||||
yield obj.result()
|
||||
|
||||
|
||||
def files2docs_in_thread(
|
||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
利用多线程批量将文件转化成langchain Document.
|
||||
生成器返回值为{(kb_name, file_name): docs}
|
||||
利用多线程批量将磁盘文件转化成langchain Document.
|
||||
如果传入参数是Tuple,形式为(filename, kb_name)
|
||||
生成器返回值为 status, (kb_name, file_name, docs | error)
|
||||
'''
|
||||
def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]:
|
||||
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||
try:
|
||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||
except Exception as e:
|
||||
return False, e
|
||||
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
|
||||
logger.error(msg)
|
||||
return False, (file.kb_name, file.filename, msg)
|
||||
|
||||
kwargs_list = []
|
||||
for i, file in enumerate(files):
|
||||
kwargs = {}
|
||||
if isinstance(file, tuple) and len(file) >= 2:
|
||||
files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
|
||||
file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
|
||||
elif isinstance(file, dict):
|
||||
filename = file.pop("filename")
|
||||
kb_name = file.pop("kb_name")
|
||||
files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kwargs = file
|
||||
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kwargs["file"] = file
|
||||
kwargs_list.append(kwargs)
|
||||
|
||||
for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool):
|
||||
|
||||
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
|
||||
yield result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pprint import pprint
|
||||
|
||||
kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
|
||||
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
|
||||
docs = kb_file.file2docs()
|
||||
pprint(docs[-1])
|
||||
|
||||
docs = kb_file.file2text()
|
||||
pprint(docs[-1])
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDIN
|
|||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
import os
|
||||
from server import model_workers
|
||||
from typing import Literal, Optional, Any
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any
|
||||
|
||||
|
||||
thread_pool = ThreadPoolExecutor()
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
|
|
@ -305,3 +309,24 @@ def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "
|
|||
if device not in ["cuda", "mps", "cpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
|
||||
def run_in_thread_pool(
|
||||
func: Callable,
|
||||
params: List[Dict] = [],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||
'''
|
||||
tasks = []
|
||||
pool = pool or thread_pool
|
||||
|
||||
for kwargs in params:
|
||||
thread = pool.submit(func, **kwargs)
|
||||
tasks.append(thread)
|
||||
|
||||
for obj in as_completed(tasks):
|
||||
yield obj.result()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,19 +7,23 @@ root_path = Path(__file__).parent.parent.parent
|
|||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
test_files = {
|
||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
|
||||
"README.MD": str(root_path / "README.MD"),
|
||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
|
||||
"test.txt": get_file_path("samples", "test.txt"),
|
||||
}
|
||||
|
||||
print("\n\n直接url访问\n")
|
||||
|
||||
|
||||
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
|
||||
if not Path(get_kb_path(kb)).exists():
|
||||
|
|
@ -78,37 +82,36 @@ def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
|
|||
assert kb in data["data"]
|
||||
|
||||
|
||||
def test_upload_doc(api="/knowledge_base/upload_doc"):
|
||||
def test_upload_docs(api="/knowledge_base/upload_docs"):
|
||||
url = api_base_url + api
|
||||
for name, path in test_files.items():
|
||||
print(f"\n上传知识文件: {name}")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
files = {"file": (name, open(path, "rb"))}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"成功上传文件 {name}"
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
|
||||
for name, path in test_files.items():
|
||||
print(f"\n尝试重新上传知识文件: {name}, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
files = {"file": (name, open(path, "rb"))}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"文件 {name} 已存在。"
|
||||
print(f"\n上传知识文件")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
for name, path in test_files.items():
|
||||
print(f"\n尝试重新上传知识文件: {name}, 覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
files = {"file": (name, open(path, "rb"))}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"成功上传文件 {name}"
|
||||
print(f"\n尝试重新上传知识文件, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == len(test_files)
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_list_files(api="/knowledge_base/list_files"):
|
||||
|
|
@ -134,26 +137,26 @@ def test_search_docs(api="/knowledge_base/search_docs"):
|
|||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_update_doc(api="/knowledge_base/update_doc"):
|
||||
def test_update_docs(api="/knowledge_base/update_docs"):
|
||||
url = api_base_url + api
|
||||
for name, path in test_files.items():
|
||||
print(f"\n更新知识文件: {name}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"成功更新文件 {name}"
|
||||
|
||||
print(f"\n更新知识文件")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_delete_doc(api="/knowledge_base/delete_doc"):
|
||||
def test_delete_docs(api="/knowledge_base/delete_docs"):
|
||||
url = api_base_url + api
|
||||
for name, path in test_files.items():
|
||||
print(f"\n删除知识文件: {name}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"{name} 文件删除成功"
|
||||
|
||||
print(f"\n删除知识文件")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
url = api_base_url + "/knowledge_base/search_docs"
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,161 @@
|
|||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
from webui_pages.utils import ApiRequest
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
api: ApiRequest = ApiRequest(api_base_url)
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
test_files = {
|
||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
|
||||
"README.MD": str(root_path / "README.MD"),
|
||||
"test.txt": get_file_path("samples", "test.txt"),
|
||||
}
|
||||
|
||||
print("\n\nApiRquest调用\n")
|
||||
|
||||
|
||||
def test_delete_kb_before():
|
||||
if not Path(get_kb_path(kb)).exists():
|
||||
return
|
||||
|
||||
data = api.delete_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb not in data["data"]
|
||||
|
||||
|
||||
def test_create_kb():
|
||||
print(f"\n尝试用空名称创建知识库:")
|
||||
data = api.create_knowledge_base(" ")
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
|
||||
|
||||
print(f"\n创建新知识库: {kb}")
|
||||
data = api.create_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"已新增知识库 {kb}"
|
||||
|
||||
print(f"\n尝试创建同名知识库: {kb}")
|
||||
data = api.create_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"已存在同名知识库 {kb}"
|
||||
|
||||
|
||||
def test_list_kbs():
|
||||
data = api.list_knowledge_bases()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) > 0
|
||||
assert kb in data
|
||||
|
||||
|
||||
def test_upload_docs():
|
||||
files = list(test_files.values())
|
||||
|
||||
print(f"\n上传知识文件")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == len(test_files)
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_list_files():
|
||||
print("\n获取知识库中文件列表:")
|
||||
data = api.list_kb_docs(knowledge_base_name=kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list)
|
||||
for name in test_files:
|
||||
assert name in data
|
||||
|
||||
|
||||
def test_search_docs():
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_update_docs():
|
||||
print(f"\n更新知识文件")
|
||||
data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_delete_docs():
|
||||
print(f"\n删除知识文件")
|
||||
data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n尝试检索删除后的检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == 0
|
||||
|
||||
|
||||
def test_recreate_vs():
|
||||
print("\n重建知识库:")
|
||||
r = api.recreate_vector_store(kb)
|
||||
for data in r:
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data["msg"])
|
||||
|
||||
query = "本项目支持哪些文件格式?"
|
||||
print("\n尝试检索重建后的检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_delete_kb_after():
|
||||
print("\n删除知识库")
|
||||
data = api.delete_knowledge_base(kb)
|
||||
pprint(data)
|
||||
|
||||
# check kb not exists anymore
|
||||
print("\n获取知识库列表:")
|
||||
data = api.list_knowledge_bases()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) > 0
|
||||
assert kb not in data
|
||||
|
|
@ -138,14 +138,11 @@ def knowledge_base_page(api: ApiRequest):
|
|||
# use_container_width=True,
|
||||
disabled=len(files) == 0,
|
||||
):
|
||||
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
|
||||
data[-1]["not_refresh_vs_cache"]=False
|
||||
for k in data:
|
||||
ret = api.upload_kb_doc(**k)
|
||||
if msg := check_success_msg(ret):
|
||||
st.toast(msg, icon="✔")
|
||||
elif msg := check_error_msg(ret):
|
||||
st.toast(msg, icon="✖")
|
||||
ret = api.upload_kb_docs(files, knowledge_base_name=kb, override=True)
|
||||
if msg := check_success_msg(ret):
|
||||
st.toast(msg, icon="✔")
|
||||
elif msg := check_error_msg(ret):
|
||||
st.toast(msg, icon="✖")
|
||||
st.session_state.files = []
|
||||
|
||||
st.divider()
|
||||
|
|
@ -218,8 +215,8 @@ def knowledge_base_page(api: ApiRequest):
|
|||
disabled=not file_exists(kb, selected_rows)[0],
|
||||
use_container_width=True,
|
||||
):
|
||||
for row in selected_rows:
|
||||
api.update_kb_doc(kb, row["file_name"])
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.update_kb_docs(kb, file_names=file_names)
|
||||
st.experimental_rerun()
|
||||
|
||||
# 将文件从向量库中删除,但不删除文件本身。
|
||||
|
|
@ -228,8 +225,8 @@ def knowledge_base_page(api: ApiRequest):
|
|||
disabled=not (selected_rows and selected_rows[0]["in_db"]),
|
||||
use_container_width=True,
|
||||
):
|
||||
for row in selected_rows:
|
||||
api.delete_kb_doc(kb, row["file_name"])
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.delete_kb_docs(kb, file_names=file_names)
|
||||
st.experimental_rerun()
|
||||
|
||||
if cols[3].button(
|
||||
|
|
@ -237,9 +234,8 @@ def knowledge_base_page(api: ApiRequest):
|
|||
type="primary",
|
||||
use_container_width=True,
|
||||
):
|
||||
for row in selected_rows:
|
||||
ret = api.delete_kb_doc(kb, row["file_name"], True)
|
||||
st.toast(ret.get("msg", " "))
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.delete_kb_docs(kb, file_names=file_names, delete_content=True)
|
||||
st.experimental_rerun()
|
||||
|
||||
st.divider()
|
||||
|
|
|
|||
|
|
@ -21,9 +21,7 @@ from fastapi.responses import StreamingResponse
|
|||
import contextlib
|
||||
import json
|
||||
from io import BytesIO
|
||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from server.utils import run_async, iter_over_async, set_httpx_timeout
|
||||
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
|
||||
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
import nltk
|
||||
|
|
@ -43,7 +41,7 @@ class ApiRequest:
|
|||
'''
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://127.0.0.1:7861",
|
||||
base_url: str = api_address(),
|
||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||
no_remote_api: bool = False, # call api view function directly
|
||||
):
|
||||
|
|
@ -78,7 +76,7 @@ class ApiRequest:
|
|||
else:
|
||||
return httpx.get(url, params=params, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when get {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
async def aget(
|
||||
|
|
@ -99,7 +97,7 @@ class ApiRequest:
|
|||
else:
|
||||
return await client.get(url, params=params, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when aget {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
def post(
|
||||
|
|
@ -121,7 +119,7 @@ class ApiRequest:
|
|||
else:
|
||||
return httpx.post(url, data=data, json=json, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when post {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
async def apost(
|
||||
|
|
@ -143,7 +141,7 @@ class ApiRequest:
|
|||
else:
|
||||
return await client.post(url, data=data, json=json, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when apost {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
def delete(
|
||||
|
|
@ -164,7 +162,7 @@ class ApiRequest:
|
|||
else:
|
||||
return httpx.delete(url, data=data, json=json, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when delete {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
async def adelete(
|
||||
|
|
@ -186,7 +184,7 @@ class ApiRequest:
|
|||
else:
|
||||
return await client.delete(url, data=data, json=json, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when adelete {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
||||
|
|
@ -205,7 +203,7 @@ class ApiRequest:
|
|||
elif chunk.strip():
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when run fastapi router: {e}")
|
||||
|
||||
def _httpx_stream2generator(
|
||||
self,
|
||||
|
|
@ -231,18 +229,18 @@ class ApiRequest:
|
|||
print(chunk, end="", flush=True)
|
||||
yield chunk
|
||||
except httpx.ConnectError as e:
|
||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
|
||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
||||
logger.error(msg)
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "msg": msg}
|
||||
except httpx.ReadTimeout as e:
|
||||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')"
|
||||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')。({e})"
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "msg": msg}
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
yield {"code": 500, "msg": str(e)}
|
||||
msg = f"API通信遇到错误:{e}"
|
||||
logger.error(msg)
|
||||
yield {"code": 500, "msg": msg}
|
||||
|
||||
# 对话相关操作
|
||||
|
||||
|
|
@ -413,8 +411,9 @@ class ApiRequest:
|
|||
try:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return {"code": 500, "msg": errorMsg or str(e)}
|
||||
msg = "API未能返回正确的JSON。" + (errorMsg or str(e))
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
def list_knowledge_bases(
|
||||
self,
|
||||
|
|
@ -510,12 +509,45 @@ class ApiRequest:
|
|||
data = self._check_httpx_json_response(response)
|
||||
return data.get("data", [])
|
||||
|
||||
def upload_kb_doc(
|
||||
def search_kb_docs(
|
||||
self,
|
||||
file: Union[str, Path, bytes],
|
||||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: int = SCORE_THRESHOLD,
|
||||
no_remote_api: bool = None,
|
||||
) -> List:
|
||||
'''
|
||||
对应api.py/knowledge_base/search_docs接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
return search_docs(**data)
|
||||
else:
|
||||
response = self.post(
|
||||
"/knowledge_base/search_docs",
|
||||
json=data,
|
||||
)
|
||||
data = self._check_httpx_json_response(response)
|
||||
return data
|
||||
|
||||
def upload_kb_docs(
|
||||
self,
|
||||
files: List[Union[str, Path, bytes]],
|
||||
knowledge_base_name: str,
|
||||
filename: str = None,
|
||||
override: bool = False,
|
||||
to_vector_store: bool = True,
|
||||
docs: Dict = {},
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
|
|
@ -525,97 +557,113 @@ class ApiRequest:
|
|||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
if isinstance(file, bytes): # raw bytes
|
||||
file = BytesIO(file)
|
||||
elif hasattr(file, "read"): # a file io like object
|
||||
filename = filename or file.name
|
||||
else: # a local path
|
||||
file = Path(file).absolute().open("rb")
|
||||
filename = filename or file.name
|
||||
def convert_file(file, filename=None):
|
||||
if isinstance(file, bytes): # raw bytes
|
||||
file = BytesIO(file)
|
||||
elif hasattr(file, "read"): # a file io like object
|
||||
filename = filename or file.name
|
||||
else: # a local path
|
||||
file = Path(file).absolute().open("rb")
|
||||
filename = filename or file.name
|
||||
return filename, file
|
||||
|
||||
files = [convert_file(file) for file in files]
|
||||
data={
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"override": override,
|
||||
"to_vector_store": to_vector_store,
|
||||
"docs": docs,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import upload_doc
|
||||
from server.knowledge_base.kb_doc_api import upload_docs
|
||||
from fastapi import UploadFile
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||
temp_file.write(file.read())
|
||||
temp_file.seek(0)
|
||||
response = run_async(upload_doc(
|
||||
UploadFile(file=temp_file, filename=filename),
|
||||
knowledge_base_name,
|
||||
override,
|
||||
))
|
||||
upload_files = []
|
||||
for file, filename in files:
|
||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||
temp_file.write(file.read())
|
||||
temp_file.seek(0)
|
||||
upload_files.append(UploadFile(file=temp_file, filename=filename))
|
||||
|
||||
response = run_async(upload_docs(upload_files, **data))
|
||||
return response.dict()
|
||||
else:
|
||||
if isinstance(data["docs"], dict):
|
||||
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
|
||||
response = self.post(
|
||||
"/knowledge_base/upload_doc",
|
||||
data={
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"override": override,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
},
|
||||
files={"file": (filename, file)},
|
||||
"/knowledge_base/upload_docs",
|
||||
data=data,
|
||||
files=[("files", (filename, file)) for filename, file in files],
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def delete_kb_doc(
|
||||
def delete_kb_docs(
|
||||
self,
|
||||
knowledge_base_name: str,
|
||||
doc_name: str,
|
||||
file_names: List[str],
|
||||
delete_content: bool = False,
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/delete_doc接口
|
||||
对应api.py/knowledge_base/delete_docs接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"doc_name": doc_name,
|
||||
"file_names": file_names,
|
||||
"delete_content": delete_content,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import delete_doc
|
||||
response = run_async(delete_doc(**data))
|
||||
from server.knowledge_base.kb_doc_api import delete_docs
|
||||
response = run_async(delete_docs(**data))
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.post(
|
||||
"/knowledge_base/delete_doc",
|
||||
"/knowledge_base/delete_docs",
|
||||
json=data,
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def update_kb_doc(
|
||||
def update_kb_docs(
|
||||
self,
|
||||
knowledge_base_name: str,
|
||||
file_name: str,
|
||||
file_names: List[str],
|
||||
override_custom_docs: bool = False,
|
||||
docs: Dict = {},
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/update_doc接口
|
||||
对应api.py/knowledge_base/update_docs接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"file_names": file_names,
|
||||
"override_custom_docs": override_custom_docs,
|
||||
"docs": docs,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
}
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import update_doc
|
||||
response = run_async(update_doc(knowledge_base_name, file_name))
|
||||
from server.knowledge_base.kb_doc_api import update_docs
|
||||
response = run_async(update_docs(**data))
|
||||
return response.dict()
|
||||
else:
|
||||
if isinstance(data["docs"], dict):
|
||||
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
|
||||
response = self.post(
|
||||
"/knowledge_base/update_doc",
|
||||
json={
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"file_name": file_name,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
},
|
||||
"/knowledge_base/update_docs",
|
||||
json=data,
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue