From e7bba6bd0a636ed6cafec5092f8b7e6236c8cdb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E5=8E=89=E5=AE=B3?= <1019933576@qq.com> Date: Thu, 11 Jan 2024 18:16:31 +0800 Subject: [PATCH] fix: automatically replace unsupported torch device (#2514) --- server/utils.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/server/utils.py b/server/utils.py index 21b1baf..376ddd8 100644 --- a/server/utils.py +++ b/server/utils.py @@ -498,23 +498,39 @@ def set_httpx_config( # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch +def is_mps_available(): + import torch + return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + +def is_cuda_available(): + import torch + return torch.cuda.is_available() def detect_device() -> Literal["cuda", "mps", "cpu"]: try: - import torch - if torch.cuda.is_available(): + if is_cuda_available(): return "cuda" - if torch.backends.mps.is_available(): + if is_mps_available(): return "mps" except: pass return "cpu" - def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: device = device or LLM_DEVICE + + # fallback to available device if specified device is not available + if device == 'cuda' and not is_cuda_available() and is_mps_available(): + logging.warning("cuda is not available, fallback to mps") + return "mps" + if device == 'mps' and not is_mps_available() and is_cuda_available(): + logging.warning("mps is not available, fallback to cuda") + return "cuda" + + # auto detect device if not specified if device not in ["cuda", "mps", "cpu"]: - device = detect_device() + return detect_device() + return device