|
import torch.cuda
|
|
import torch.mps
|
|
import torch.backends
|
|
|
|
def torch_gc(DEVICE):
|
|
if torch.cuda.is_available():
|
|
with torch.cuda.device(DEVICE):
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
elif torch.backends.mps.is_available():
|
|
torch.mps.empty_cache() |