make apirequest support streaming dict
This commit is contained in:
parent
c4994e85df
commit
27d49be706
|
|
@ -8,12 +8,15 @@ from configs.model_config import (
|
|||
KB_ROOT_PATH,
|
||||
LLM_MODEL,
|
||||
llm_model_dict,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
)
|
||||
import httpx
|
||||
import asyncio
|
||||
from server.chat.openai_chat import OpenAiChatMsgIn
|
||||
from fastapi.responses import StreamingResponse
|
||||
import contextlib
|
||||
import json
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
|
|
@ -30,26 +33,6 @@ KB_ROOT_PATH = Path(KB_ROOT_PATH)
|
|||
set_httpx_timeout()
|
||||
|
||||
|
||||
def get_kb_list() -> List[str]:
|
||||
'''
|
||||
获取知识库列表
|
||||
'''
|
||||
kb_list = os.listdir(KB_ROOT_PATH)
|
||||
return [x for x in kb_list if (KB_ROOT_PATH / x).is_dir()]
|
||||
|
||||
|
||||
def get_kb_files(kb: str) -> List[str]:
|
||||
'''
|
||||
获取某个知识库下包含的所有文件(只包括根目录一级)
|
||||
'''
|
||||
kb = KB_ROOT_PATH / kb / "content"
|
||||
if kb.is_dir():
|
||||
kb_files = os.listdir(kb)
|
||||
return kb_files
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def run_async(cor):
|
||||
'''
|
||||
在同步环境中运行异步代码.
|
||||
|
|
@ -174,7 +157,7 @@ class ApiRequest:
|
|||
except:
|
||||
retry -= 1
|
||||
|
||||
def _fastapi_stream2generator(self, response: StreamingResponse):
|
||||
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
||||
'''
|
||||
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
||||
'''
|
||||
|
|
@ -182,15 +165,27 @@ class ApiRequest:
|
|||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
return iter_over_async(response.body_iterator, loop)
|
||||
|
||||
for chunk in iter_over_async(response.body_iterator, loop):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
def _httpx_stream2generator(self,response: contextlib._GeneratorContextManager):
|
||||
def _httpx_stream2generator(
|
||||
self,
|
||||
response: contextlib._GeneratorContextManager,
|
||||
as_json: bool = False,
|
||||
):
|
||||
'''
|
||||
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
||||
'''
|
||||
with response as r:
|
||||
for chunk in r.iter_text(None):
|
||||
yield chunk
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
# 对话相关操作
|
||||
|
||||
|
|
@ -254,7 +249,7 @@ class ApiRequest:
|
|||
self,
|
||||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = 3,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -266,20 +261,20 @@ class ApiRequest:
|
|||
if no_remote_api:
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
response = knowledge_base_chat(query, knowledge_base_name, top_k)
|
||||
return self._fastapi_stream2generator(response)
|
||||
return self._fastapi_stream2generator(response, as_json=True)
|
||||
else:
|
||||
response = self.post(
|
||||
"/chat/knowledge_base_chat",
|
||||
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k},
|
||||
stream=True,
|
||||
)
|
||||
return self._httpx_stream2generator(response)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
def search_engine_chat(
|
||||
self,
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int,
|
||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -291,14 +286,14 @@ class ApiRequest:
|
|||
if no_remote_api:
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
response = search_engine_chat(query, search_engine_name, top_k)
|
||||
return self._fastapi_stream2generator(response)
|
||||
return self._fastapi_stream2generator(response, as_json=True)
|
||||
else:
|
||||
response = self.post(
|
||||
"/chat/search_engine_chat",
|
||||
json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k},
|
||||
stream=True,
|
||||
)
|
||||
return self._httpx_stream2generator(response)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
# 知识库相关操作
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue