diff --git a/utils/__init__.py b/utils/__init__.py index 37deccf..6cde439 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,5 +1,4 @@ -import torch.cuda -import torch.backends +import torch def torch_gc(DEVICE): if torch.cuda.is_available(): @@ -8,7 +7,6 @@ def torch_gc(DEVICE): torch.cuda.ipc_collect() elif torch.backends.mps.is_available(): try: - import torch.mps torch.mps.empty_cache() except Exception as e: print(e)