Fix device detection and fallback logic and add 'xpu' (#2570)
Co-authored-by: zR <2448370773@qq.com>
This commit is contained in:
parent
e7bba6bd0a
commit
b653c25fbc
|
|
@ -10,10 +10,11 @@ from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
|
from langchain.llms import OpenAI
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
|
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
|
||||||
import logging
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
|
|
@ -58,6 +59,7 @@ def get_ChatOpenAI(
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_OpenAI(
|
def get_OpenAI(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
|
|
@ -151,7 +153,6 @@ class ChatMessage(BaseModel):
|
||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
try:
|
try:
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
# with torch.cuda.device(DEVICE):
|
# with torch.cuda.device(DEVICE):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
@ -498,17 +499,18 @@ def set_httpx_config(
|
||||||
|
|
||||||
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
||||||
|
|
||||||
|
|
||||||
def is_mps_available():
|
def is_mps_available():
|
||||||
import torch
|
|
||||||
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||||
|
|
||||||
|
|
||||||
def is_cuda_available():
|
def is_cuda_available():
|
||||||
import torch
|
|
||||||
return torch.cuda.is_available()
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||||
try:
|
try:
|
||||||
if is_cuda_available():
|
if torch.cuda.is_available():
|
||||||
return "cuda"
|
return "cuda"
|
||||||
if is_mps_available():
|
if is_mps_available():
|
||||||
return "mps"
|
return "mps"
|
||||||
|
|
@ -516,11 +518,31 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||||
pass
|
pass
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
|
||||||
device = device or LLM_DEVICE
|
|
||||||
|
|
||||||
# fallback to available device if specified device is not available
|
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||||
if device == 'cuda' and not is_cuda_available() and is_mps_available():
|
device = device or LLM_DEVICE
|
||||||
|
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
||||||
|
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
|
||||||
|
device = detect_device()
|
||||||
|
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
|
||||||
|
logging.warning("cuda is not available, fallback to mps")
|
||||||
|
return "mps"
|
||||||
|
if device == 'mps' and not is_mps_available() and is_cuda_available():
|
||||||
|
logging.warning("mps is not available, fallback to cuda")
|
||||||
|
return "cuda"
|
||||||
|
|
||||||
|
# auto detect device if not specified
|
||||||
|
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
||||||
|
return detect_device()
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||||
|
device = device or LLM_DEVICE
|
||||||
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
|
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
|
||||||
|
device = detect_device()
|
||||||
|
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
|
||||||
logging.warning("cuda is not available, fallback to mps")
|
logging.warning("cuda is not available, fallback to mps")
|
||||||
return "mps"
|
return "mps"
|
||||||
if device == 'mps' and not is_mps_available() and is_cuda_available():
|
if device == 'mps' and not is_mps_available() and is_cuda_available():
|
||||||
|
|
@ -530,14 +552,6 @@ def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||||
# auto detect device if not specified
|
# auto detect device if not specified
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
return detect_device()
|
return detect_device()
|
||||||
|
|
||||||
return device
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
|
||||||
device = device or EMBEDDING_DEVICE
|
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
|
||||||
device = detect_device()
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue