254 lines
7.0 KiB
Python
254 lines
7.0 KiB
Python
import base64
|
||
import inspect
|
||
import os
|
||
from io import BytesIO
|
||
from pathlib import Path
|
||
from typing import Union, List, Dict
|
||
|
||
import httpx
|
||
from pydantic import BaseModel
|
||
from typing_extensions import TypeGuard
|
||
|
||
from open_chatcaht._constants import HTTPX_TIMEOUT
|
||
|
||
|
||
def get_httpx_client(
|
||
use_async: bool = False,
|
||
proxies: Union[str, Dict] = None,
|
||
timeout: float = HTTPX_TIMEOUT,
|
||
unused_proxies: List[str] = [],
|
||
**kwargs,
|
||
) -> Union[httpx.Client, httpx.AsyncClient]:
|
||
"""
|
||
helper to get httpx client with default proxies that bypass local addesses.
|
||
"""
|
||
default_proxies = {
|
||
# do not use proxy for locahost
|
||
"all://127.0.0.1": None,
|
||
"all://localhost": None,
|
||
}
|
||
# do not use proxy for user deployed fastchat servers
|
||
for x in unused_proxies:
|
||
host = ":".join(x.split(":")[:2])
|
||
default_proxies.update({host: None})
|
||
|
||
# get proxies from system envionrent
|
||
# proxy not str empty string, None, False, 0, [] or {}
|
||
default_proxies.update(
|
||
{
|
||
"http://": (
|
||
os.environ.get("http_proxy")
|
||
if os.environ.get("http_proxy")
|
||
and len(os.environ.get("http_proxy").strip())
|
||
else None
|
||
),
|
||
"https://": (
|
||
os.environ.get("https_proxy")
|
||
if os.environ.get("https_proxy")
|
||
and len(os.environ.get("https_proxy").strip())
|
||
else None
|
||
),
|
||
"all://": (
|
||
os.environ.get("all_proxy")
|
||
if os.environ.get("all_proxy")
|
||
and len(os.environ.get("all_proxy").strip())
|
||
else None
|
||
),
|
||
}
|
||
)
|
||
for host in os.environ.get("no_proxy", "").split(","):
|
||
if host := host.strip():
|
||
# default_proxies.update({host: None}) # Origin code
|
||
default_proxies.update(
|
||
{"all://" + host: None}
|
||
) # PR 1838 fix, if not add 'all://', httpx will raise error
|
||
|
||
# merge default proxies with user provided proxies
|
||
if isinstance(proxies, str):
|
||
proxies = {"all://": proxies}
|
||
|
||
if isinstance(proxies, dict):
|
||
default_proxies.update(proxies)
|
||
|
||
# construct Client
|
||
kwargs.update(timeout=timeout, proxies=default_proxies)
|
||
|
||
if use_async:
|
||
return httpx.AsyncClient(**kwargs)
|
||
else:
|
||
return httpx.Client(**kwargs)
|
||
|
||
|
||
def set_httpx_config(
|
||
timeout: float = HTTPX_TIMEOUT,
|
||
proxy: Union[str, Dict] = None,
|
||
unused_proxies: List[str] = [],
|
||
):
|
||
"""
|
||
设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。
|
||
将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
|
||
对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。
|
||
"""
|
||
|
||
import os
|
||
|
||
import httpx
|
||
|
||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||
|
||
# 在进程范围内设置系统级代理
|
||
proxies = {}
|
||
if isinstance(proxy, str):
|
||
for n in ["http", "https", "all"]:
|
||
proxies[n + "_proxy"] = proxy
|
||
elif isinstance(proxy, dict):
|
||
for n in ["http", "https", "all"]:
|
||
if p := proxy.get(n):
|
||
proxies[n + "_proxy"] = p
|
||
elif p := proxy.get(n + "_proxy"):
|
||
proxies[n + "_proxy"] = p
|
||
|
||
for k, v in proxies.items():
|
||
os.environ[k] = v
|
||
|
||
# set host to bypass proxy
|
||
no_proxy = [
|
||
x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()
|
||
]
|
||
no_proxy += [
|
||
# do not use proxy for locahost
|
||
"http://127.0.0.1",
|
||
"http://localhost",
|
||
]
|
||
# do not use proxy for user deployed fastchat servers
|
||
for x in unused_proxies:
|
||
host = ":".join(x.split(":")[:2])
|
||
if host not in no_proxy:
|
||
no_proxy.append(host)
|
||
os.environ["NO_PROXY"] = ",".join(no_proxy)
|
||
|
||
def _get_proxies():
|
||
return proxies
|
||
|
||
import urllib.request
|
||
|
||
urllib.request.getproxies = _get_proxies
|
||
|
||
|
||
def get_img_base64(file_path: str) -> str:
|
||
"""
|
||
get_img_base64 used in streamlit.
|
||
"""
|
||
image = file_path
|
||
# 读取图片
|
||
with open(image, "rb") as f:
|
||
buffer = BytesIO(f.read())
|
||
base_str = base64.b64encode(buffer.getvalue()).decode()
|
||
return f"data:image/png;base64,{base_str}"
|
||
|
||
|
||
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
|
||
"""
|
||
return error message if error occured when requests API
|
||
"""
|
||
if (
|
||
isinstance(data, dict)
|
||
and key in data
|
||
and "code" in data
|
||
and data["code"] == 200
|
||
):
|
||
return data[key]
|
||
return ""
|
||
|
||
|
||
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
||
"""
|
||
return error message if error occured when requests API
|
||
"""
|
||
if isinstance(data, dict):
|
||
if key in data:
|
||
return data[key]
|
||
if "code" in data and data["code"] != 200:
|
||
return data["msg"]
|
||
return ""
|
||
|
||
|
||
def get_variable(*args):
|
||
for var in args:
|
||
if var:
|
||
return var
|
||
return None
|
||
|
||
|
||
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
|
||
return isinstance(obj, dict)
|
||
|
||
|
||
def model_to_dict(model: BaseModel) -> dict[str, object]:
|
||
return model.dict()
|
||
|
||
|
||
def get_function_default_params(func) -> dict:
|
||
"""
|
||
获取函数的参数及其默认值。
|
||
|
||
参数:
|
||
func (function): 要分析的函数。
|
||
|
||
返回:
|
||
dict: 一个包含参数名称及其默认值的字典。
|
||
"""
|
||
signature = inspect.signature(func)
|
||
params = signature.parameters
|
||
params_dict = {}
|
||
|
||
for param_name, param in params.items():
|
||
if param.default is inspect.Parameter.empty:
|
||
params_dict[param_name] = None
|
||
else:
|
||
params_dict[param_name] = param.default
|
||
|
||
return params_dict
|
||
|
||
|
||
def merge_dicts(dict1, dict2) -> dict:
|
||
"""
|
||
合并两个字典,优先使用第一个字典中的非空值。
|
||
|
||
参数:
|
||
dict1 (dict): 第一个字典。
|
||
dict2 (dict): 第二个字典。
|
||
|
||
返回:
|
||
dict: 合并后的字典。
|
||
"""
|
||
merged_dict = {}
|
||
|
||
# 遍历两个字典的键集合
|
||
all_keys = set(dict1.keys()).union(set(dict2.keys()))
|
||
|
||
for key in all_keys:
|
||
value1 = dict1.get(key)
|
||
value2 = dict2.get(key)
|
||
|
||
# 如果第一个字典中的值不为空,使用第一个字典的值
|
||
if value1:
|
||
merged_dict[key] = value1
|
||
else:
|
||
# 否则使用第二个字典中的值
|
||
merged_dict[key] = value2
|
||
|
||
return merged_dict
|
||
|
||
|
||
def convert_file(file, filename=None):
|
||
if isinstance(file, bytes): # raw bytes
|
||
file = BytesIO(file)
|
||
elif hasattr(file, "read"): # a file io like object
|
||
filename = filename or file.name
|
||
else: # a local path
|
||
file = Path(file).absolute().open("rb")
|
||
filename = filename or os.path.split(file.name)[-1]
|
||
return filename, file |