323 lines
12 KiB
Python
323 lines
12 KiB
Python
|
|
import contextlib
|
|||
|
|
import inspect
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
from typing import *
|
|||
|
|
|
|||
|
|
from open_chatcaht._constants import API_BASE_URI
|
|||
|
|
from open_chatcaht.utils import set_httpx_config, get_httpx_client, get_variable, get_function_default_params, \
|
|||
|
|
merge_dicts
|
|||
|
|
from functools import wraps
|
|||
|
|
from typing import Type, get_type_hints
|
|||
|
|
|
|||
|
|
import httpx
|
|||
|
|
import requests
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
set_httpx_config()
|
|||
|
|
|
|||
|
|
CHATCHAT_API_BASE = get_variable(os.getenv('CHATCHAT_API_BASE'), 'http://127.0.0.1:8000')
|
|||
|
|
CHATCHAT_CLIENT_TIME_OUT = get_variable(os.getenv('CHATCHAT_CLIENT_TIME_OUT'), 60)
|
|||
|
|
CHATCHAT_CLIENT_DEFAULT_RETRY_COUNT = get_variable(os.getenv('CHATCHAT_CLIENT_DEFAULT_RETRY'), 3)
|
|||
|
|
CHATCHAT_CLIENT_DEFAULT_RETRY_INTERVAL = get_variable(os.getenv('CHATCHAT_CLIENT_DEFAULT_RETRY_INTERVAL'), 60)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ApiClient:
|
|||
|
|
"""
|
|||
|
|
api.py调用的封装(同步模式),简化api调用方式
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
base_url: str = API_BASE_URI,
|
|||
|
|
timeout: float = 60,
|
|||
|
|
use_async: bool = False,
|
|||
|
|
use_proxy: bool = False,
|
|||
|
|
proxies=None,
|
|||
|
|
log_level: int = logging.INFO,
|
|||
|
|
retry: int = 3,
|
|||
|
|
retry_interval: int = 1,
|
|||
|
|
):
|
|||
|
|
if proxies is None:
|
|||
|
|
proxies = {}
|
|||
|
|
self.base_url = get_variable(base_url, CHATCHAT_API_BASE)
|
|||
|
|
self.timeout = get_variable(timeout, CHATCHAT_CLIENT_TIME_OUT)
|
|||
|
|
self._use_async = use_async
|
|||
|
|
self.use_proxy = use_proxy
|
|||
|
|
self.default_retry_count = get_variable(retry, CHATCHAT_CLIENT_DEFAULT_RETRY_COUNT)
|
|||
|
|
self.default_retry_interval = get_variable(retry_interval, CHATCHAT_CLIENT_DEFAULT_RETRY_INTERVAL)
|
|||
|
|
self.proxies = proxies
|
|||
|
|
self._client = None
|
|||
|
|
self.logger = logging.getLogger(__name__)
|
|||
|
|
self.logger.setLevel(log_level)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def client(self):
|
|||
|
|
if self._client is None or self._client.is_closed:
|
|||
|
|
self._client = get_httpx_client(
|
|||
|
|
base_url=self.base_url, use_async=self._use_async, timeout=self.timeout
|
|||
|
|
)
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
def _get(
|
|||
|
|
self,
|
|||
|
|
url: str,
|
|||
|
|
params: Union[Dict, List[Tuple], bytes] = None,
|
|||
|
|
retry: int = 3,
|
|||
|
|
stream: bool = False,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
|||
|
|
while retry > 0:
|
|||
|
|
try:
|
|||
|
|
if stream:
|
|||
|
|
return self.client.stream("GET", url, params=params, **kwargs)
|
|||
|
|
else:
|
|||
|
|
return self.client.get(url, params=params, **kwargs)
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"error when get {url}: {e}"
|
|||
|
|
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
|||
|
|
retry -= 1
|
|||
|
|
|
|||
|
|
def _post(
|
|||
|
|
self,
|
|||
|
|
url: str,
|
|||
|
|
data: Dict = None,
|
|||
|
|
json: Dict = None,
|
|||
|
|
retry: int = 3,
|
|||
|
|
stream: bool = False,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
|||
|
|
while retry > 0:
|
|||
|
|
try:
|
|||
|
|
# print(kwargs)
|
|||
|
|
if stream:
|
|||
|
|
|
|||
|
|
return self.client.stream(
|
|||
|
|
"POST", url, data=data, json=json, **kwargs
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
self.logger.debug(f"post {url} with data: {data}")
|
|||
|
|
return self.client.post(url, data=data, json=json, **kwargs)
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"error when post {url}: {e}"
|
|||
|
|
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
|||
|
|
retry -= 1
|
|||
|
|
|
|||
|
|
def _delete(
|
|||
|
|
self,
|
|||
|
|
url: str,
|
|||
|
|
data: Dict = None,
|
|||
|
|
json: Dict = None,
|
|||
|
|
retry: int = 3,
|
|||
|
|
stream: bool = False,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
|||
|
|
while retry > 0:
|
|||
|
|
try:
|
|||
|
|
if stream:
|
|||
|
|
return self.client.stream(
|
|||
|
|
"DELETE", url, data=data, json=json, **kwargs
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
return self.client.delete(url, data=data, json=json, **kwargs)
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"error when delete {url}: {e}"
|
|||
|
|
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
|||
|
|
retry -= 1
|
|||
|
|
|
|||
|
|
def _httpx_stream2generator(
|
|||
|
|
self,
|
|||
|
|
response: contextlib._GeneratorContextManager,
|
|||
|
|
as_json: bool = False,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
async def ret_async(response, as_json):
|
|||
|
|
try:
|
|||
|
|
async with response as r:
|
|||
|
|
chunk_cache = ""
|
|||
|
|
async for chunk in r.aiter_text(None):
|
|||
|
|
if not chunk: # fastchat api yield empty bytes on start and end
|
|||
|
|
continue
|
|||
|
|
if as_json:
|
|||
|
|
try:
|
|||
|
|
if chunk.startswith("data: "):
|
|||
|
|
data = json.loads(chunk_cache + chunk[6:-2])
|
|||
|
|
elif chunk.startswith(":"): # skip sse comment line
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
data = json.loads(chunk_cache + chunk)
|
|||
|
|
|
|||
|
|
chunk_cache = ""
|
|||
|
|
yield data
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
|||
|
|
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
|||
|
|
|
|||
|
|
if chunk.startswith("data: "):
|
|||
|
|
chunk_cache += chunk[6:-2]
|
|||
|
|
elif chunk.startswith(":"): # skip sse comment line
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
chunk_cache += chunk
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
# print(chunk, end="", flush=True)
|
|||
|
|
yield chunk
|
|||
|
|
except httpx.ConnectError as e:
|
|||
|
|
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
|||
|
|
self.logger.error(msg)
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
except httpx.ReadTimeout as e:
|
|||
|
|
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})"
|
|||
|
|
self.logger.error(msg)
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"API通信遇到错误:{e}"
|
|||
|
|
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
|
|||
|
|
def ret_sync(response, as_json):
|
|||
|
|
try:
|
|||
|
|
with response as r:
|
|||
|
|
chunk_cache = ""
|
|||
|
|
for chunk in r.iter_text(None):
|
|||
|
|
if not chunk: # fastchat api yield empty bytes on start and end
|
|||
|
|
continue
|
|||
|
|
if as_json:
|
|||
|
|
try:
|
|||
|
|
if chunk.startswith("data: "):
|
|||
|
|
data = json.loads(chunk_cache + chunk[6:-2])
|
|||
|
|
elif chunk.startswith(":"): # skip sse comment line
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
data = json.loads(chunk_cache + chunk)
|
|||
|
|
|
|||
|
|
chunk_cache = ""
|
|||
|
|
yield data
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
|||
|
|
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
|||
|
|
|
|||
|
|
if chunk.startswith("data: "):
|
|||
|
|
chunk_cache += chunk[6:-2]
|
|||
|
|
elif chunk.startswith(":"): # skip sse comment line
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
chunk_cache += chunk
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
# print(chunk, end="", flush=True)
|
|||
|
|
yield chunk
|
|||
|
|
except httpx.ConnectError as e:
|
|||
|
|
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
|||
|
|
self.logger.error(msg)
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
except httpx.ReadTimeout as e:
|
|||
|
|
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})"
|
|||
|
|
self.logger.error(msg)
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"API通信遇到错误:{e}"
|
|||
|
|
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
|
|||
|
|
if self._use_async:
|
|||
|
|
return ret_async(response, as_json)
|
|||
|
|
else:
|
|||
|
|
return ret_sync(response, as_json)
|
|||
|
|
|
|||
|
|
def _get_response_value(
|
|||
|
|
self,
|
|||
|
|
response: httpx.Response,
|
|||
|
|
as_json: bool = False,
|
|||
|
|
value_func: Callable = None,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
转换同步或异步请求返回的响应
|
|||
|
|
`as_json`: 返回json
|
|||
|
|
`value_func`: 用户可以自定义返回值,该函数接受response或json
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def to_json(r):
|
|||
|
|
try:
|
|||
|
|
return r.json()
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = "API未能返回正确的JSON。" + str(e)
|
|||
|
|
self.logger.error(f"{e.__class__.__name__}: {msg}")
|
|||
|
|
return {"code": 500, "msg": msg, "data": None}
|
|||
|
|
|
|||
|
|
if value_func is None:
|
|||
|
|
value_func = lambda r: r
|
|||
|
|
|
|||
|
|
async def ret_async(response):
|
|||
|
|
if as_json:
|
|||
|
|
return value_func(to_json(await response))
|
|||
|
|
else:
|
|||
|
|
return value_func(await response)
|
|||
|
|
|
|||
|
|
if self._use_async:
|
|||
|
|
return ret_async(response)
|
|||
|
|
else:
|
|||
|
|
if as_json:
|
|||
|
|
return value_func(to_json(response))
|
|||
|
|
else:
|
|||
|
|
return value_func(response)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_request_method(api_client_obj: ApiClient, method):
|
|||
|
|
if method is httpx.post:
|
|||
|
|
return getattr(api_client_obj, "_post")
|
|||
|
|
elif method is httpx.get:
|
|||
|
|
return getattr(api_client_obj, "_get")
|
|||
|
|
# elif method is httpx.put:
|
|||
|
|
# return api_client_obj.put
|
|||
|
|
elif method is httpx.delete:
|
|||
|
|
return getattr(api_client_obj, "_delete")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def http_request(method):
|
|||
|
|
def decorator(url, base_url='', headers=None, body_model: Type[BaseModel] = None, **options):
|
|||
|
|
headers = headers or {}
|
|||
|
|
|
|||
|
|
def wrapper(func):
|
|||
|
|
@wraps(func)
|
|||
|
|
def inner(*args, **kwargs):
|
|||
|
|
try:
|
|||
|
|
default_param: dict = get_function_default_params(func)
|
|||
|
|
|
|||
|
|
api_client_obj: ApiClient = args[0] if len(args) > 0 and isinstance(args[0], ApiClient) else None
|
|||
|
|
return_type = get_type_hints(func).get('return')
|
|||
|
|
full_url = base_url + url
|
|||
|
|
param = merge_dicts(kwargs, default_param)
|
|||
|
|
if body_model is not None:
|
|||
|
|
param = body_model(**kwargs).dict()
|
|||
|
|
# Send the HTTP request
|
|||
|
|
response = None
|
|||
|
|
if api_client_obj is not None:
|
|||
|
|
_method = get_request_method(api_client_obj, method)
|
|||
|
|
response = _method(full_url, headers=headers, json=param)
|
|||
|
|
else:
|
|||
|
|
response = method(full_url, headers=headers, json=param)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
return response.json()
|
|||
|
|
except requests.exceptions.HTTPError as http_err:
|
|||
|
|
print(f"HTTP error occurred: {http_err}")
|
|||
|
|
except Exception as err:
|
|||
|
|
print(f"An error occurred: {err}")
|
|||
|
|
|
|||
|
|
return inner
|
|||
|
|
|
|||
|
|
return wrapper
|
|||
|
|
|
|||
|
|
return decorator
|
|||
|
|
|
|||
|
|
|
|||
|
|
post = http_request(httpx.post)
|
|||
|
|
get = http_request(httpx.get)
|
|||
|
|
delete = http_request(httpx.delete)
|
|||
|
|
put = http_request(httpx.put)
|