修复代理为空的问题

This commit is contained in:
glide-the 2023-10-04 14:02:57 +08:00
parent db3efb306a
commit 84a4141dec
1 changed files with 46 additions and 36 deletions

View File

@ -7,14 +7,13 @@ import asyncio
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
logger, log_verbose, logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
import httpx import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
thread_pool = ThreadPoolExecutor(os.cpu_count()) thread_pool = ThreadPoolExecutor(os.cpu_count())
@ -33,12 +32,12 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
def get_ChatOpenAI( def get_ChatOpenAI(
model_name: str, model_name: str,
temperature: float, temperature: float,
streaming: bool = True, streaming: bool = True,
callbacks: List[Callable] = [], callbacks: List[Callable] = [],
verbose: bool = True, verbose: bool = True,
**kwargs: Any, **kwargs: Any,
) -> ChatOpenAI: ) -> ChatOpenAI:
config = get_model_worker_config(model_name) config = get_model_worker_config(model_name)
model = ChatOpenAI( model = ChatOpenAI(
@ -68,6 +67,7 @@ class BaseResponse(BaseModel):
} }
} }
class ListResponse(BaseResponse): class ListResponse(BaseResponse):
data: List[str] = pydantic.Field(..., description="List of names") data: List[str] = pydantic.Field(..., description="List of names")
@ -115,6 +115,7 @@ class ChatMessage(BaseModel):
} }
} }
def torch_gc(): def torch_gc():
import torch import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -126,11 +127,12 @@ def torch_gc():
from torch.mps import empty_cache from torch.mps import empty_cache
empty_cache() empty_cache()
except Exception as e: except Exception as e:
msg=("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
"以支持及时清理 torch 产生的内存占用。") "以支持及时清理 torch 产生的内存占用。")
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
def run_async(cor): def run_async(cor):
''' '''
在同步环境中运行异步代码. 在同步环境中运行异步代码.
@ -147,12 +149,14 @@ def iter_over_async(ait, loop):
将异步生成器封装成同步生成器. 将异步生成器封装成同步生成器.
''' '''
ait = ait.__aiter__() ait = ait.__aiter__()
async def get_next(): async def get_next():
try: try:
obj = await ait.__anext__() obj = await ait.__anext__()
return False, obj return False, obj
except StopAsyncIteration: except StopAsyncIteration:
return True, None return True, None
while True: while True:
done, obj = loop.run_until_complete(get_next()) done, obj = loop.run_until_complete(get_next())
if done: if done:
@ -161,11 +165,11 @@ def iter_over_async(ait, loop):
def MakeFastAPIOffline( def MakeFastAPIOffline(
app: FastAPI, app: FastAPI,
static_dir = Path(__file__).parent / "static", static_dir=Path(__file__).parent / "static",
static_url = "/static-offline-docs", static_url="/static-offline-docs",
docs_url: Optional[str] = "/docs", docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc", redoc_url: Optional[str] = "/redoc",
) -> None: ) -> None:
"""patch the FastAPI obj that doesn't rely on CDN for the documentation page""" """patch the FastAPI obj that doesn't rely on CDN for the documentation page"""
from fastapi import Request from fastapi import Request
@ -245,6 +249,7 @@ def list_embed_models() -> List[str]:
''' '''
return list(MODEL_PATH["embed_model"]) return list(MODEL_PATH["embed_model"])
def list_llm_models() -> Dict[str, List[str]]: def list_llm_models() -> Dict[str, List[str]]:
''' '''
get names of configured llm models with different types. get names of configured llm models with different types.
@ -268,23 +273,23 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
for v in MODEL_PATH.values(): for v in MODEL_PATH.values():
paths.update(v) paths.update(v)
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径 if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
path = Path(path_str) path = Path(path_str)
if path.is_dir(): # 任意绝对路径 if path.is_dir(): # 任意绝对路径
return str(path) return str(path)
root_path = Path(MODEL_ROOT_PATH) root_path = Path(MODEL_ROOT_PATH)
if root_path.is_dir(): if root_path.is_dir():
path = root_path / model_name path = root_path / model_name
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
return str(path) return str(path)
path = root_path / path_str path = root_path / path_str
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
return str(path) return str(path)
path = root_path / path_str.split("/")[-1] path = root_path / path_str.split("/")[-1]
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
return str(path) return str(path)
return path_str # THUDM/chatglm06b return path_str # THUDM/chatglm06b
# 从server_config中获取服务信息 # 从server_config中获取服务信息
@ -372,7 +377,7 @@ def get_prompt_template(name: str) -> Optional[str]:
''' '''
from configs import prompt_config from configs import prompt_config
import importlib import importlib
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载 importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
return prompt_config.PROMPT_TEMPLATES.get(name) return prompt_config.PROMPT_TEMPLATES.get(name)
@ -380,7 +385,7 @@ def get_prompt_template(name: str) -> Optional[str]:
def set_httpx_config( def set_httpx_config(
timeout: float = HTTPX_DEFAULT_TIMEOUT, timeout: float = HTTPX_DEFAULT_TIMEOUT,
proxy: Union[str, Dict] = None, proxy: Union[str, Dict] = None,
): ):
''' '''
设置httpx默认timeouthttpx默认timeout是5秒在请求LLM回答时不够用 设置httpx默认timeouthttpx默认timeout是5秒在请求LLM回答时不够用
将本项目相关服务加入无代理列表避免fastchat的服务器请求错误(windows下无效) 将本项目相关服务加入无代理列表避免fastchat的服务器请求错误(windows下无效)
@ -400,9 +405,9 @@ def set_httpx_config(
proxies[n + "_proxy"] = proxy proxies[n + "_proxy"] = proxy
elif isinstance(proxy, dict): elif isinstance(proxy, dict):
for n in ["http", "https", "all"]: for n in ["http", "https", "all"]:
if p:= proxy.get(n): if p := proxy.get(n):
proxies[n + "_proxy"] = p proxies[n + "_proxy"] = p
elif p:= proxy.get(n + "_proxy"): elif p := proxy.get(n + "_proxy"):
proxies[n + "_proxy"] = p proxies[n + "_proxy"] = p
for k, v in proxies.items(): for k, v in proxies.items():
@ -463,9 +468,9 @@ def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
def run_in_thread_pool( def run_in_thread_pool(
func: Callable, func: Callable,
params: List[Dict] = [], params: List[Dict] = [],
pool: ThreadPoolExecutor = None, pool: ThreadPoolExecutor = None,
) -> Generator: ) -> Generator:
''' '''
在线程池中批量运行任务并将运行结果以生成器的形式返回 在线程池中批量运行任务并将运行结果以生成器的形式返回
@ -483,10 +488,10 @@ def run_in_thread_pool(
def get_httpx_client( def get_httpx_client(
use_async: bool = False, use_async: bool = False,
proxies: Union[str, Dict] = None, proxies: Union[str, Dict] = None,
timeout: float = HTTPX_DEFAULT_TIMEOUT, timeout: float = HTTPX_DEFAULT_TIMEOUT,
**kwargs, **kwargs,
) -> Union[httpx.Client, httpx.AsyncClient]: ) -> Union[httpx.Client, httpx.AsyncClient]:
''' '''
helper to get httpx client with default proxies that bypass local addesses. helper to get httpx client with default proxies that bypass local addesses.
@ -508,9 +513,15 @@ def get_httpx_client(
# get proxies from system envionrent # get proxies from system envionrent
# proxy not str empty string, None, False, 0, [] or {} # proxy not str empty string, None, False, 0, [] or {}
default_proxies.update({ default_proxies.update({
"http://": os.environ.get("http_proxy") if len(os.environ.get("http_proxy").strip()) > 0 else None, "http://": (os.environ.get("http_proxy")
"https://": os.environ.get("https_proxy") if len(os.environ.get("https_proxy").strip()) > 0 else None, if os.environ.get("http_proxy") and len(os.environ.get("http_proxy").strip())
"all://": os.environ.get("all_proxy") if len(os.environ.get("all_proxy").strip()) > 0 else None, else None),
"https://": (os.environ.get("https_proxy")
if os.environ.get("https_proxy") and len(os.environ.get("https_proxy").strip())
else None),
"all://": (os.environ.get("all_proxy")
if os.environ.get("all_proxy") and len(os.environ.get("all_proxy").strip())
else None),
}) })
for host in os.environ.get("no_proxy", "").split(","): for host in os.environ.get("no_proxy", "").split(","):
if host := host.strip(): if host := host.strip():
@ -530,4 +541,3 @@ def get_httpx_client(
return httpx.AsyncClient(**kwargs) return httpx.AsyncClient(**kwargs)
else: else:
return httpx.Client(**kwargs) return httpx.Client(**kwargs)