fix: automatically replace unsupported torch device (#2514)
This commit is contained in:
parent
b5064813af
commit
e7bba6bd0a
|
|
@ -498,23 +498,39 @@ def set_httpx_config(
|
|||
|
||||
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
||||
|
||||
def is_mps_available():
|
||||
import torch
|
||||
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
|
||||
def is_cuda_available():
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
|
||||
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
if is_cuda_available():
|
||||
return "cuda"
|
||||
if torch.backends.mps.is_available():
|
||||
if is_mps_available():
|
||||
return "mps"
|
||||
except:
|
||||
pass
|
||||
return "cpu"
|
||||
|
||||
|
||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||
device = device or LLM_DEVICE
|
||||
|
||||
# fallback to available device if specified device is not available
|
||||
if device == 'cuda' and not is_cuda_available() and is_mps_available():
|
||||
logging.warning("cuda is not available, fallback to mps")
|
||||
return "mps"
|
||||
if device == 'mps' and not is_mps_available() and is_cuda_available():
|
||||
logging.warning("mps is not available, fallback to cuda")
|
||||
return "cuda"
|
||||
|
||||
# auto detect device if not specified
|
||||
if device not in ["cuda", "mps", "cpu"]:
|
||||
device = detect_device()
|
||||
return detect_device()
|
||||
|
||||
return device
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue