Langchain-Chatchat/libs/chatchat-server/chatchat/server/api_server/kb_routes.py

140 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from typing import List, Literal
from fastapi import APIRouter, Request
from chatchat.settings import Settings
from chatchat.server.api_server.api_schemas import OpenAIChatInput, OpenAIChatOutput
from chatchat.server.chat.file_chat import upload_temp_docs
from chatchat.server.chat.kb_chat import kb_chat
from chatchat.server.knowledge_base.kb_api import create_kb, delete_kb, list_kbs
from chatchat.server.knowledge_base.kb_doc_api import (
delete_docs,
download_doc,
list_files,
recreate_vector_store,
search_docs,
update_docs,
update_info,
upload_docs,
search_temp_docs,
)
from chatchat.server.knowledge_base.kb_summary_api import (
recreate_summary_vector_store,
summary_doc_ids_to_vector_store,
summary_file_to_vector_store,
)
from chatchat.server.utils import BaseResponse, ListResponse
from chatchat.server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
kb_router = APIRouter(prefix="/knowledge_base", tags=["Knowledge Base Management"])
@kb_router.post(
"/{mode}/{param}/chat/completions", summary="知识库对话openai 兼容,参数与 /chat/kb_chat 一致"
)
async def kb_chat_endpoint(
mode: Literal["local_kb", "temp_kb", "search_engine"],
param: str,
body: OpenAIChatInput,
request: Request,
):
# import rich
# rich.print(body)
if body.max_tokens in [None, 0]:
body.max_tokens = Settings.model_settings.MAX_TOKENS
extra = body.model_extra
ret = await kb_chat(
query=body.messages[-1]["content"],
mode=mode,
kb_name=param,
top_k=extra.get("top_k", Settings.kb_settings.VECTOR_SEARCH_TOP_K),
score_threshold=extra.get("score_threshold", Settings.kb_settings.SCORE_THRESHOLD),
history=body.messages[:-1],
stream=body.stream,
model=body.model,
temperature=body.temperature,
max_tokens=body.max_tokens,
prompt_name=extra.get("prompt_name", "default"),
return_direct=extra.get("return_direct", False),
request=request,
)
return ret
kb_router.get(
"/list_knowledge_bases", response_model=ListResponse, summary="获取知识库列表"
)(list_kbs)
kb_router.post(
"/create_knowledge_base", response_model=BaseResponse, summary="创建知识库"
)(create_kb)
kb_router.post(
"/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库"
)(delete_kb)
kb_router.get(
"/list_files", response_model=ListResponse, summary="获取知识库内的文件列表"
)(list_files)
kb_router.post("/search_docs", response_model=List[dict], summary="搜索知识库")(
search_docs
)
kb_router.post(
"/upload_docs",
response_model=BaseResponse,
summary="上传文件到知识库,并/或进行向量化",
)(upload_docs)
kb_router.post(
"/delete_docs", response_model=BaseResponse, summary="删除知识库内指定文件"
)(delete_docs)
kb_router.post("/update_info", response_model=BaseResponse, summary="更新知识库介绍")(
update_info
)
kb_router.post(
"/update_docs", response_model=BaseResponse, summary="更新现有文件到知识库"
)(update_docs)
kb_router.get("/download_doc", summary="下载对应的知识文件")(download_doc)
kb_router.post(
"/recreate_vector_store", summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
kb_router.post("/upload_temp_docs", summary="上传文件到临时目录,用于文件对话。")(
upload_temp_docs
)
kb_router.post("/search_temp_docs", summary="检索临时知识库")(
search_temp_docs
)
# @kb_router.post("/list_temp_kbs", summary="列出所有临时知识库")
# def list_temp_kbs():
# return list(memo_faiss_pool.keys())
summary_router = APIRouter(prefix="/kb_summary_api")
summary_router.post(
"/summary_file_to_vector_store", summary="单个知识库根据文件名称摘要"
)(summary_file_to_vector_store)
summary_router.post(
"/summary_doc_ids_to_vector_store",
summary="单个知识库根据doc_ids摘要",
response_model=BaseResponse,
)(summary_doc_ids_to_vector_store)
summary_router.post("/recreate_summary_vector_store", summary="重建单个知识库文件摘要")(
recreate_summary_vector_store
)
kb_router.include_router(summary_router)