wjs/yolo_safehat_det_flask_stream/req.py

313 lines
11 KiB
Python
Raw Permalink Normal View History

2024-05-15 14:46:56 +08:00
# 该文件封装了对api.py的请求可以被不同的webui使用
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
from typing import *
from pathlib import Path
# 此处导入的配置为发起请求如WEBUI机器上的配置主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
# from configs import (
# EMBEDDING_MODEL,
# DEFAULT_VS_TYPE,
# LLM_MODELS,
# TEMPERATURE,
# SCORE_THRESHOLD,
# CHUNK_SIZE,
# OVERLAP_SIZE,
# ZH_TITLE_ENHANCE,
# FIRST_VECTOR_SEARCH_TOP_K,
# VECTOR_SEARCH_TOP_K,
# SEARCH_ENGINE_TOP_K,
# HTTPX_DEFAULT_TIMEOUT,
# logger, log_verbose,
# )
import httpx
import contextlib
import json
import os
from io import BytesIO
# from server.utils import set_httpx_config, api_address, get_httpx_client
from pprint import pprint
# from langchain_core._api import deprecated
# set_httpx_config()
def get_httpx_client(
use_async: bool = False,
proxies: Union[str, Dict] = None,
timeout: float = 300,
**kwargs,
) -> Union[httpx.Client, httpx.AsyncClient]:
if use_async:
return httpx.AsyncClient(**kwargs)
else:
return httpx.Client(**kwargs)
class ApiRequest:
'''
api.py调用的封装同步模式,简化api调用方式
'''
def __init__(
self,
base_url: str = 'http://127.0.0.1:7861',
# base_url: str = 'http://192.168.0.21:17861',
timeout: float = 300,
):
self.base_url = base_url
self.timeout = timeout
self._use_async = False
self._client = None
@property
def client(self):
if self._client is None or self._client.is_closed:
self._client = get_httpx_client(base_url=self.base_url,
use_async=self._use_async,
timeout=self.timeout)
return self._client
def get(
self,
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
if stream:
return self.client.stream("GET", url, params=params, **kwargs)
else:
return self.client.get(url, params=params, **kwargs)
except Exception as e:
msg = f"error when get {url}: {e}"
# logger.error(f'{e.__class__.__name__}: {msg}',
# exc_info=e if log_verbose else None)
retry -= 1
def post(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
# print(kwargs)
if stream:
return self.client.stream("POST", url, data=data, json=json, **kwargs)
else:
return self.client.post(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when post {url}: {e}"
# logger.error(f'{e.__class__.__name__}: {msg}',
# exc_info=e if log_verbose else None)
retry -= 1
def _httpx_stream2generator(
self,
response: contextlib._GeneratorContextManager,
as_json: bool = False,
):
'''
将httpx.stream返回的GeneratorContextManager转化为普通生成器
'''
async def ret_async(response, as_json):
try:
async with response as r:
async for chunk in r.aiter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
if chunk.startswith("data: "):
data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
continue
else:
data = json.loads(chunk)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
# logger.error(f'{e.__class__.__name__}: {msg}',
# exc_info=e if log_verbose else None)
else:
# print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
# logger.error(msg)
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见Wiki '5. 启动 API 服务或 Web UI')。({e}"
# logger.error(msg)
yield {"code": 500, "msg": msg}
except Exception as e:
msg = f"API通信遇到错误{e}"
# logger.error(f'{e.__class__.__name__}: {msg}',
# exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg}
def ret_sync(response, as_json):
try:
with response as r:
for chunk in r.iter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
if chunk.startswith("data: "):
data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
continue
else:
data = json.loads(chunk)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
# # logger.error(f'{e.__class__.__name__}: {msg}',
# exc_info=e if log_verbose else None)
else:
# print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
# logger.error(msg)
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见Wiki '5. 启动 API 服务或 Web UI')。({e}"
# logger.error(msg)
yield {"code": 500, "msg": msg}
except Exception as e:
msg = f"API通信遇到错误{e}"
# logger.error(f'{e.__class__.__name__}: {msg}',
# exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg}
if self._use_async:
return ret_async(response, as_json)
else:
return ret_sync(response, as_json)
def _get_response_value(
self,
response: httpx.Response,
as_json: bool = False,
value_func: Callable = None,
):
'''
转换同步或异步请求返回的响应
`as_json`: 返回json
`value_func`: 用户可以自定义返回值该函数接受response或json
'''
def to_json(r):
try:
return r.json()
except Exception as e:
msg = "API未能返回正确的JSON。" + str(e)
# if log_verbose:
# # logger.error(f'{e.__class__.__name__}: {msg}',
# # exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg, "data": None}
if value_func is None:
value_func = (lambda r: r)
async def ret_async(response):
if as_json:
return value_func(to_json(await response))
else:
return value_func(await response)
if self._use_async:
return ret_async(response)
else:
if as_json:
return value_func(to_json(response))
else:
return value_func(response)
def knowledge_base_chat(
self,
query: str,
knowledge_base_name: str,
top_k: int = 1,
score_threshold: float = 2,
history: List[Dict] = [],
stream: bool = True,
model: str = "chatglm3-6b",
temperature: float = 0.01,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/knowledge_base_chat接口
'''
data = {
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
}
# print(f"received input message:")
# pprint(data)
response = self.post(
"/chat/knowledge_base_chat",
json=data,
stream=True,
)
return self._httpx_stream2generator(response, as_json=True)
# class AsyncApiRequest(ApiRequest):
# def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
# super().__init__(base_url, timeout)
# self._use_async = True
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''
return error message if error occured when requests API
'''
if isinstance(data, dict):
if key in data:
return data[key]
if "code" in data and data["code"] != 200:
return data["msg"]
return ""
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
'''
return error message if error occured when requests API
'''
if (isinstance(data, dict)
and key in data
and "code" in data
and data["code"] == 200):
return data[key]
return ""
if __name__ == "__main__":
api = ApiRequest()