Langchain-Chatchat/libs/python-sdk/open_chatcaht/utils.py

254 lines
7.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import inspect
import os
from io import BytesIO
from pathlib import Path
from typing import Union, List, Dict
import httpx
from pydantic import BaseModel
from typing_extensions import TypeGuard
from open_chatcaht._constants import HTTPX_TIMEOUT
def get_httpx_client(
use_async: bool = False,
proxies: Union[str, Dict] = None,
timeout: float = HTTPX_TIMEOUT,
unused_proxies: List[str] = [],
**kwargs,
) -> Union[httpx.Client, httpx.AsyncClient]:
"""
helper to get httpx client with default proxies that bypass local addesses.
"""
default_proxies = {
# do not use proxy for locahost
"all://127.0.0.1": None,
"all://localhost": None,
}
# do not use proxy for user deployed fastchat servers
for x in unused_proxies:
host = ":".join(x.split(":")[:2])
default_proxies.update({host: None})
# get proxies from system envionrent
# proxy not str empty string, None, False, 0, [] or {}
default_proxies.update(
{
"http://": (
os.environ.get("http_proxy")
if os.environ.get("http_proxy")
and len(os.environ.get("http_proxy").strip())
else None
),
"https://": (
os.environ.get("https_proxy")
if os.environ.get("https_proxy")
and len(os.environ.get("https_proxy").strip())
else None
),
"all://": (
os.environ.get("all_proxy")
if os.environ.get("all_proxy")
and len(os.environ.get("all_proxy").strip())
else None
),
}
)
for host in os.environ.get("no_proxy", "").split(","):
if host := host.strip():
# default_proxies.update({host: None}) # Origin code
default_proxies.update(
{"all://" + host: None}
) # PR 1838 fix, if not add 'all://', httpx will raise error
# merge default proxies with user provided proxies
if isinstance(proxies, str):
proxies = {"all://": proxies}
if isinstance(proxies, dict):
default_proxies.update(proxies)
# construct Client
kwargs.update(timeout=timeout, proxies=default_proxies)
if use_async:
return httpx.AsyncClient(**kwargs)
else:
return httpx.Client(**kwargs)
def set_httpx_config(
timeout: float = HTTPX_TIMEOUT,
proxy: Union[str, Dict] = None,
unused_proxies: List[str] = [],
):
"""
设置httpx默认timeout。httpx默认timeout是5秒在请求LLM回答时不够用。
将本项目相关服务加入无代理列表避免fastchat的服务器请求错误。(windows下无效)
对于chatgpt等在线API如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。
"""
import os
import httpx
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
# 在进程范围内设置系统级代理
proxies = {}
if isinstance(proxy, str):
for n in ["http", "https", "all"]:
proxies[n + "_proxy"] = proxy
elif isinstance(proxy, dict):
for n in ["http", "https", "all"]:
if p := proxy.get(n):
proxies[n + "_proxy"] = p
elif p := proxy.get(n + "_proxy"):
proxies[n + "_proxy"] = p
for k, v in proxies.items():
os.environ[k] = v
# set host to bypass proxy
no_proxy = [
x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()
]
no_proxy += [
# do not use proxy for locahost
"http://127.0.0.1",
"http://localhost",
]
# do not use proxy for user deployed fastchat servers
for x in unused_proxies:
host = ":".join(x.split(":")[:2])
if host not in no_proxy:
no_proxy.append(host)
os.environ["NO_PROXY"] = ",".join(no_proxy)
def _get_proxies():
return proxies
import urllib.request
urllib.request.getproxies = _get_proxies
def get_img_base64(file_path: str) -> str:
"""
get_img_base64 used in streamlit.
"""
image = file_path
# 读取图片
with open(image, "rb") as f:
buffer = BytesIO(f.read())
base_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/png;base64,{base_str}"
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
"""
return error message if error occured when requests API
"""
if (
isinstance(data, dict)
and key in data
and "code" in data
and data["code"] == 200
):
return data[key]
return ""
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
"""
return error message if error occured when requests API
"""
if isinstance(data, dict):
if key in data:
return data[key]
if "code" in data and data["code"] != 200:
return data["msg"]
return ""
def get_variable(*args):
for var in args:
if var:
return var
return None
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
return isinstance(obj, dict)
def model_to_dict(model: BaseModel) -> dict[str, object]:
return model.dict()
def get_function_default_params(func) -> dict:
"""
获取函数的参数及其默认值。
参数:
func (function): 要分析的函数。
返回:
dict: 一个包含参数名称及其默认值的字典。
"""
signature = inspect.signature(func)
params = signature.parameters
params_dict = {}
for param_name, param in params.items():
if param.default is inspect.Parameter.empty:
params_dict[param_name] = None
else:
params_dict[param_name] = param.default
return params_dict
def merge_dicts(dict1, dict2) -> dict:
"""
合并两个字典,优先使用第一个字典中的非空值。
参数:
dict1 (dict): 第一个字典。
dict2 (dict): 第二个字典。
返回:
dict: 合并后的字典。
"""
merged_dict = {}
# 遍历两个字典的键集合
all_keys = set(dict1.keys()).union(set(dict2.keys()))
for key in all_keys:
value1 = dict1.get(key)
value2 = dict2.get(key)
# 如果第一个字典中的值不为空,使用第一个字典的值
if value1:
merged_dict[key] = value1
else:
# 否则使用第二个字典中的值
merged_dict[key] = value2
return merged_dict
def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1]
return filename, file