Fix device detection and fallback logic and add 'xpu' (#2570)

Co-authored-by: zR <2448370773@qq.com>
This commit is contained in:
chatgpt-bot 2024-01-11 18:36:38 +08:00 committed by GitHub
parent e7bba6bd0a
commit b653c25fbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 17 deletions

View File

@ -10,10 +10,11 @@ from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI, AzureOpenAI, Anthropic from langchain.llms import OpenAI
import httpx import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
import logging import logging
import torch
async def wrap_done(fn: Awaitable, event: asyncio.Event): async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -58,6 +59,7 @@ def get_ChatOpenAI(
) )
return model return model
def get_OpenAI( def get_OpenAI(
model_name: str, model_name: str,
temperature: float, temperature: float,
@ -151,7 +153,6 @@ class ChatMessage(BaseModel):
def torch_gc(): def torch_gc():
try: try:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
# with torch.cuda.device(DEVICE): # with torch.cuda.device(DEVICE):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -498,17 +499,18 @@ def set_httpx_config(
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch # 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def is_mps_available(): def is_mps_available():
import torch
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def is_cuda_available(): def is_cuda_available():
import torch
return torch.cuda.is_available() return torch.cuda.is_available()
def detect_device() -> Literal["cuda", "mps", "cpu"]: def detect_device() -> Literal["cuda", "mps", "cpu"]:
try: try:
if is_cuda_available(): if torch.cuda.is_available():
return "cuda" return "cuda"
if is_mps_available(): if is_mps_available():
return "mps" return "mps"
@ -516,11 +518,31 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]:
pass pass
return "cpu" return "cpu"
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or LLM_DEVICE
# fallback to available device if specified device is not available def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
if device == 'cuda' and not is_cuda_available() and is_mps_available(): device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu", "xpu"]:
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
device = detect_device()
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps")
return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available():
logging.warning("mps is not available, fallback to cuda")
return "cuda"
# auto detect device if not specified
if device not in ["cuda", "mps", "cpu", "xpu"]:
return detect_device()
return device
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu"]:
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
device = detect_device()
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps") logging.warning("cuda is not available, fallback to mps")
return "mps" return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available(): if device == 'mps' and not is_mps_available() and is_cuda_available():
@ -530,14 +552,6 @@ def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
# auto detect device if not specified # auto detect device if not specified
if device not in ["cuda", "mps", "cpu"]: if device not in ["cuda", "mps", "cpu"]:
return detect_device() return detect_device()
return device
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or EMBEDDING_DEVICE
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device return device