update model_loader

This commit is contained in:
imClumsyPanda 2023-05-21 15:30:52 +08:00
parent 3712eec6a9
commit 871a871651
2 changed files with 13 additions and 10 deletions

View File

@ -1,9 +1,9 @@
import gc
# import gc
import traceback
from queue import Queue
from threading import Thread
import threading
from typing import Optional, List, Dict, Any
# from threading import Thread
# import threading
from typing import Optional, List, Dict, Any, TypeVar, Deque
from collections import deque
import torch
import transformers
@ -12,13 +12,16 @@ from models.extensions.thread_with_exception import ThreadWithException
import models.shared as shared
class LimitedLengthDict(dict):
K = TypeVar('K')
V = TypeVar('V')
class LimitedLengthDict(Dict[K, V]):
def __init__(self, maxlen=None, *args, **kwargs):
self.maxlen = maxlen
self._keys = deque()
self._keys: Deque[K] = deque()
super().__init__(*args, **kwargs)
def __setitem__(self, key, value):
def __setitem__(self, key: K, value: V):
if key not in self:
if self.maxlen is not None and len(self) >= self.maxlen:
oldest_key = self._keys.popleft()

View File

@ -139,9 +139,9 @@ class LoaderCheckPoint:
model = dispatch_model(model, device_map=self.device_map)
else:
print(
"Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "
"detected.\nFalling back to CPU mode.\n")
# print(
# "Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been "
# "detected.\nFalling back to CPU mode.\n")
model = (
AutoModel.from_pretrained(
checkpoint,