Langchain-Chatchat/libs/python-sdk/open_chatcaht/api_client.py

323 lines
12 KiB
Python
Raw Normal View History

2024-12-20 16:04:03 +08:00
import contextlib
import inspect
import json
import logging
import os
from typing import *
from open_chatcaht._constants import API_BASE_URI
from open_chatcaht.utils import set_httpx_config, get_httpx_client, get_variable, get_function_default_params, \
merge_dicts
from functools import wraps
from typing import Type, get_type_hints
import httpx
import requests
from pydantic import BaseModel
set_httpx_config()
CHATCHAT_API_BASE = get_variable(os.getenv('CHATCHAT_API_BASE'), 'http://127.0.0.1:8000')
CHATCHAT_CLIENT_TIME_OUT = get_variable(os.getenv('CHATCHAT_CLIENT_TIME_OUT'), 60)
CHATCHAT_CLIENT_DEFAULT_RETRY_COUNT = get_variable(os.getenv('CHATCHAT_CLIENT_DEFAULT_RETRY'), 3)
CHATCHAT_CLIENT_DEFAULT_RETRY_INTERVAL = get_variable(os.getenv('CHATCHAT_CLIENT_DEFAULT_RETRY_INTERVAL'), 60)
class ApiClient:
"""
api.py调用的封装同步模式,简化api调用方式
"""
def __init__(
self,
base_url: str = API_BASE_URI,
timeout: float = 60,
use_async: bool = False,
use_proxy: bool = False,
proxies=None,
log_level: int = logging.INFO,
retry: int = 3,
retry_interval: int = 1,
):
if proxies is None:
proxies = {}
self.base_url = get_variable(base_url, CHATCHAT_API_BASE)
self.timeout = get_variable(timeout, CHATCHAT_CLIENT_TIME_OUT)
self._use_async = use_async
self.use_proxy = use_proxy
self.default_retry_count = get_variable(retry, CHATCHAT_CLIENT_DEFAULT_RETRY_COUNT)
self.default_retry_interval = get_variable(retry_interval, CHATCHAT_CLIENT_DEFAULT_RETRY_INTERVAL)
self.proxies = proxies
self._client = None
self.logger = logging.getLogger(__name__)
self.logger.setLevel(log_level)
@property
def client(self):
if self._client is None or self._client.is_closed:
self._client = get_httpx_client(
base_url=self.base_url, use_async=self._use_async, timeout=self.timeout
)
return self._client
def _get(
self,
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
if stream:
return self.client.stream("GET", url, params=params, **kwargs)
else:
return self.client.get(url, params=params, **kwargs)
except Exception as e:
msg = f"error when get {url}: {e}"
self.logger.error(f"{e.__class__.__name__}: {msg}")
retry -= 1
def _post(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
# print(kwargs)
if stream:
return self.client.stream(
"POST", url, data=data, json=json, **kwargs
)
else:
self.logger.debug(f"post {url} with data: {data}")
return self.client.post(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when post {url}: {e}"
self.logger.error(f"{e.__class__.__name__}: {msg}")
retry -= 1
def _delete(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
if stream:
return self.client.stream(
"DELETE", url, data=data, json=json, **kwargs
)
else:
return self.client.delete(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when delete {url}: {e}"
self.logger.error(f"{e.__class__.__name__}: {msg}")
retry -= 1
def _httpx_stream2generator(
self,
response: contextlib._GeneratorContextManager,
as_json: bool = False,
):
"""
将httpx.stream返回的GeneratorContextManager转化为普通生成器
"""
async def ret_async(response, as_json):
try:
async with response as r:
chunk_cache = ""
async for chunk in r.aiter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
if chunk.startswith("data: "):
data = json.loads(chunk_cache + chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
continue
else:
data = json.loads(chunk_cache + chunk)
chunk_cache = ""
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
self.logger.error(f"{e.__class__.__name__}: {msg}")
if chunk.startswith("data: "):
chunk_cache += chunk[6:-2]
elif chunk.startswith(":"): # skip sse comment line
continue
else:
chunk_cache += chunk
continue
else:
# print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
self.logger.error(msg)
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见Wiki '5. 启动 API 服务或 Web UI')。({e}"
self.logger.error(msg)
yield {"code": 500, "msg": msg}
except Exception as e:
msg = f"API通信遇到错误{e}"
self.logger.error(f"{e.__class__.__name__}: {msg}")
yield {"code": 500, "msg": msg}
def ret_sync(response, as_json):
try:
with response as r:
chunk_cache = ""
for chunk in r.iter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
if chunk.startswith("data: "):
data = json.loads(chunk_cache + chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
continue
else:
data = json.loads(chunk_cache + chunk)
chunk_cache = ""
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
self.logger.error(f"{e.__class__.__name__}: {msg}")
if chunk.startswith("data: "):
chunk_cache += chunk[6:-2]
elif chunk.startswith(":"): # skip sse comment line
continue
else:
chunk_cache += chunk
continue
else:
# print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
self.logger.error(msg)
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见Wiki '5. 启动 API 服务或 Web UI')。({e}"
self.logger.error(msg)
yield {"code": 500, "msg": msg}
except Exception as e:
msg = f"API通信遇到错误{e}"
self.logger.error(f"{e.__class__.__name__}: {msg}")
yield {"code": 500, "msg": msg}
if self._use_async:
return ret_async(response, as_json)
else:
return ret_sync(response, as_json)
def _get_response_value(
self,
response: httpx.Response,
as_json: bool = False,
value_func: Callable = None,
):
"""
转换同步或异步请求返回的响应
`as_json`: 返回json
`value_func`: 用户可以自定义返回值该函数接受response或json
"""
def to_json(r):
try:
return r.json()
except Exception as e:
msg = "API未能返回正确的JSON。" + str(e)
self.logger.error(f"{e.__class__.__name__}: {msg}")
return {"code": 500, "msg": msg, "data": None}
if value_func is None:
value_func = lambda r: r
async def ret_async(response):
if as_json:
return value_func(to_json(await response))
else:
return value_func(await response)
if self._use_async:
return ret_async(response)
else:
if as_json:
return value_func(to_json(response))
else:
return value_func(response)
def get_request_method(api_client_obj: ApiClient, method):
if method is httpx.post:
return getattr(api_client_obj, "_post")
elif method is httpx.get:
return getattr(api_client_obj, "_get")
# elif method is httpx.put:
# return api_client_obj.put
elif method is httpx.delete:
return getattr(api_client_obj, "_delete")
def http_request(method):
def decorator(url, base_url='', headers=None, body_model: Type[BaseModel] = None, **options):
headers = headers or {}
def wrapper(func):
@wraps(func)
def inner(*args, **kwargs):
try:
default_param: dict = get_function_default_params(func)
api_client_obj: ApiClient = args[0] if len(args) > 0 and isinstance(args[0], ApiClient) else None
return_type = get_type_hints(func).get('return')
full_url = base_url + url
param = merge_dicts(kwargs, default_param)
if body_model is not None:
param = body_model(**kwargs).dict()
# Send the HTTP request
response = None
if api_client_obj is not None:
_method = get_request_method(api_client_obj, method)
response = _method(full_url, headers=headers, json=param)
else:
response = method(full_url, headers=headers, json=param)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
print(f"HTTP error occurred: {http_err}")
except Exception as err:
print(f"An error occurred: {err}")
return inner
return wrapper
return decorator
post = http_request(httpx.post)
get = http_request(httpx.get)
delete = http_request(httpx.delete)
put = http_request(httpx.put)