From 871a871651aeeebae4c942c65e647a1327398009 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sun, 21 May 2023 15:30:52 +0800 Subject: [PATCH] update model_loader --- models/extensions/callback.py | 17 ++++++++++------- models/loader/loader.py | 6 +++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/models/extensions/callback.py b/models/extensions/callback.py index 097caa8..7a17455 100644 --- a/models/extensions/callback.py +++ b/models/extensions/callback.py @@ -1,9 +1,9 @@ -import gc +# import gc import traceback from queue import Queue -from threading import Thread -import threading -from typing import Optional, List, Dict, Any +# from threading import Thread +# import threading +from typing import Optional, List, Dict, Any, TypeVar, Deque from collections import deque import torch import transformers @@ -12,13 +12,16 @@ from models.extensions.thread_with_exception import ThreadWithException import models.shared as shared -class LimitedLengthDict(dict): +K = TypeVar('K') +V = TypeVar('V') + +class LimitedLengthDict(Dict[K, V]): def __init__(self, maxlen=None, *args, **kwargs): self.maxlen = maxlen - self._keys = deque() + self._keys: Deque[K] = deque() super().__init__(*args, **kwargs) - def __setitem__(self, key, value): + def __setitem__(self, key: K, value: V): if key not in self: if self.maxlen is not None and len(self) >= self.maxlen: oldest_key = self._keys.popleft() diff --git a/models/loader/loader.py b/models/loader/loader.py index b6608f3..fffc66e 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -139,9 +139,9 @@ class LoaderCheckPoint: model = dispatch_model(model, device_map=self.device_map) else: - print( - "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been " - "detected.\nFalling back to CPU mode.\n") + # print( + # "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been " + # "detected.\nFalling back to CPU mode.\n") model = ( AutoModel.from_pretrained( checkpoint,