Merge pull request #1654 from glide-the/master

修复代理为空的问题
This commit is contained in:
glide-the 2023-10-04 14:04:11 +08:00 committed by GitHub
commit 295783b549
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 46 additions and 36 deletions

View File

@ -14,7 +14,6 @@ from langchain.chat_models import ChatOpenAI
import httpx import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
thread_pool = ThreadPoolExecutor(os.cpu_count()) thread_pool = ThreadPoolExecutor(os.cpu_count())
@ -68,6 +67,7 @@ class BaseResponse(BaseModel):
} }
} }
class ListResponse(BaseResponse): class ListResponse(BaseResponse):
data: List[str] = pydantic.Field(..., description="List of names") data: List[str] = pydantic.Field(..., description="List of names")
@ -115,6 +115,7 @@ class ChatMessage(BaseModel):
} }
} }
def torch_gc(): def torch_gc():
import torch import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -126,11 +127,12 @@ def torch_gc():
from torch.mps import empty_cache from torch.mps import empty_cache
empty_cache() empty_cache()
except Exception as e: except Exception as e:
msg=("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
"以支持及时清理 torch 产生的内存占用。") "以支持及时清理 torch 产生的内存占用。")
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
def run_async(cor): def run_async(cor):
''' '''
在同步环境中运行异步代码. 在同步环境中运行异步代码.
@ -147,12 +149,14 @@ def iter_over_async(ait, loop):
将异步生成器封装成同步生成器. 将异步生成器封装成同步生成器.
''' '''
ait = ait.__aiter__() ait = ait.__aiter__()
async def get_next(): async def get_next():
try: try:
obj = await ait.__anext__() obj = await ait.__anext__()
return False, obj return False, obj
except StopAsyncIteration: except StopAsyncIteration:
return True, None return True, None
while True: while True:
done, obj = loop.run_until_complete(get_next()) done, obj = loop.run_until_complete(get_next())
if done: if done:
@ -162,8 +166,8 @@ def iter_over_async(ait, loop):
def MakeFastAPIOffline( def MakeFastAPIOffline(
app: FastAPI, app: FastAPI,
static_dir = Path(__file__).parent / "static", static_dir=Path(__file__).parent / "static",
static_url = "/static-offline-docs", static_url="/static-offline-docs",
docs_url: Optional[str] = "/docs", docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc", redoc_url: Optional[str] = "/redoc",
) -> None: ) -> None:
@ -245,6 +249,7 @@ def list_embed_models() -> List[str]:
''' '''
return list(MODEL_PATH["embed_model"]) return list(MODEL_PATH["embed_model"])
def list_llm_models() -> Dict[str, List[str]]: def list_llm_models() -> Dict[str, List[str]]:
''' '''
get names of configured llm models with different types. get names of configured llm models with different types.
@ -380,7 +385,7 @@ def get_prompt_template(name: str) -> Optional[str]:
def set_httpx_config( def set_httpx_config(
timeout: float = HTTPX_DEFAULT_TIMEOUT, timeout: float = HTTPX_DEFAULT_TIMEOUT,
proxy: Union[str, Dict] = None, proxy: Union[str, Dict] = None,
): ):
''' '''
设置httpx默认timeouthttpx默认timeout是5秒在请求LLM回答时不够用 设置httpx默认timeouthttpx默认timeout是5秒在请求LLM回答时不够用
将本项目相关服务加入无代理列表避免fastchat的服务器请求错误(windows下无效) 将本项目相关服务加入无代理列表避免fastchat的服务器请求错误(windows下无效)
@ -400,9 +405,9 @@ def set_httpx_config(
proxies[n + "_proxy"] = proxy proxies[n + "_proxy"] = proxy
elif isinstance(proxy, dict): elif isinstance(proxy, dict):
for n in ["http", "https", "all"]: for n in ["http", "https", "all"]:
if p:= proxy.get(n): if p := proxy.get(n):
proxies[n + "_proxy"] = p proxies[n + "_proxy"] = p
elif p:= proxy.get(n + "_proxy"): elif p := proxy.get(n + "_proxy"):
proxies[n + "_proxy"] = p proxies[n + "_proxy"] = p
for k, v in proxies.items(): for k, v in proxies.items():
@ -508,9 +513,15 @@ def get_httpx_client(
# get proxies from system envionrent # get proxies from system envionrent
# proxy not str empty string, None, False, 0, [] or {} # proxy not str empty string, None, False, 0, [] or {}
default_proxies.update({ default_proxies.update({
"http://": os.environ.get("http_proxy") if len(os.environ.get("http_proxy").strip()) > 0 else None, "http://": (os.environ.get("http_proxy")
"https://": os.environ.get("https_proxy") if len(os.environ.get("https_proxy").strip()) > 0 else None, if os.environ.get("http_proxy") and len(os.environ.get("http_proxy").strip())
"all://": os.environ.get("all_proxy") if len(os.environ.get("all_proxy").strip()) > 0 else None, else None),
"https://": (os.environ.get("https_proxy")
if os.environ.get("https_proxy") and len(os.environ.get("https_proxy").strip())
else None),
"all://": (os.environ.get("all_proxy")
if os.environ.get("all_proxy") and len(os.environ.get("all_proxy").strip())
else None),
}) })
for host in os.environ.get("no_proxy", "").split(","): for host in os.environ.get("no_proxy", "").split(","):
if host := host.strip(): if host := host.strip():
@ -530,4 +541,3 @@ def get_httpx_client(
return httpx.AsyncClient(**kwargs) return httpx.AsyncClient(**kwargs)
else: else:
return httpx.Client(**kwargs) return httpx.Client(**kwargs)