323 lines
12 KiB
Python
323 lines
12 KiB
Python
import contextlib
|
||
import inspect
|
||
import json
|
||
import logging
|
||
import os
|
||
from typing import *
|
||
|
||
from open_chatcaht._constants import API_BASE_URI
|
||
from open_chatcaht.utils import set_httpx_config, get_httpx_client, get_variable, get_function_default_params, \
|
||
merge_dicts
|
||
from functools import wraps
|
||
from typing import Type, get_type_hints
|
||
|
||
import httpx
|
||
import requests
|
||
from pydantic import BaseModel
|
||
|
||
set_httpx_config()
|
||
|
||
CHATCHAT_API_BASE = get_variable(os.getenv('CHATCHAT_API_BASE'), 'http://127.0.0.1:8000')
|
||
CHATCHAT_CLIENT_TIME_OUT = get_variable(os.getenv('CHATCHAT_CLIENT_TIME_OUT'), 60)
|
||
CHATCHAT_CLIENT_DEFAULT_RETRY_COUNT = get_variable(os.getenv('CHATCHAT_CLIENT_DEFAULT_RETRY'), 3)
|
||
CHATCHAT_CLIENT_DEFAULT_RETRY_INTERVAL = get_variable(os.getenv('CHATCHAT_CLIENT_DEFAULT_RETRY_INTERVAL'), 60)
|
||
|
||
|
||
class ApiClient:
|
||
"""
|
||
api.py调用的封装(同步模式),简化api调用方式
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str = API_BASE_URI,
|
||
timeout: float = 60,
|
||
use_async: bool = False,
|
||
use_proxy: bool = False,
|
||
proxies=None,
|
||
log_level: int = logging.INFO,
|
||
retry: int = 3,
|
||
retry_interval: int = 1,
|
||
):
|
||
if proxies is None:
|
||
proxies = {}
|
||
self.base_url = get_variable(base_url, CHATCHAT_API_BASE)
|
||
self.timeout = get_variable(timeout, CHATCHAT_CLIENT_TIME_OUT)
|
||
self._use_async = use_async
|
||
self.use_proxy = use_proxy
|
||
self.default_retry_count = get_variable(retry, CHATCHAT_CLIENT_DEFAULT_RETRY_COUNT)
|
||
self.default_retry_interval = get_variable(retry_interval, CHATCHAT_CLIENT_DEFAULT_RETRY_INTERVAL)
|
||
self.proxies = proxies
|
||
self._client = None
|
||
self.logger = logging.getLogger(__name__)
|
||
self.logger.setLevel(log_level)
|
||
|
||
@property
|
||
def client(self):
|
||
if self._client is None or self._client.is_closed:
|
||
self._client = get_httpx_client(
|
||
base_url=self.base_url, use_async=self._use_async, timeout=self.timeout
|
||
)
|
||
return self._client
|
||
|
||
def _get(
|
||
self,
|
||
url: str,
|
||
params: Union[Dict, List[Tuple], bytes] = None,
|
||
retry: int = 3,
|
||
stream: bool = False,
|
||
**kwargs: Any,
|
||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||
while retry > 0:
|
||
try:
|
||
if stream:
|
||
return self.client.stream("GET", url, params=params, **kwargs)
|
||
else:
|
||
return self.client.get(url, params=params, **kwargs)
|
||
except Exception as e:
|
||
msg = f"error when get {url}: {e}"
|
||
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
||
retry -= 1
|
||
|
||
def _post(
|
||
self,
|
||
url: str,
|
||
data: Dict = None,
|
||
json: Dict = None,
|
||
retry: int = 3,
|
||
stream: bool = False,
|
||
**kwargs: Any,
|
||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||
while retry > 0:
|
||
try:
|
||
# print(kwargs)
|
||
if stream:
|
||
|
||
return self.client.stream(
|
||
"POST", url, data=data, json=json, **kwargs
|
||
)
|
||
else:
|
||
self.logger.debug(f"post {url} with data: {data}")
|
||
return self.client.post(url, data=data, json=json, **kwargs)
|
||
except Exception as e:
|
||
msg = f"error when post {url}: {e}"
|
||
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
||
retry -= 1
|
||
|
||
def _delete(
|
||
self,
|
||
url: str,
|
||
data: Dict = None,
|
||
json: Dict = None,
|
||
retry: int = 3,
|
||
stream: bool = False,
|
||
**kwargs: Any,
|
||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||
while retry > 0:
|
||
try:
|
||
if stream:
|
||
return self.client.stream(
|
||
"DELETE", url, data=data, json=json, **kwargs
|
||
)
|
||
else:
|
||
return self.client.delete(url, data=data, json=json, **kwargs)
|
||
except Exception as e:
|
||
msg = f"error when delete {url}: {e}"
|
||
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
||
retry -= 1
|
||
|
||
def _httpx_stream2generator(
|
||
self,
|
||
response: contextlib._GeneratorContextManager,
|
||
as_json: bool = False,
|
||
):
|
||
"""
|
||
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
||
"""
|
||
|
||
async def ret_async(response, as_json):
|
||
try:
|
||
async with response as r:
|
||
chunk_cache = ""
|
||
async for chunk in r.aiter_text(None):
|
||
if not chunk: # fastchat api yield empty bytes on start and end
|
||
continue
|
||
if as_json:
|
||
try:
|
||
if chunk.startswith("data: "):
|
||
data = json.loads(chunk_cache + chunk[6:-2])
|
||
elif chunk.startswith(":"): # skip sse comment line
|
||
continue
|
||
else:
|
||
data = json.loads(chunk_cache + chunk)
|
||
|
||
chunk_cache = ""
|
||
yield data
|
||
except Exception as e:
|
||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
||
|
||
if chunk.startswith("data: "):
|
||
chunk_cache += chunk[6:-2]
|
||
elif chunk.startswith(":"): # skip sse comment line
|
||
continue
|
||
else:
|
||
chunk_cache += chunk
|
||
continue
|
||
else:
|
||
# print(chunk, end="", flush=True)
|
||
yield chunk
|
||
except httpx.ConnectError as e:
|
||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
||
self.logger.error(msg)
|
||
yield {"code": 500, "msg": msg}
|
||
except httpx.ReadTimeout as e:
|
||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})"
|
||
self.logger.error(msg)
|
||
yield {"code": 500, "msg": msg}
|
||
except Exception as e:
|
||
msg = f"API通信遇到错误:{e}"
|
||
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
||
yield {"code": 500, "msg": msg}
|
||
|
||
def ret_sync(response, as_json):
|
||
try:
|
||
with response as r:
|
||
chunk_cache = ""
|
||
for chunk in r.iter_text(None):
|
||
if not chunk: # fastchat api yield empty bytes on start and end
|
||
continue
|
||
if as_json:
|
||
try:
|
||
if chunk.startswith("data: "):
|
||
data = json.loads(chunk_cache + chunk[6:-2])
|
||
elif chunk.startswith(":"): # skip sse comment line
|
||
continue
|
||
else:
|
||
data = json.loads(chunk_cache + chunk)
|
||
|
||
chunk_cache = ""
|
||
yield data
|
||
except Exception as e:
|
||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
||
|
||
if chunk.startswith("data: "):
|
||
chunk_cache += chunk[6:-2]
|
||
elif chunk.startswith(":"): # skip sse comment line
|
||
continue
|
||
else:
|
||
chunk_cache += chunk
|
||
continue
|
||
else:
|
||
# print(chunk, end="", flush=True)
|
||
yield chunk
|
||
except httpx.ConnectError as e:
|
||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
||
self.logger.error(msg)
|
||
yield {"code": 500, "msg": msg}
|
||
except httpx.ReadTimeout as e:
|
||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})"
|
||
self.logger.error(msg)
|
||
yield {"code": 500, "msg": msg}
|
||
except Exception as e:
|
||
msg = f"API通信遇到错误:{e}"
|
||
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
||
yield {"code": 500, "msg": msg}
|
||
|
||
if self._use_async:
|
||
return ret_async(response, as_json)
|
||
else:
|
||
return ret_sync(response, as_json)
|
||
|
||
def _get_response_value(
|
||
self,
|
||
response: httpx.Response,
|
||
as_json: bool = False,
|
||
value_func: Callable = None,
|
||
):
|
||
"""
|
||
转换同步或异步请求返回的响应
|
||
`as_json`: 返回json
|
||
`value_func`: 用户可以自定义返回值,该函数接受response或json
|
||
"""
|
||
|
||
def to_json(r):
|
||
try:
|
||
return r.json()
|
||
except Exception as e:
|
||
msg = "API未能返回正确的JSON。" + str(e)
|
||
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
||
return {"code": 500, "msg": msg, "data": None}
|
||
|
||
if value_func is None:
|
||
value_func = lambda r: r
|
||
|
||
async def ret_async(response):
|
||
if as_json:
|
||
return value_func(to_json(await response))
|
||
else:
|
||
return value_func(await response)
|
||
|
||
if self._use_async:
|
||
return ret_async(response)
|
||
else:
|
||
if as_json:
|
||
return value_func(to_json(response))
|
||
else:
|
||
return value_func(response)
|
||
|
||
|
||
def get_request_method(api_client_obj: ApiClient, method):
|
||
if method is httpx.post:
|
||
return getattr(api_client_obj, "_post")
|
||
elif method is httpx.get:
|
||
return getattr(api_client_obj, "_get")
|
||
# elif method is httpx.put:
|
||
# return api_client_obj.put
|
||
elif method is httpx.delete:
|
||
return getattr(api_client_obj, "_delete")
|
||
|
||
|
||
def http_request(method):
|
||
def decorator(url, base_url='', headers=None, body_model: Type[BaseModel] = None, **options):
|
||
headers = headers or {}
|
||
|
||
def wrapper(func):
|
||
@wraps(func)
|
||
def inner(*args, **kwargs):
|
||
try:
|
||
default_param: dict = get_function_default_params(func)
|
||
|
||
api_client_obj: ApiClient = args[0] if len(args) > 0 and isinstance(args[0], ApiClient) else None
|
||
return_type = get_type_hints(func).get('return')
|
||
full_url = base_url + url
|
||
param = merge_dicts(kwargs, default_param)
|
||
if body_model is not None:
|
||
param = body_model(**kwargs).dict()
|
||
# Send the HTTP request
|
||
response = None
|
||
if api_client_obj is not None:
|
||
_method = get_request_method(api_client_obj, method)
|
||
response = _method(full_url, headers=headers, json=param)
|
||
else:
|
||
response = method(full_url, headers=headers, json=param)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
except requests.exceptions.HTTPError as http_err:
|
||
print(f"HTTP error occurred: {http_err}")
|
||
except Exception as err:
|
||
print(f"An error occurred: {err}")
|
||
|
||
return inner
|
||
|
||
return wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
post = http_request(httpx.post)
|
||
get = http_request(httpx.get)
|
||
delete = http_request(httpx.delete)
|
||
put = http_request(httpx.put)
|