fix: automatically replace unsupported torch device (#2514)

This commit is contained in:
高厉害 2024-01-11 18:16:31 +08:00 committed by GitHub
parent b5064813af
commit e7bba6bd0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 5 deletions

View File

@ -498,23 +498,39 @@ def set_httpx_config(
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def is_mps_available():
import torch
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def is_cuda_available():
import torch
return torch.cuda.is_available()
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:
import torch
if torch.cuda.is_available():
if is_cuda_available():
return "cuda"
if torch.backends.mps.is_available():
if is_mps_available():
return "mps"
except:
pass
return "cpu"
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or LLM_DEVICE
# fallback to available device if specified device is not available
if device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps")
return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available():
logging.warning("mps is not available, fallback to cuda")
return "cuda"
# auto detect device if not specified
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return detect_device()
return device