14 lines
512 B
Python
14 lines
512 B
Python
import torch
|
|
|
|
def torch_gc():
|
|
if torch.cuda.is_available():
|
|
# with torch.cuda.device(DEVICE):
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
elif torch.backends.mps.is_available():
|
|
try:
|
|
from torch.mps import empty_cache
|
|
empty_cache()
|
|
except Exception as e:
|
|
print(e)
|
|
print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") |