221 lines
6.7 KiB
Python
221 lines
6.7 KiB
Python
|
|
import gc
|
|||
|
|
import traceback
|
|||
|
|
from queue import Queue
|
|||
|
|
from threading import Thread
|
|||
|
|
import threading
|
|||
|
|
from typing import Optional, List, Dict, Any
|
|||
|
|
from collections import deque
|
|||
|
|
import torch
|
|||
|
|
import transformers
|
|||
|
|
|
|||
|
|
from models.extensions.thread_with_exception import ThreadWithException
|
|||
|
|
import models.shared as shared
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LimitedLengthDict(dict):
|
|||
|
|
def __init__(self, maxlen=None, *args, **kwargs):
|
|||
|
|
self.maxlen = maxlen
|
|||
|
|
self._keys = deque()
|
|||
|
|
super().__init__(*args, **kwargs)
|
|||
|
|
|
|||
|
|
def __setitem__(self, key, value):
|
|||
|
|
if key not in self:
|
|||
|
|
if self.maxlen is not None and len(self) >= self.maxlen:
|
|||
|
|
oldest_key = self._keys.popleft()
|
|||
|
|
if oldest_key in self:
|
|||
|
|
del self[oldest_key]
|
|||
|
|
self._keys.append(key)
|
|||
|
|
super().__setitem__(key, value)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FixedLengthQueue:
|
|||
|
|
# 停止符号列表
|
|||
|
|
stop_sequence: Optional[str] = []
|
|||
|
|
# 缓冲区
|
|||
|
|
max_length: int = 0
|
|||
|
|
# 缓冲区容器
|
|||
|
|
queue: deque = None
|
|||
|
|
# 输入区容器
|
|||
|
|
queue_in: LimitedLengthDict[int, str] = {}
|
|||
|
|
# 输出区容器
|
|||
|
|
queue_out: Dict[int, str] = {}
|
|||
|
|
|
|||
|
|
def __new__(cls, *args, **kwargs):
|
|||
|
|
# 创建新的实例
|
|||
|
|
instance = super().__new__(cls)
|
|||
|
|
# 在这里可以对实例进行额外的设置
|
|||
|
|
return instance
|
|||
|
|
|
|||
|
|
def __init__(self, stop_sequence):
|
|||
|
|
if stop_sequence is None:
|
|||
|
|
self.stop_sequence = []
|
|||
|
|
self.max_length = 0
|
|||
|
|
elif isinstance(stop_sequence, str):
|
|||
|
|
self.stop_sequence = [stop_sequence]
|
|||
|
|
self.max_length = 1
|
|||
|
|
else:
|
|||
|
|
self.stop_sequence = stop_sequence
|
|||
|
|
self.max_length = len(''.join(stop_sequence))
|
|||
|
|
|
|||
|
|
self.queue = deque(maxlen=self.max_length)
|
|||
|
|
self.queue.clear()
|
|||
|
|
self.queue_in.clear()
|
|||
|
|
self.queue_out.clear()
|
|||
|
|
|
|||
|
|
def add(self, index, item):
|
|||
|
|
self.queue_in[index] = item
|
|||
|
|
|
|||
|
|
def _add_out(self, index, item):
|
|||
|
|
self.queue_out[index] = item
|
|||
|
|
|
|||
|
|
def put_replace_out(self, index):
|
|||
|
|
return self.queue_out[index]
|
|||
|
|
|
|||
|
|
def contains_replace_sequence(self):
|
|||
|
|
"""
|
|||
|
|
替换字符
|
|||
|
|
:return:
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
for key, value in self.queue_in.items():
|
|||
|
|
|
|||
|
|
word_index = value.rfind(":")
|
|||
|
|
if word_index != -1:
|
|||
|
|
value = value.replace(":", ":")
|
|||
|
|
|
|||
|
|
word_index = value.rfind("[")
|
|||
|
|
if word_index != -1:
|
|||
|
|
value = value.replace("[", "")
|
|||
|
|
|
|||
|
|
word_index = value.rfind("]")
|
|||
|
|
if word_index != -1:
|
|||
|
|
value = value.replace("]", "")
|
|||
|
|
|
|||
|
|
self._add_out(key, value)
|
|||
|
|
|
|||
|
|
def contains_stop_sequence(self):
|
|||
|
|
# 截取固定大小的数据判断
|
|||
|
|
self.queue.clear()
|
|||
|
|
last_three_keys = list(self.queue_out.keys())[-self.max_length:]
|
|||
|
|
joined_queue = ''.join([self.queue_out[key] for key in last_three_keys])
|
|||
|
|
for char in joined_queue:
|
|||
|
|
self.queue.append(char)
|
|||
|
|
|
|||
|
|
joined_queue = ''.join(self.queue)
|
|||
|
|
# Initialize a variable to store the index of the last found stop string
|
|||
|
|
last_stop_str_index = -1
|
|||
|
|
|
|||
|
|
# Iterate through the stop string list
|
|||
|
|
for stop_word in self.stop_sequence:
|
|||
|
|
# Find the last occurrence of the stop string in the output
|
|||
|
|
stop_word_index = joined_queue.rfind(stop_word)
|
|||
|
|
|
|||
|
|
# If the stop string is found, compare the index with the previously found index
|
|||
|
|
if stop_word_index != -1 and stop_word_index > last_stop_str_index:
|
|||
|
|
last_stop_str_index = stop_word_index
|
|||
|
|
|
|||
|
|
# Handle the last found stop string index here
|
|||
|
|
return last_stop_str_index
|
|||
|
|
|
|||
|
|
def __repr__(self):
|
|||
|
|
return str(self.queue)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
|||
|
|
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
|||
|
|
|
|||
|
|
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
|||
|
|
transformers.StoppingCriteria.__init__(self)
|
|||
|
|
self.sentinel_token_ids = sentinel_token_ids
|
|||
|
|
self.starting_idx = starting_idx
|
|||
|
|
|
|||
|
|
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
|||
|
|
for sample in input_ids:
|
|||
|
|
trimmed_sample = sample[self.starting_idx:]
|
|||
|
|
|
|||
|
|
for i in range(len(self.sentinel_token_ids)):
|
|||
|
|
# Can't unfold, output is still too tiny. Skip.
|
|||
|
|
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
|
|||
|
|
continue
|
|||
|
|
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
|
|||
|
|
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Stream(transformers.StoppingCriteria):
|
|||
|
|
def __init__(self, callback_func=None):
|
|||
|
|
self.callback_func = callback_func
|
|||
|
|
|
|||
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|||
|
|
if shared.stop_everything:
|
|||
|
|
raise ValueError
|
|||
|
|
if self.callback_func is not None:
|
|||
|
|
self.callback_func(input_ids[0])
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Iteratorize:
|
|||
|
|
"""
|
|||
|
|
Transforms a function that takes a callback
|
|||
|
|
into a lazy iterator (generator).
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
thread: ThreadWithException = None
|
|||
|
|
|
|||
|
|
def __new__(cls, *args, **kwargs):
|
|||
|
|
# 创建新的实例
|
|||
|
|
instance = super().__new__(cls)
|
|||
|
|
# 在这里可以对实例进行额外的设置
|
|||
|
|
return instance
|
|||
|
|
|
|||
|
|
def __init__(self, func, kwargs={}, callback=None):
|
|||
|
|
self.mfunc = func
|
|||
|
|
self.c_callback = callback
|
|||
|
|
self.q = Queue()
|
|||
|
|
self.sentinel = object()
|
|||
|
|
self.kwargs = kwargs
|
|||
|
|
|
|||
|
|
def _callback(val):
|
|||
|
|
if shared.stop_everything:
|
|||
|
|
raise ValueError
|
|||
|
|
self.q.put(val)
|
|||
|
|
|
|||
|
|
def gen():
|
|||
|
|
try:
|
|||
|
|
ret = self.mfunc(callback=_callback, **self.kwargs)
|
|||
|
|
except ValueError:
|
|||
|
|
print("print(ValueError)")
|
|||
|
|
except:
|
|||
|
|
traceback.print_exc()
|
|||
|
|
print("traceback.print_exc()")
|
|||
|
|
self.q.put(self.sentinel)
|
|||
|
|
|
|||
|
|
self.thread = ThreadWithException(target=gen)
|
|||
|
|
self.thread.start()
|
|||
|
|
|
|||
|
|
def __iter__(self):
|
|||
|
|
shared.stop_everything = False
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def __next__(self):
|
|||
|
|
obj = self.q.get(True, None)
|
|||
|
|
if obj is self.sentinel:
|
|||
|
|
raise StopIteration
|
|||
|
|
else:
|
|||
|
|
return obj
|
|||
|
|
|
|||
|
|
def __del__(self):
|
|||
|
|
shared.stop_everything = False
|
|||
|
|
self.q.empty()
|
|||
|
|
shared.loaderCheckPoint.clear_torch_cache()
|
|||
|
|
|
|||
|
|
def __enter__(self):
|
|||
|
|
shared.stop_everything = False
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|||
|
|
shared.stop_everything = True
|
|||
|
|
shared.loaderCheckPoint.clear_torch_cache()
|
|||
|
|
self.thread.raise_exception()
|