make apirequest support streaming dict
This commit is contained in:
parent
c4994e85df
commit
27d49be706
|
|
@ -8,12 +8,15 @@ from configs.model_config import (
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
LLM_MODEL,
|
LLM_MODEL,
|
||||||
llm_model_dict,
|
llm_model_dict,
|
||||||
|
VECTOR_SEARCH_TOP_K,
|
||||||
|
SEARCH_ENGINE_TOP_K,
|
||||||
)
|
)
|
||||||
import httpx
|
import httpx
|
||||||
import asyncio
|
import asyncio
|
||||||
from server.chat.openai_chat import OpenAiChatMsgIn
|
from server.chat.openai_chat import OpenAiChatMsgIn
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
def set_httpx_timeout(timeout=60.0):
|
def set_httpx_timeout(timeout=60.0):
|
||||||
|
|
@ -30,26 +33,6 @@ KB_ROOT_PATH = Path(KB_ROOT_PATH)
|
||||||
set_httpx_timeout()
|
set_httpx_timeout()
|
||||||
|
|
||||||
|
|
||||||
def get_kb_list() -> List[str]:
|
|
||||||
'''
|
|
||||||
获取知识库列表
|
|
||||||
'''
|
|
||||||
kb_list = os.listdir(KB_ROOT_PATH)
|
|
||||||
return [x for x in kb_list if (KB_ROOT_PATH / x).is_dir()]
|
|
||||||
|
|
||||||
|
|
||||||
def get_kb_files(kb: str) -> List[str]:
|
|
||||||
'''
|
|
||||||
获取某个知识库下包含的所有文件(只包括根目录一级)
|
|
||||||
'''
|
|
||||||
kb = KB_ROOT_PATH / kb / "content"
|
|
||||||
if kb.is_dir():
|
|
||||||
kb_files = os.listdir(kb)
|
|
||||||
return kb_files
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def run_async(cor):
|
def run_async(cor):
|
||||||
'''
|
'''
|
||||||
在同步环境中运行异步代码.
|
在同步环境中运行异步代码.
|
||||||
|
|
@ -174,7 +157,7 @@ class ApiRequest:
|
||||||
except:
|
except:
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
def _fastapi_stream2generator(self, response: StreamingResponse):
|
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
||||||
'''
|
'''
|
||||||
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
||||||
'''
|
'''
|
||||||
|
|
@ -182,15 +165,27 @@ class ApiRequest:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
except:
|
except:
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
return iter_over_async(response.body_iterator, loop)
|
|
||||||
|
for chunk in iter_over_async(response.body_iterator, loop):
|
||||||
|
if as_json and chunk:
|
||||||
|
yield json.loads(chunk)
|
||||||
|
else:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
def _httpx_stream2generator(self,response: contextlib._GeneratorContextManager):
|
def _httpx_stream2generator(
|
||||||
|
self,
|
||||||
|
response: contextlib._GeneratorContextManager,
|
||||||
|
as_json: bool = False,
|
||||||
|
):
|
||||||
'''
|
'''
|
||||||
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
||||||
'''
|
'''
|
||||||
with response as r:
|
with response as r:
|
||||||
for chunk in r.iter_text(None):
|
for chunk in r.iter_text(None):
|
||||||
yield chunk
|
if as_json and chunk:
|
||||||
|
yield json.loads(chunk)
|
||||||
|
else:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
# 对话相关操作
|
# 对话相关操作
|
||||||
|
|
||||||
|
|
@ -254,7 +249,7 @@ class ApiRequest:
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
top_k: int = 3,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -266,20 +261,20 @@ class ApiRequest:
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
response = knowledge_base_chat(query, knowledge_base_name, top_k)
|
response = knowledge_base_chat(query, knowledge_base_name, top_k)
|
||||||
return self._fastapi_stream2generator(response)
|
return self._fastapi_stream2generator(response, as_json=True)
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/knowledge_base_chat",
|
"/chat/knowledge_base_chat",
|
||||||
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k},
|
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
def search_engine_chat(
|
def search_engine_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
search_engine_name: str,
|
search_engine_name: str,
|
||||||
top_k: int,
|
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -291,14 +286,14 @@ class ApiRequest:
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.search_engine_chat import search_engine_chat
|
from server.chat.search_engine_chat import search_engine_chat
|
||||||
response = search_engine_chat(query, search_engine_name, top_k)
|
response = search_engine_chat(query, search_engine_name, top_k)
|
||||||
return self._fastapi_stream2generator(response)
|
return self._fastapi_stream2generator(response, as_json=True)
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/search_engine_chat",
|
"/chat/search_engine_chat",
|
||||||
json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k},
|
json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
# 知识库相关操作
|
# 知识库相关操作
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue