update kb and search chat: disable streaming in swagger besides streaming in ApiRequest

This commit is contained in:
liunux4odoo 2023-08-09 23:35:36 +08:00
parent ca49f9d095
commit 222689ed5b
3 changed files with 74 additions and 28 deletions

View File

@ -20,13 +20,14 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
@ -67,11 +68,19 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
for inum, doc in enumerate(docs)
]
async for token in callback.aiter():
# Use server-sent-events to stream the response
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
"docs": source_documents},
ensure_ascii=False)
await task
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),

View File

@ -61,13 +61,14 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -106,8 +107,15 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
for inum, doc in enumerate(docs)
]
async for token in callback.aiter():
# Use server-sent-events to stream the response
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)

View File

@ -253,6 +253,7 @@ class ApiRequest:
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
history: List[Dict] = [],
stream: bool = True,
no_remote_api: bool = None,
):
'''
@ -261,14 +262,22 @@ class ApiRequest:
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"history": history,
"stream": stream,
}
if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat
response = knowledge_base_chat(query, knowledge_base_name, top_k, history)
response = knowledge_base_chat(**data)
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post(
"/chat/knowledge_base_chat",
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "history": history},
json=data,
stream=True,
)
return self._httpx_stream2generator(response, as_json=True)
@ -278,6 +287,7 @@ class ApiRequest:
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
stream: bool = True,
no_remote_api: bool = None,
):
'''
@ -286,14 +296,21 @@ class ApiRequest:
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"query": query,
"search_engine_name": search_engine_name,
"top_k": top_k,
"stream": stream,
}
if no_remote_api:
from server.chat.search_engine_chat import search_engine_chat
response = search_engine_chat(query, search_engine_name, top_k)
response = search_engine_chat(**data)
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post(
"/chat/search_engine_chat",
json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k},
json=data,
stream=True,
)
return self._httpx_stream2generator(response, as_json=True)
@ -331,14 +348,20 @@ class ApiRequest:
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"knowledge_base_name": knowledge_base_name,
"vector_store_type": vector_store_type,
"embed_model": embed_model,
}
if no_remote_api:
from server.knowledge_base.kb_api import create_kb
response = run_async(create_kb(knowledge_base_name, vector_store_type, embed_model))
response = run_async(create_kb(**data))
return response.dict()
else:
response = self.post(
"/knowledge_base/create_knowledge_base",
json={"knowledge_base_name": knowledge_base_name, "vector_store_type": vector_store_type, "embed_model": embed_model},
json=data,
)
return response.json()
@ -443,14 +466,20 @@ class ApiRequest:
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"knowledge_base_name": knowledge_base_name,
"doc_name": doc_name,
"delete_content": delete_content,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import delete_doc
response = run_async(delete_doc(knowledge_base_name, doc_name, delete_content))
response = run_async(delete_doc(**data))
return response.dict()
else:
response = self.delete(
"/knowledge_base/delete_doc",
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name, "delete_content": delete_content},
json=data,
)
return response.json()