diff --git a/server/utils.py b/server/utils.py index 376ddd8..5512ee6 100644 --- a/server/utils.py +++ b/server/utils.py @@ -10,10 +10,11 @@ from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE, import os from concurrent.futures import ThreadPoolExecutor, as_completed from langchain.chat_models import ChatOpenAI -from langchain.llms import OpenAI, AzureOpenAI, Anthropic +from langchain.llms import OpenAI import httpx from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple import logging +import torch async def wrap_done(fn: Awaitable, event: asyncio.Event): @@ -58,6 +59,7 @@ def get_ChatOpenAI( ) return model + def get_OpenAI( model_name: str, temperature: float, @@ -151,7 +153,6 @@ class ChatMessage(BaseModel): def torch_gc(): try: - import torch if torch.cuda.is_available(): # with torch.cuda.device(DEVICE): torch.cuda.empty_cache() @@ -498,17 +499,18 @@ def set_httpx_config( # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch + def is_mps_available(): - import torch return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + def is_cuda_available(): - import torch return torch.cuda.is_available() + def detect_device() -> Literal["cuda", "mps", "cpu"]: try: - if is_cuda_available(): + if torch.cuda.is_available(): return "cuda" if is_mps_available(): return "mps" @@ -516,11 +518,31 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]: pass return "cpu" -def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: - device = device or LLM_DEVICE - # fallback to available device if specified device is not available - if device == 'cuda' and not is_cuda_available() and is_mps_available(): +def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: + device = device or LLM_DEVICE + if device not in ["cuda", "mps", "cpu", "xpu"]: + logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}") + device = detect_device() + elif device == 'cuda' and not is_cuda_available() and is_mps_available(): + logging.warning("cuda is not available, fallback to mps") + return "mps" + if device == 'mps' and not is_mps_available() and is_cuda_available(): + logging.warning("mps is not available, fallback to cuda") + return "cuda" + + # auto detect device if not specified + if device not in ["cuda", "mps", "cpu", "xpu"]: + return detect_device() + return device + + +def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: + device = device or LLM_DEVICE + if device not in ["cuda", "mps", "cpu"]: + logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}") + device = detect_device() + elif device == 'cuda' and not is_cuda_available() and is_mps_available(): logging.warning("cuda is not available, fallback to mps") return "mps" if device == 'mps' and not is_mps_available() and is_cuda_available(): @@ -530,14 +552,6 @@ def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: # auto detect device if not specified if device not in ["cuda", "mps", "cpu"]: return detect_device() - - return device - - -def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: - device = device or EMBEDDING_DEVICE - if device not in ["cuda", "mps", "cpu"]: - device = detect_device() return device