make apirequest support streaming dict

This commit is contained in:
liunux4odoo 2023-08-04 12:49:39 +08:00
parent c4994e85df
commit 27d49be706
1 changed files with 25 additions and 30 deletions

View File

@ -8,12 +8,15 @@ from configs.model_config import (
KB_ROOT_PATH,
LLM_MODEL,
llm_model_dict,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
)
import httpx
import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn
from fastapi.responses import StreamingResponse
import contextlib
import json
def set_httpx_timeout(timeout=60.0):
@ -30,26 +33,6 @@ KB_ROOT_PATH = Path(KB_ROOT_PATH)
set_httpx_timeout()
def get_kb_list() -> List[str]:
'''
获取知识库列表
'''
kb_list = os.listdir(KB_ROOT_PATH)
return [x for x in kb_list if (KB_ROOT_PATH / x).is_dir()]
def get_kb_files(kb: str) -> List[str]:
'''
获取某个知识库下包含的所有文件只包括根目录一级
'''
kb = KB_ROOT_PATH / kb / "content"
if kb.is_dir():
kb_files = os.listdir(kb)
return kb_files
else:
return []
def run_async(cor):
'''
在同步环境中运行异步代码.
@ -174,7 +157,7 @@ class ApiRequest:
except:
retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse):
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
'''
将api.py中视图函数返回的StreamingResponse转化为同步生成器
'''
@ -182,15 +165,27 @@ class ApiRequest:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
return iter_over_async(response.body_iterator, loop)
for chunk in iter_over_async(response.body_iterator, loop):
if as_json and chunk:
yield json.loads(chunk)
else:
yield chunk
def _httpx_stream2generator(self,response: contextlib._GeneratorContextManager):
def _httpx_stream2generator(
self,
response: contextlib._GeneratorContextManager,
as_json: bool = False,
):
'''
将httpx.stream返回的GeneratorContextManager转化为普通生成器
'''
with response as r:
for chunk in r.iter_text(None):
yield chunk
if as_json and chunk:
yield json.loads(chunk)
else:
yield chunk
# 对话相关操作
@ -254,7 +249,7 @@ class ApiRequest:
self,
query: str,
knowledge_base_name: str,
top_k: int = 3,
top_k: int = VECTOR_SEARCH_TOP_K,
no_remote_api: bool = None,
):
'''
@ -266,20 +261,20 @@ class ApiRequest:
if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat
response = knowledge_base_chat(query, knowledge_base_name, top_k)
return self._fastapi_stream2generator(response)
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post(
"/chat/knowledge_base_chat",
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k},
stream=True,
)
return self._httpx_stream2generator(response)
return self._httpx_stream2generator(response, as_json=True)
def search_engine_chat(
self,
query: str,
search_engine_name: str,
top_k: int,
top_k: int = SEARCH_ENGINE_TOP_K,
no_remote_api: bool = None,
):
'''
@ -291,14 +286,14 @@ class ApiRequest:
if no_remote_api:
from server.chat.search_engine_chat import search_engine_chat
response = search_engine_chat(query, search_engine_name, top_k)
return self._fastapi_stream2generator(response)
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post(
"/chat/search_engine_chat",
json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k},
stream=True,
)
return self._httpx_stream2generator(response)
return self._httpx_stream2generator(response, as_json=True)
# 知识库相关操作