将所有httpx请求改为使用Client,提高效率,方便以后设置代理等。 (#1554)

将所有httpx请求改为使用Client,提高效率,方便以后设置代理等。

将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
This commit is contained in:
liunux4odoo 2023-09-21 15:19:51 +08:00 committed by GitHub
parent 818cb1a491
commit e4a927c5d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 213 additions and 102 deletions

View File

@ -1,7 +1,7 @@
from fastapi import Body from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import BaseResponse, fschat_controller_address, list_llm_models from server.utils import BaseResponse, fschat_controller_address, list_llm_models, get_httpx_client
import httpx
def list_running_models( def list_running_models(
@ -13,7 +13,8 @@ def list_running_models(
''' '''
try: try:
controller_address = controller_address or fschat_controller_address() controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models") with get_httpx_client() as client:
r = client.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"]) return BaseResponse(data=r.json()["models"])
except Exception as e: except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}', logger.error(f'{e.__class__.__name__}: {e}',
@ -41,7 +42,8 @@ def stop_llm_model(
''' '''
try: try:
controller_address = controller_address or fschat_controller_address() controller_address = controller_address or fschat_controller_address()
r = httpx.post( with get_httpx_client() as client:
r = client.post(
controller_address + "/release_worker", controller_address + "/release_worker",
json={"model_name": model_name}, json={"model_name": model_name},
) )
@ -64,7 +66,8 @@ def change_llm_model(
''' '''
try: try:
controller_address = controller_address or fschat_controller_address() controller_address = controller_address or fschat_controller_address()
r = httpx.post( with get_httpx_client() as client:
r = client.post(
controller_address + "/release_worker", controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name}, json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model

View File

@ -2,7 +2,7 @@ from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json import json
import httpx from server.utils import get_httpx_client
from pprint import pprint from pprint import pprint
from typing import List, Dict from typing import List, Dict
@ -63,7 +63,8 @@ class MiniMaxWorker(ApiModelWorker):
} }
print("request data sent to minimax:") print("request data sent to minimax:")
pprint(data) pprint(data)
response = httpx.stream("POST", with get_httpx_client() as client:
response = client.stream("POST",
self.BASE_URL.format(pro=pro, group_id=group_id), self.BASE_URL.format(pro=pro, group_id=group_id),
headers=headers, headers=headers,
json=data) json=data)

View File

@ -5,7 +5,7 @@ import sys
import json import json
import httpx import httpx
from cachetools import cached, TTLCache from cachetools import cached, TTLCache
from server.utils import get_model_worker_config from server.utils import get_model_worker_config, get_httpx_client
from typing import List, Literal, Dict from typing import List, Literal, Dict
@ -54,7 +54,8 @@ def get_baidu_access_token(api_key: str, secret_key: str) -> str:
url = "https://aip.baidubce.com/oauth/2.0/token" url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
try: try:
return httpx.get(url, params=params).json().get("access_token") with get_httpx_client() as client:
return client.get(url, params=params).json().get("access_token")
except Exception as e: except Exception as e:
print(f"failed to get token from baidu: {e}") print(f"failed to get token from baidu: {e}")
@ -91,7 +92,8 @@ def request_qianfan_api(
'Accept': 'application/json', 'Accept': 'application/json',
} }
with httpx.stream("POST", url, headers=headers, json=payload) as response: with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=payload) as response:
for line in response.iter_lines(): for line in response.iter_lines():
if not line.strip(): if not line.strip():
continue continue

View File

@ -7,11 +7,12 @@ import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
logger, log_verbose, logger, log_verbose,
FSCHAT_MODEL_WORKERS) FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
thread_pool = ThreadPoolExecutor(os.cpu_count()) thread_pool = ThreadPoolExecutor(os.cpu_count())
@ -376,19 +377,63 @@ def get_prompt_template(name: str) -> Optional[str]:
return prompt_config.PROMPT_TEMPLATES.get(name) return prompt_config.PROMPT_TEMPLATES.get(name)
def set_httpx_timeout(timeout: float = None): def set_httpx_config(
timeout: float = HTTPX_DEFAULT_TIMEOUT,
proxy: Union[str, Dict] = None,
):
''' '''
设置httpx默认timeout 设置httpx默认timeouthttpx默认timeout是5秒在请求LLM回答时不够用
httpx默认timeout是5秒在请求LLM回答时不够用 将本项目相关服务加入无代理列表避免fastchat的服务器请求错误(windows下无效)
对于chatgpt等在线API如要使用代理需要手动配置搜索引擎的代理如何处置还需考虑
''' '''
import httpx import httpx
from configs.server_config import HTTPX_DEFAULT_TIMEOUT import os
timeout = timeout or HTTPX_DEFAULT_TIMEOUT
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
# 在进程范围内设置系统级代理
proxies = {}
if isinstance(proxy, str):
for n in ["http", "https", "all"]:
proxies[n + "_proxy"] = proxy
elif isinstance(proxy, dict):
for n in ["http", "https", "all"]:
if p:= proxy.get(n):
proxies[n + "_proxy"] = p
elif p:= proxy.get(n + "_proxy"):
proxies[n + "_proxy"] = p
for k, v in proxies.items():
os.environ[k] = v
# set host to bypass proxy
no_proxy = [x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()]
no_proxy += [
# do not use proxy for locahost
"http://127.0.0.1",
"http://localhost",
]
# do not use proxy for user deployed fastchat servers
for x in [
fschat_controller_address(),
fschat_model_worker_address(),
fschat_openai_api_address(),
]:
host = ":".join(x.split(":")[:2])
if host not in no_proxy:
no_proxy.append(host)
os.environ["NO_PROXY"] = ",".join(no_proxy)
# TODO: 简单的清除系统代理不是个好的选择影响太多。似乎修改代理服务器的bypass列表更好。
# patch requests to use custom proxies instead of system settings
# def _get_proxies():
# return {}
# import urllib.request
# urllib.request.getproxies = _get_proxies
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch # 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]: def detect_device() -> Literal["cuda", "mps", "cpu"]:
@ -436,3 +481,51 @@ def run_in_thread_pool(
for obj in as_completed(tasks): for obj in as_completed(tasks):
yield obj.result() yield obj.result()
def get_httpx_client(
use_async: bool = False,
proxies: Union[str, Dict] = None,
timeout: float = HTTPX_DEFAULT_TIMEOUT,
**kwargs,
) -> Union[httpx.Client, httpx.AsyncClient]:
'''
helper to get httpx client with default proxies that bypass local addesses.
'''
default_proxies = {
# do not use proxy for locahost
"all://127.0.0.1": None,
"all://localhost": None,
}
# do not use proxy for user deployed fastchat servers
for x in [
fschat_controller_address(),
fschat_model_worker_address(),
fschat_openai_api_address(),
]:
host = ":".join(x.split(":")[:2])
default_proxies.update({host: None})
# get proxies from system envionrent
default_proxies.update({
"http://": os.environ.get("http_proxy"),
"https://": os.environ.get("https_proxy"),
"all://": os.environ.get("all_proxy"),
})
for host in os.environ.get("no_proxy", "").split(","):
if host := host.strip():
default_proxies.update({host: None})
# merge default proxies with user provided proxies
if isinstance(proxies, str):
proxies = {"all://": proxies}
if isinstance(proxies, dict):
default_proxies.update(proxies)
# construct Client
kwargs.update(timeout=timeout, proxies=default_proxies)
if use_async:
return httpx.AsyncClient(**kwargs)
else:
return httpx.Client(**kwargs)

View File

@ -31,7 +31,7 @@ from configs import (
HTTPX_DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT,
) )
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout, fschat_openai_api_address, set_httpx_config, get_httpx_client,
get_model_worker_config, get_all_model_worker_configs, get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device) MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse import argparse
@ -203,7 +203,6 @@ def create_openai_api_app(
def _set_app_event(app: FastAPI, started_event: mp.Event = None): def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@app.on_event("startup") @app.on_event("startup")
async def on_startup(): async def on_startup():
set_httpx_timeout()
if started_event is not None: if started_event is not None:
started_event.set() started_event.set()
@ -214,6 +213,8 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
from fastapi import Body from fastapi import Body
import time import time
import sys import sys
from server.utils import set_httpx_config
set_httpx_config()
app = create_controller_app( app = create_controller_app(
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"), dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
@ -251,7 +252,8 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
logger.error(msg) logger.error(msg)
return {"code": 500, "msg": msg} return {"code": 500, "msg": msg}
r = httpx.post(worker_address + "/release", with get_httpx_client() as client:
r = client.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin}) json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200: if r.status_code != 200:
msg = f"failed to release model: {model_name}" msg = f"failed to release model: {model_name}"
@ -299,6 +301,8 @@ def run_model_worker(
import uvicorn import uvicorn
from fastapi import Body from fastapi import Body
import sys import sys
from server.utils import set_httpx_config
set_httpx_config()
kwargs = get_model_worker_config(model_name) kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host") host = kwargs.pop("host")
@ -337,6 +341,8 @@ def run_model_worker(
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None): def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn import uvicorn
import sys import sys
from server.utils import set_httpx_config
set_httpx_config()
controller_addr = fschat_controller_address() controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet. app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
@ -353,6 +359,8 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
def run_api_server(started_event: mp.Event = None): def run_api_server(started_event: mp.Event = None):
from server.api import create_app from server.api import create_app
import uvicorn import uvicorn
from server.utils import set_httpx_config
set_httpx_config()
app = create_app() app = create_app()
_set_app_event(app, started_event) _set_app_event(app, started_event)
@ -364,6 +372,9 @@ def run_api_server(started_event: mp.Event = None):
def run_webui(started_event: mp.Event = None): def run_webui(started_event: mp.Event = None):
from server.utils import set_httpx_config
set_httpx_config()
host = WEBUI_SERVER["host"] host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"] port = WEBUI_SERVER["port"]

View File

@ -26,7 +26,7 @@ import contextlib
import json import json
import os import os
from io import BytesIO from io import BytesIO
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address from server.utils import run_async, iter_over_async, set_httpx_config, api_address, get_httpx_client
from configs.model_config import NLTK_DATA_PATH from configs.model_config import NLTK_DATA_PATH
import nltk import nltk
@ -35,7 +35,7 @@ from pprint import pprint
KB_ROOT_PATH = Path(KB_ROOT_PATH) KB_ROOT_PATH = Path(KB_ROOT_PATH)
set_httpx_timeout() set_httpx_config()
class ApiRequest: class ApiRequest:
@ -53,6 +53,8 @@ class ApiRequest:
self.base_url = base_url self.base_url = base_url
self.timeout = timeout self.timeout = timeout
self.no_remote_api = no_remote_api self.no_remote_api = no_remote_api
self._client = get_httpx_client()
self._aclient = get_httpx_client(use_async=True)
if no_remote_api: if no_remote_api:
logger.warn("将来可能取消对no_remote_api的支持更新版本时请注意。") logger.warn("将来可能取消对no_remote_api的支持更新版本时请注意。")
@ -79,9 +81,9 @@ class ApiRequest:
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
return httpx.stream("GET", url, params=params, **kwargs) return self._client.stream("GET", url, params=params, **kwargs)
else: else:
return httpx.get(url, params=params, **kwargs) return self._client.get(url, params=params, **kwargs)
except Exception as e: except Exception as e:
msg = f"error when get {url}: {e}" msg = f"error when get {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
@ -98,13 +100,13 @@ class ApiRequest:
) -> Union[httpx.Response, None]: ) -> Union[httpx.Response, None]:
url = self._parse_url(url) url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout) kwargs.setdefault("timeout", self.timeout)
async with httpx.AsyncClient() as client:
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
return await client.stream("GET", url, params=params, **kwargs) return await self._aclient.stream("GET", url, params=params, **kwargs)
else: else:
return await client.get(url, params=params, **kwargs) return await self._aclient.get(url, params=params, **kwargs)
except Exception as e: except Exception as e:
msg = f"error when aget {url}: {e}" msg = f"error when aget {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
@ -124,11 +126,10 @@ class ApiRequest:
kwargs.setdefault("timeout", self.timeout) kwargs.setdefault("timeout", self.timeout)
while retry > 0: while retry > 0:
try: try:
# return requests.post(url, data=data, json=json, stream=stream, **kwargs)
if stream: if stream:
return httpx.stream("POST", url, data=data, json=json, **kwargs) return self._client.stream("POST", url, data=data, json=json, **kwargs)
else: else:
return httpx.post(url, data=data, json=json, **kwargs) return self._client.post(url, data=data, json=json, **kwargs)
except Exception as e: except Exception as e:
msg = f"error when post {url}: {e}" msg = f"error when post {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
@ -146,13 +147,13 @@ class ApiRequest:
) -> Union[httpx.Response, None]: ) -> Union[httpx.Response, None]:
url = self._parse_url(url) url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout) kwargs.setdefault("timeout", self.timeout)
async with httpx.AsyncClient() as client:
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
return await client.stream("POST", url, data=data, json=json, **kwargs) return await self._client.stream("POST", url, data=data, json=json, **kwargs)
else: else:
return await client.post(url, data=data, json=json, **kwargs) return await self._client.post(url, data=data, json=json, **kwargs)
except Exception as e: except Exception as e:
msg = f"error when apost {url}: {e}" msg = f"error when apost {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
@ -173,9 +174,9 @@ class ApiRequest:
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
return httpx.stream("DELETE", url, data=data, json=json, **kwargs) return self._client.stream("DELETE", url, data=data, json=json, **kwargs)
else: else:
return httpx.delete(url, data=data, json=json, **kwargs) return self._client.delete(url, data=data, json=json, **kwargs)
except Exception as e: except Exception as e:
msg = f"error when delete {url}: {e}" msg = f"error when delete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
@ -193,13 +194,13 @@ class ApiRequest:
) -> Union[httpx.Response, None]: ) -> Union[httpx.Response, None]:
url = self._parse_url(url) url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout) kwargs.setdefault("timeout", self.timeout)
async with httpx.AsyncClient() as client:
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
return await client.stream("DELETE", url, data=data, json=json, **kwargs) return await self._aclient.stream("DELETE", url, data=data, json=json, **kwargs)
else: else:
return await client.delete(url, data=data, json=json, **kwargs) return await self._aclient.delete(url, data=data, json=json, **kwargs)
except Exception as e: except Exception as e:
msg = f"error when adelete {url}: {e}" msg = f"error when adelete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',