diff --git a/server/utils.py b/server/utils.py index 21b1baf..376ddd8 100644 --- a/server/utils.py +++ b/server/utils.py @@ -498,23 +498,39 @@ def set_httpx_config( # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch +def is_mps_available(): + import torch + return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + +def is_cuda_available(): + import torch + return torch.cuda.is_available() def detect_device() -> Literal["cuda", "mps", "cpu"]: try: - import torch - if torch.cuda.is_available(): + if is_cuda_available(): return "cuda" - if torch.backends.mps.is_available(): + if is_mps_available(): return "mps" except: pass return "cpu" - def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: device = device or LLM_DEVICE + + # fallback to available device if specified device is not available + if device == 'cuda' and not is_cuda_available() and is_mps_available(): + logging.warning("cuda is not available, fallback to mps") + return "mps" + if device == 'mps' and not is_mps_available() and is_cuda_available(): + logging.warning("mps is not available, fallback to cuda") + return "cuda" + + # auto detect device if not specified if device not in ["cuda", "mps", "cpu"]: - device = detect_device() + return detect_device() + return device