update kb and search chat: disable streaming in swagger besides streaming in ApiRequest
This commit is contained in:
parent
ca49f9d095
commit
222689ed5b
|
|
@ -20,13 +20,14 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
examples=[[
|
examples=[[
|
||||||
{"role": "user",
|
{"role": "user",
|
||||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
{"role": "assistant",
|
{"role": "assistant",
|
||||||
"content": "虎头虎脑"}]]
|
"content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
|
stream: bool = Body(False, description="流式输出"),
|
||||||
):
|
):
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
if kb is None:
|
if kb is None:
|
||||||
|
|
@ -67,11 +68,19 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
for inum, doc in enumerate(docs)
|
for inum, doc in enumerate(docs)
|
||||||
]
|
]
|
||||||
|
|
||||||
async for token in callback.aiter():
|
if stream:
|
||||||
# Use server-sent-events to stream the response
|
async for token in callback.aiter():
|
||||||
|
# Use server-sent-events to stream the response
|
||||||
|
yield json.dumps({"answer": token,
|
||||||
|
"docs": source_documents}, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
answer = ""
|
||||||
|
async for token in callback.aiter():
|
||||||
|
answer += token
|
||||||
yield json.dumps({"answer": token,
|
yield json.dumps({"answer": token,
|
||||||
"docs": source_documents},
|
"docs": source_documents},
|
||||||
ensure_ascii=False)
|
ensure_ascii=False)
|
||||||
|
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
|
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
|
||||||
|
|
|
||||||
|
|
@ -61,13 +61,14 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
examples=[[
|
examples=[[
|
||||||
{"role": "user",
|
{"role": "user",
|
||||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
{"role": "assistant",
|
{"role": "assistant",
|
||||||
"content": "虎头虎脑"}]]
|
"content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
|
stream: bool = Body(False, description="流式输出"),
|
||||||
):
|
):
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||||
|
|
@ -106,8 +107,15 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||||
for inum, doc in enumerate(docs)
|
for inum, doc in enumerate(docs)
|
||||||
]
|
]
|
||||||
|
|
||||||
async for token in callback.aiter():
|
if stream:
|
||||||
# Use server-sent-events to stream the response
|
async for token in callback.aiter():
|
||||||
|
# Use server-sent-events to stream the response
|
||||||
|
yield json.dumps({"answer": token,
|
||||||
|
"docs": source_documents}, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
answer = ""
|
||||||
|
async for token in callback.aiter():
|
||||||
|
answer += token
|
||||||
yield json.dumps({"answer": token,
|
yield json.dumps({"answer": token,
|
||||||
"docs": source_documents},
|
"docs": source_documents},
|
||||||
ensure_ascii=False)
|
ensure_ascii=False)
|
||||||
|
|
|
||||||
|
|
@ -253,6 +253,7 @@ class ApiRequest:
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
|
stream: bool = True,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -261,14 +262,22 @@ class ApiRequest:
|
||||||
if no_remote_api is None:
|
if no_remote_api is None:
|
||||||
no_remote_api = self.no_remote_api
|
no_remote_api = self.no_remote_api
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"query": query,
|
||||||
|
"knowledge_base_name": knowledge_base_name,
|
||||||
|
"top_k": top_k,
|
||||||
|
"history": history,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
response = knowledge_base_chat(query, knowledge_base_name, top_k, history)
|
response = knowledge_base_chat(**data)
|
||||||
return self._fastapi_stream2generator(response, as_json=True)
|
return self._fastapi_stream2generator(response, as_json=True)
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/knowledge_base_chat",
|
"/chat/knowledge_base_chat",
|
||||||
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "history": history},
|
json=data,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
@ -278,6 +287,7 @@ class ApiRequest:
|
||||||
query: str,
|
query: str,
|
||||||
search_engine_name: str,
|
search_engine_name: str,
|
||||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||||
|
stream: bool = True,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -286,14 +296,21 @@ class ApiRequest:
|
||||||
if no_remote_api is None:
|
if no_remote_api is None:
|
||||||
no_remote_api = self.no_remote_api
|
no_remote_api = self.no_remote_api
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"query": query,
|
||||||
|
"search_engine_name": search_engine_name,
|
||||||
|
"top_k": top_k,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.search_engine_chat import search_engine_chat
|
from server.chat.search_engine_chat import search_engine_chat
|
||||||
response = search_engine_chat(query, search_engine_name, top_k)
|
response = search_engine_chat(**data)
|
||||||
return self._fastapi_stream2generator(response, as_json=True)
|
return self._fastapi_stream2generator(response, as_json=True)
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/search_engine_chat",
|
"/chat/search_engine_chat",
|
||||||
json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k},
|
json=data,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
@ -331,14 +348,20 @@ class ApiRequest:
|
||||||
if no_remote_api is None:
|
if no_remote_api is None:
|
||||||
no_remote_api = self.no_remote_api
|
no_remote_api = self.no_remote_api
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"knowledge_base_name": knowledge_base_name,
|
||||||
|
"vector_store_type": vector_store_type,
|
||||||
|
"embed_model": embed_model,
|
||||||
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.knowledge_base.kb_api import create_kb
|
from server.knowledge_base.kb_api import create_kb
|
||||||
response = run_async(create_kb(knowledge_base_name, vector_store_type, embed_model))
|
response = run_async(create_kb(**data))
|
||||||
return response.dict()
|
return response.dict()
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/create_knowledge_base",
|
"/knowledge_base/create_knowledge_base",
|
||||||
json={"knowledge_base_name": knowledge_base_name, "vector_store_type": vector_store_type, "embed_model": embed_model},
|
json=data,
|
||||||
)
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
@ -443,14 +466,20 @@ class ApiRequest:
|
||||||
if no_remote_api is None:
|
if no_remote_api is None:
|
||||||
no_remote_api = self.no_remote_api
|
no_remote_api = self.no_remote_api
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"knowledge_base_name": knowledge_base_name,
|
||||||
|
"doc_name": doc_name,
|
||||||
|
"delete_content": delete_content,
|
||||||
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.knowledge_base.kb_doc_api import delete_doc
|
from server.knowledge_base.kb_doc_api import delete_doc
|
||||||
response = run_async(delete_doc(knowledge_base_name, doc_name, delete_content))
|
response = run_async(delete_doc(**data))
|
||||||
return response.dict()
|
return response.dict()
|
||||||
else:
|
else:
|
||||||
response = self.delete(
|
response = self.delete(
|
||||||
"/knowledge_base/delete_doc",
|
"/knowledge_base/delete_doc",
|
||||||
json={"knowledge_base_name": knowledge_base_name, "doc_name": doc_name, "delete_content": delete_content},
|
json=data,
|
||||||
)
|
)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue