diff --git a/utils/__init__.py b/utils/__init__.py index 8508c7d..37deccf 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,5 +1,4 @@ import torch.cuda -import torch.mps import torch.backends def torch_gc(DEVICE): @@ -8,4 +7,9 @@ def torch_gc(DEVICE): torch.cuda.empty_cache() torch.cuda.ipc_collect() elif torch.backends.mps.is_available(): - torch.mps.empty_cache() \ No newline at end of file + try: + import torch.mps + torch.mps.empty_cache() + except Exception as e: + print(e) + print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") \ No newline at end of file