Fix device detection and fallback logic and add 'xpu' (#2570)

Co-authored-by: zR <2448370773@qq.com>
This commit is contained in:
chatgpt-bot 2024-01-11 18:36:38 +08:00 committed by GitHub
parent e7bba6bd0a
commit b653c25fbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 17 deletions

View File

@ -10,10 +10,11 @@ from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
from langchain.llms import OpenAI
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
import logging
import torch
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -58,6 +59,7 @@ def get_ChatOpenAI(
)
return model
def get_OpenAI(
model_name: str,
temperature: float,
@ -151,7 +153,6 @@ class ChatMessage(BaseModel):
def torch_gc():
try:
import torch
if torch.cuda.is_available():
# with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
@ -498,17 +499,18 @@ def set_httpx_config(
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def is_mps_available():
import torch
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def is_cuda_available():
import torch
return torch.cuda.is_available()
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:
if is_cuda_available():
if torch.cuda.is_available():
return "cuda"
if is_mps_available():
return "mps"
@ -516,11 +518,31 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]:
pass
return "cpu"
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or LLM_DEVICE
# fallback to available device if specified device is not available
if device == 'cuda' and not is_cuda_available() and is_mps_available():
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu", "xpu"]:
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
device = detect_device()
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps")
return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available():
logging.warning("mps is not available, fallback to cuda")
return "cuda"
# auto detect device if not specified
if device not in ["cuda", "mps", "cpu", "xpu"]:
return detect_device()
return device
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu"]:
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
device = detect_device()
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps")
return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available():
@ -530,14 +552,6 @@ def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
# auto detect device if not specified
if device not in ["cuda", "mps", "cpu"]:
return detect_device()
return device
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or EMBEDDING_DEVICE
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device