git rebase and pull
This commit is contained in:
commit
8b18cf2a5e
|
|
@ -14,7 +14,7 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||||
search_engine_chat)
|
search_engine_chat)
|
||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||||||
update_doc, recreate_vector_store)
|
update_doc, download_doc, recreate_vector_store)
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
@ -101,6 +101,10 @@ def create_app():
|
||||||
summary="更新现有文件到知识库"
|
summary="更新现有文件到知识库"
|
||||||
)(update_doc)
|
)(update_doc)
|
||||||
|
|
||||||
|
app.get("/knowledge_base/download_doc",
|
||||||
|
tags=["Knowledge Base Management"],
|
||||||
|
summary="下载对应的知识文件")(download_doc)
|
||||||
|
|
||||||
app.post("/knowledge_base/recreate_vector_store",
|
app.post("/knowledge_base/recreate_vector_store",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from fastapi import Body
|
from fastapi import Body, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||||
VECTOR_SEARCH_TOP_K)
|
VECTOR_SEARCH_TOP_K)
|
||||||
|
|
@ -15,6 +15,7 @@ from server.chat.utils import History
|
||||||
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
|
||||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
|
|
@ -29,6 +30,8 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
"content": "虎头虎脑"}]]
|
"content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
|
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||||
|
request: Request = None,
|
||||||
):
|
):
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
if kb is None:
|
if kb is None:
|
||||||
|
|
@ -64,10 +67,16 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
callback.done),
|
callback.done),
|
||||||
)
|
)
|
||||||
|
|
||||||
source_documents = [
|
source_documents = []
|
||||||
f"""出处 [{inum + 1}] [{os.path.split(doc.metadata["source"])[-1]}] \n\n{doc.page_content}\n\n"""
|
for inum, doc in enumerate(docs):
|
||||||
for inum, doc in enumerate(docs)
|
filename = os.path.split(doc.metadata["source"])[-1]
|
||||||
]
|
if local_doc_url:
|
||||||
|
url = "file://" + doc.metadata["source"]
|
||||||
|
else:
|
||||||
|
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
||||||
|
url = f"{request.base_url}knowledge_base/download_doc?" + parameters
|
||||||
|
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||||
|
source_documents.append(text)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
async for token in callback.aiter():
|
async for token in callback.aiter():
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
from fastapi import File, Form, Body, UploadFile
|
from fastapi import File, Form, Body, Query, UploadFile
|
||||||
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL
|
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
|
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
@ -104,9 +104,32 @@ async def update_doc(
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
|
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
|
||||||
|
|
||||||
|
|
||||||
async def download_doc():
|
async def download_doc(
|
||||||
# TODO: 下载文件
|
knowledge_base_name: str = Query(..., examples=["samples"]),
|
||||||
pass
|
file_name: str = Query(..., examples=["test.txt"]),
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
下载知识库文档
|
||||||
|
'''
|
||||||
|
if not validate_kb_name(knowledge_base_name):
|
||||||
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
|
kb_file = KnowledgeFile(filename=file_name,
|
||||||
|
knowledge_base_name=knowledge_base_name)
|
||||||
|
|
||||||
|
if os.path.exists(kb_file.filepath):
|
||||||
|
return FileResponse(
|
||||||
|
path=kb_file.filepath,
|
||||||
|
filename=kb_file.filename,
|
||||||
|
media_type="multipart/form-data")
|
||||||
|
else:
|
||||||
|
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def recreate_vector_store(
|
async def recreate_vector_store(
|
||||||
|
|
|
||||||
1
webui.py
1
webui.py
|
|
@ -35,7 +35,6 @@ if __name__ == "__main__":
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
st.image(
|
st.image(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
os.path.dirname(__file__),
|
|
||||||
"img",
|
"img",
|
||||||
"logo-long-chatchat-trans-v2.png"
|
"logo-long-chatchat-trans-v2.png"
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import os
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
|
||||||
"img",
|
"img",
|
||||||
"chatchat_icon_blue_square_v2.png"
|
"chatchat_icon_blue_square_v2.png"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -259,6 +259,7 @@ class ApiRequest:
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
|
stream: bool = True,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -267,12 +268,18 @@ class ApiRequest:
|
||||||
if no_remote_api is None:
|
if no_remote_api is None:
|
||||||
no_remote_api = self.no_remote_api
|
no_remote_api = self.no_remote_api
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"query": query,
|
||||||
|
"history": history,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.chat import chat
|
from server.chat.chat import chat
|
||||||
response = chat(query, history)
|
response = chat(**data)
|
||||||
return self._fastapi_stream2generator(response)
|
return self._fastapi_stream2generator(response)
|
||||||
else:
|
else:
|
||||||
response = self.post("/chat/chat", json={"query": query, "history": history}, stream=True)
|
response = self.post("/chat/chat", json=data, stream=True)
|
||||||
return self._httpx_stream2generator(response)
|
return self._httpx_stream2generator(response)
|
||||||
|
|
||||||
def knowledge_base_chat(
|
def knowledge_base_chat(
|
||||||
|
|
@ -296,6 +303,7 @@ class ApiRequest:
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"history": history,
|
"history": history,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
|
"local_doc_url": no_remote_api,
|
||||||
}
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue