From cc08e2cb96ee78912a2045408799557036fb1ffe Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Mon, 14 Aug 2023 11:46:36 +0800 Subject: [PATCH] update api and webui: 1. add download_doc to api 2. return local path or http url in kowledge_base_chat depends on no_remote_api 3. change assistant avater in webui --- server/api.py | 6 +++++- server/chat/knowledge_base_chat.py | 19 ++++++++++++----- server/knowledge_base/kb_doc_api.py | 33 ++++++++++++++++++++++++----- webui.py | 1 - webui_pages/dialogue/dialogue.py | 1 - webui_pages/utils.py | 12 +++++++++-- 6 files changed, 57 insertions(+), 15 deletions(-) diff --git a/server/api.py b/server/api.py index 873b887..458b1d7 100644 --- a/server/api.py +++ b/server/api.py @@ -14,7 +14,7 @@ from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, - update_doc, recreate_vector_store) + update_doc, download_doc, recreate_vector_store) from server.utils import BaseResponse, ListResponse nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -101,6 +101,10 @@ def create_app(): summary="更新现有文件到知识库" )(update_doc) + app.get("/knowledge_base/download_doc", + tags=["Knowledge Base Management"], + summary="下载对应的知识文件")(download_doc) + app.post("/knowledge_base/recreate_vector_store", tags=["Knowledge Base Management"], summary="根据content中文档重建向量库,流式输出处理进度。" diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index d6dafe2..0ecabf0 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,4 +1,4 @@ -from fastapi import Body +from fastapi import Body, Request from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, VECTOR_SEARCH_TOP_K) @@ -15,6 +15,7 @@ from server.chat.utils import History from server.knowledge_base.kb_service.base import KBService, KBServiceFactory import json import os +from urllib.parse import urlencode def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), @@ -29,6 +30,8 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), + local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), + request: Request = None, ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: @@ -64,10 +67,16 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp callback.done), ) - source_documents = [ - f"""出处 [{inum + 1}] [{os.path.split(doc.metadata["source"])[-1]}] \n\n{doc.page_content}\n\n""" - for inum, doc in enumerate(docs) - ] + source_documents = [] + for inum, doc in enumerate(docs): + filename = os.path.split(doc.metadata["source"])[-1] + if local_doc_url: + url = "file://" + doc.metadata["source"] + else: + parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) + url = f"{request.base_url}knowledge_base/download_doc?" + parameters + text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" + source_documents.append(text) if stream: async for token in callback.aiter(): diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index b467d96..3f27fb1 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,10 +1,10 @@ import os import urllib -from fastapi import File, Form, Body, UploadFile +from fastapi import File, Form, Body, Query, UploadFile from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, FileResponse import json from server.knowledge_base.kb_service.base import KBServiceFactory from typing import List @@ -104,9 +104,32 @@ async def update_doc( return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") -async def download_doc(): - # TODO: 下载文件 - pass +async def download_doc( + knowledge_base_name: str = Query(..., examples=["samples"]), + file_name: str = Query(..., examples=["test.txt"]), + ): + ''' + 下载知识库文档 + ''' + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + + if os.path.exists(kb_file.filepath): + return FileResponse( + path=kb_file.filepath, + filename=kb_file.filename, + media_type="multipart/form-data") + else: + return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") + + async def recreate_vector_store( diff --git a/webui.py b/webui.py index 703108c..d84da42 100644 --- a/webui.py +++ b/webui.py @@ -35,7 +35,6 @@ if __name__ == "__main__": with st.sidebar: st.image( os.path.join( - os.path.dirname(__file__), "img", "logo-long-chatchat-trans-v2.png" ), diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 6a806bd..a39a196 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -8,7 +8,6 @@ import os chat_box = ChatBox( assistant_avatar=os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "img", "chatchat_icon_blue_square_v2.png" ) diff --git a/webui_pages/utils.py b/webui_pages/utils.py index e64403c..d2772d5 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -259,6 +259,7 @@ class ApiRequest: self, query: str, history: List[Dict] = [], + stream: bool = True, no_remote_api: bool = None, ): ''' @@ -267,12 +268,18 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api + data = { + "query": query, + "history": history, + "stream": stream, + } + if no_remote_api: from server.chat.chat import chat - response = chat(query, history) + response = chat(**data) return self._fastapi_stream2generator(response) else: - response = self.post("/chat/chat", json={"query": query, "history": history}, stream=True) + response = self.post("/chat/chat", json=data, stream=True) return self._httpx_stream2generator(response) def knowledge_base_chat( @@ -296,6 +303,7 @@ class ApiRequest: "top_k": top_k, "history": history, "stream": stream, + "local_doc_url": no_remote_api, } if no_remote_api: