140 lines
4.2 KiB
Python
140 lines
4.2 KiB
Python
from __future__ import annotations
|
||
|
||
from typing import List, Literal
|
||
|
||
from fastapi import APIRouter, Request
|
||
|
||
from chatchat.settings import Settings
|
||
from chatchat.server.api_server.api_schemas import OpenAIChatInput, OpenAIChatOutput
|
||
from chatchat.server.chat.file_chat import upload_temp_docs
|
||
from chatchat.server.chat.kb_chat import kb_chat
|
||
from chatchat.server.knowledge_base.kb_api import create_kb, delete_kb, list_kbs
|
||
from chatchat.server.knowledge_base.kb_doc_api import (
|
||
delete_docs,
|
||
download_doc,
|
||
list_files,
|
||
recreate_vector_store,
|
||
search_docs,
|
||
update_docs,
|
||
update_info,
|
||
upload_docs,
|
||
search_temp_docs,
|
||
)
|
||
from chatchat.server.knowledge_base.kb_summary_api import (
|
||
recreate_summary_vector_store,
|
||
summary_doc_ids_to_vector_store,
|
||
summary_file_to_vector_store,
|
||
)
|
||
from chatchat.server.utils import BaseResponse, ListResponse
|
||
from chatchat.server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||
|
||
|
||
kb_router = APIRouter(prefix="/knowledge_base", tags=["Knowledge Base Management"])
|
||
|
||
|
||
@kb_router.post(
|
||
"/{mode}/{param}/chat/completions", summary="知识库对话,openai 兼容,参数与 /chat/kb_chat 一致"
|
||
)
|
||
async def kb_chat_endpoint(
|
||
mode: Literal["local_kb", "temp_kb", "search_engine"],
|
||
param: str,
|
||
body: OpenAIChatInput,
|
||
request: Request,
|
||
):
|
||
# import rich
|
||
# rich.print(body)
|
||
|
||
if body.max_tokens in [None, 0]:
|
||
body.max_tokens = Settings.model_settings.MAX_TOKENS
|
||
|
||
extra = body.model_extra
|
||
ret = await kb_chat(
|
||
query=body.messages[-1]["content"],
|
||
mode=mode,
|
||
kb_name=param,
|
||
top_k=extra.get("top_k", Settings.kb_settings.VECTOR_SEARCH_TOP_K),
|
||
score_threshold=extra.get("score_threshold", Settings.kb_settings.SCORE_THRESHOLD),
|
||
history=body.messages[:-1],
|
||
stream=body.stream,
|
||
model=body.model,
|
||
temperature=body.temperature,
|
||
max_tokens=body.max_tokens,
|
||
prompt_name=extra.get("prompt_name", "default"),
|
||
return_direct=extra.get("return_direct", False),
|
||
request=request,
|
||
)
|
||
return ret
|
||
|
||
|
||
kb_router.get(
|
||
"/list_knowledge_bases", response_model=ListResponse, summary="获取知识库列表"
|
||
)(list_kbs)
|
||
|
||
kb_router.post(
|
||
"/create_knowledge_base", response_model=BaseResponse, summary="创建知识库"
|
||
)(create_kb)
|
||
|
||
kb_router.post(
|
||
"/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库"
|
||
)(delete_kb)
|
||
|
||
kb_router.get(
|
||
"/list_files", response_model=ListResponse, summary="获取知识库内的文件列表"
|
||
)(list_files)
|
||
|
||
kb_router.post("/search_docs", response_model=List[dict], summary="搜索知识库")(
|
||
search_docs
|
||
)
|
||
|
||
kb_router.post(
|
||
"/upload_docs",
|
||
response_model=BaseResponse,
|
||
summary="上传文件到知识库,并/或进行向量化",
|
||
)(upload_docs)
|
||
|
||
kb_router.post(
|
||
"/delete_docs", response_model=BaseResponse, summary="删除知识库内指定文件"
|
||
)(delete_docs)
|
||
|
||
kb_router.post("/update_info", response_model=BaseResponse, summary="更新知识库介绍")(
|
||
update_info
|
||
)
|
||
|
||
kb_router.post(
|
||
"/update_docs", response_model=BaseResponse, summary="更新现有文件到知识库"
|
||
)(update_docs)
|
||
|
||
kb_router.get("/download_doc", summary="下载对应的知识文件")(download_doc)
|
||
|
||
kb_router.post(
|
||
"/recreate_vector_store", summary="根据content中文档重建向量库,流式输出处理进度。"
|
||
)(recreate_vector_store)
|
||
|
||
kb_router.post("/upload_temp_docs", summary="上传文件到临时目录,用于文件对话。")(
|
||
upload_temp_docs
|
||
)
|
||
|
||
kb_router.post("/search_temp_docs", summary="检索临时知识库")(
|
||
search_temp_docs
|
||
)
|
||
|
||
# @kb_router.post("/list_temp_kbs", summary="列出所有临时知识库")
|
||
# def list_temp_kbs():
|
||
# return list(memo_faiss_pool.keys())
|
||
|
||
|
||
summary_router = APIRouter(prefix="/kb_summary_api")
|
||
summary_router.post(
|
||
"/summary_file_to_vector_store", summary="单个知识库根据文件名称摘要"
|
||
)(summary_file_to_vector_store)
|
||
summary_router.post(
|
||
"/summary_doc_ids_to_vector_store",
|
||
summary="单个知识库根据doc_ids摘要",
|
||
response_model=BaseResponse,
|
||
)(summary_doc_ids_to_vector_store)
|
||
summary_router.post("/recreate_summary_vector_store", summary="重建单个知识库文件摘要")(
|
||
recreate_summary_vector_store
|
||
)
|
||
|
||
kb_router.include_router(summary_router)
|