update api and webui:

1. add download_doc to api
2. return local path or http url in kowledge_base_chat depends on
   no_remote_api
3. change assistant avater in webui
This commit is contained in:
liunux4odoo 2023-08-14 11:46:36 +08:00
parent 0d6a9cf8f3
commit cc08e2cb96
6 changed files with 57 additions and 15 deletions

View File

@ -14,7 +14,7 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat) search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, recreate_vector_store) update_doc, download_doc, recreate_vector_store)
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -101,6 +101,10 @@ def create_app():
summary="更新现有文件到知识库" summary="更新现有文件到知识库"
)(update_doc) )(update_doc)
app.get("/knowledge_base/download_doc",
tags=["Knowledge Base Management"],
summary="下载对应的知识文件")(download_doc)
app.post("/knowledge_base/recreate_vector_store", app.post("/knowledge_base/recreate_vector_store",
tags=["Knowledge Base Management"], tags=["Knowledge Base Management"],
summary="根据content中文档重建向量库流式输出处理进度。" summary="根据content中文档重建向量库流式输出处理进度。"

View File

@ -1,4 +1,4 @@
from fastapi import Body from fastapi import Body, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K) VECTOR_SEARCH_TOP_K)
@ -15,6 +15,7 @@ from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
import json import json
import os import os
from urllib.parse import urlencode
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
@ -29,6 +30,8 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
"content": "虎头虎脑"}]] "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
): ):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None: if kb is None:
@ -64,10 +67,16 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
callback.done), callback.done),
) )
source_documents = [ source_documents = []
f"""出处 [{inum + 1}] [{os.path.split(doc.metadata["source"])[-1]}] \n\n{doc.page_content}\n\n""" for inum, doc in enumerate(docs):
for inum, doc in enumerate(docs) filename = os.path.split(doc.metadata["source"])[-1]
] if local_doc_url:
url = "file://" + doc.metadata["source"]
else:
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
url = f"{request.base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if stream: if stream:
async for token in callback.aiter(): async for token in callback.aiter():

View File

@ -1,10 +1,10 @@
import os import os
import urllib import urllib
from fastapi import File, Form, Body, UploadFile from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, FileResponse
import json import json
from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List from typing import List
@ -104,9 +104,32 @@ async def update_doc(
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
async def download_doc(): async def download_doc(
# TODO: 下载文件 knowledge_base_name: str = Query(..., examples=["samples"]),
pass file_name: str = Query(..., examples=["test.txt"]),
):
'''
下载知识库文档
'''
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
else:
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
async def recreate_vector_store( async def recreate_vector_store(

View File

@ -35,7 +35,6 @@ if __name__ == "__main__":
with st.sidebar: with st.sidebar:
st.image( st.image(
os.path.join( os.path.join(
os.path.dirname(__file__),
"img", "img",
"logo-long-chatchat-trans-v2.png" "logo-long-chatchat-trans-v2.png"
), ),

View File

@ -8,7 +8,6 @@ import os
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"img", "img",
"chatchat_icon_blue_square_v2.png" "chatchat_icon_blue_square_v2.png"
) )

View File

@ -259,6 +259,7 @@ class ApiRequest:
self, self,
query: str, query: str,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -267,12 +268,18 @@ class ApiRequest:
if no_remote_api is None: if no_remote_api is None:
no_remote_api = self.no_remote_api no_remote_api = self.no_remote_api
data = {
"query": query,
"history": history,
"stream": stream,
}
if no_remote_api: if no_remote_api:
from server.chat.chat import chat from server.chat.chat import chat
response = chat(query, history) response = chat(**data)
return self._fastapi_stream2generator(response) return self._fastapi_stream2generator(response)
else: else:
response = self.post("/chat/chat", json={"query": query, "history": history}, stream=True) response = self.post("/chat/chat", json=data, stream=True)
return self._httpx_stream2generator(response) return self._httpx_stream2generator(response)
def knowledge_base_chat( def knowledge_base_chat(
@ -296,6 +303,7 @@ class ApiRequest:
"top_k": top_k, "top_k": top_k,
"history": history, "history": history,
"stream": stream, "stream": stream,
"local_doc_url": no_remote_api,
} }
if no_remote_api: if no_remote_api: