删除 AnswerResultStream 、generate_with_callback收集器

This commit is contained in:
glide-the 2023-05-25 21:07:40 +08:00
parent e7b06a9072
commit c4ee36b8ac
7 changed files with 21 additions and 204 deletions

View File

@ -12,9 +12,7 @@ from tqdm import tqdm
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from models.loader.args import parser from models.loader.args import parser
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
import models.shared as shared import models.shared as shared

View File

@ -10,142 +10,12 @@ import transformers
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
class ListenerToken:
"""
观测结果
"""
input_ids: torch.LongTensor
_scores: torch.FloatTensor
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
self.input_ids = input_ids
self._scores = _scores
class AnswerResult: class AnswerResult:
""" """
消息实体 消息实体
""" """
history: List[List[str]] = [] history: List[List[str]] = []
llm_output: Optional[dict] = None llm_output: Optional[dict] = None
listenerToken: ListenerToken = None
class AnswerResultStream:
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, answerResult: AnswerResult):
if self.callback_func is not None:
self.callback_func(answerResult)
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
"""
定义模型stopping_criteria 监听者在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是不同模型的预测输出可能不是矢量信息hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数
通过给 StoppingCriteriaList指定模型生成答案时停止的条件每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时StoppingCriteria都会收到相同的预测结果最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测以实现更加精细的控制
"""
listenerQueue: deque = deque(maxlen=1)
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}):
self.mfunc = func
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream收集模型预测的AnswerResult响应结果最终由下层实现类确认是否结束
结束条件包含如下
1模型预测结束收集器self.q队列收到 self.sentinel标识
2在处理迭代器队列消息时返回了break跳出迭代器触发了StopIteration事件
3模型预测出错
因为当前类是迭代器所以在for in 中执行了break后 __exit__ 方法会被调用最终stop_now属性会被更新然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1消息为上游checkpoint包装的AnswerResult消息体返回下游处理
2消息为self.sentinel标识抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息最终stop_now属性会被更新
异步线程检测stop_now属性被更新抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if self.stop_now:
raise ValueError
self.q.put(val)
def gen():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
self.thread = Thread(target=gen)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
"""
暂无实现
:return:
"""
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
""" break 后会执行 """
self.stop_now = True
class BaseAnswer(ABC): class BaseAnswer(ABC):
@ -168,22 +38,4 @@ class BaseAnswer(ABC):
def generatorAnswer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False): streaming: bool = False):
def generate_with_callback(callback=None, **kwargs):
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
self._generate_answer(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs)
with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator:
for answerResult in generator:
if answerResult.listenerToken:
output = answerResult.listenerToken.input_ids
yield answerResult
@abstractmethod
def _generate_answer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False,
generate_with_callback: AnswerResultStream = None) -> None:
pass pass

View File

@ -5,9 +5,7 @@ from langchain.llms.base import LLM
from typing import Optional, List from typing import Optional, List
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import transformers import transformers
@ -43,15 +41,9 @@ class ChatGLM(BaseAnswer, LLM, ABC):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass pass
def _generate_answer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False, streaming: bool = False):
generate_with_callback: AnswerResultStream = None) -> None:
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
stopping_criteria_list.append(listenerQueue)
if streaming: if streaming:
history += [[]] history += [[]]
@ -60,34 +52,27 @@ class ChatGLM(BaseAnswer, LLM, ABC):
prompt, prompt,
history=history[-self.history_len:-1] if self.history_len > 0 else [], history=history[-self.history_len:-1] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature
stopping_criteria=stopping_criteria_list
)): )):
# self.checkPoint.clear_torch_cache() # self.checkPoint.clear_torch_cache()
history[-1] = [prompt, stream_resp] history[-1] = [prompt, stream_resp]
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": stream_resp} answer_result.llm_output = {"answer": stream_resp}
if listenerQueue.listenerQueue.__len__() > 0: yield answer_result
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)
else: else:
response, _ = self.checkPoint.model.chat( response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer, self.checkPoint.tokenizer,
prompt, prompt,
history=history[-self.history_len:] if self.history_len > 0 else [], history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token, max_length=self.max_token,
temperature=self.temperature, temperature=self.temperature
stopping_criteria=stopping_criteria_list
) )
self.checkPoint.clear_torch_cache() self.checkPoint.clear_torch_cache()
history += [[prompt, response]] history += [[prompt, response]]
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
if listenerQueue.listenerQueue.__len__() > 0: yield answer_result
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
generate_with_callback(answer_result)

View File

@ -5,9 +5,7 @@ from langchain.llms.base import LLM
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
class FastChatLLM(BaseAnswer, LLM, ABC): class FastChatLLM(BaseAnswer, LLM, ABC):
@ -40,10 +38,9 @@ class FastChatLLM(BaseAnswer, LLM, ABC):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass pass
def _generate_answer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False, streaming: bool = False):
generate_with_callback: AnswerResultStream = None) -> None:
response = "fastchat 响应结果" response = "fastchat 响应结果"
history += [[prompt, response]] history += [[prompt, response]]
@ -51,4 +48,4 @@ class FastChatLLM(BaseAnswer, LLM, ABC):
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
generate_with_callback(answer_result) yield answer_result

View File

@ -9,9 +9,7 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
class InvalidScoreLogitsProcessor(LogitsProcessor): class InvalidScoreLogitsProcessor(LogitsProcessor):
@ -178,23 +176,15 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
self.history = self.history + [[None, reply]] self.history = self.history + [[None, reply]]
return reply return reply
def _generate_answer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False, streaming: bool = False):
generate_with_callback: AnswerResultStream = None) -> None:
if history: if history:
self.history = history self.history = history
# Create the StoppingCriteriaList with the stopping strings
self.stopping_criteria = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
self.stopping_criteria.append(listenerQueue)
# TODO 需要实现chat对话模块和注意力模型目前_call为langchain的LLM拓展的api默认为无提示词模式如果需要操作注意力模型可以参考chat_glm的实现 # TODO 需要实现chat对话模块和注意力模型目前_call为langchain的LLM拓展的api默认为无提示词模式如果需要操作注意力模型可以参考chat_glm的实现
softprompt = self.generate_softprompt_history_tensors(prompt) softprompt = self.generate_softprompt_history_tensors(prompt)
response = self._call(prompt=softprompt, stop=['\n###']) response = self._call(prompt=softprompt, stop=['\n###'])
answer_result = AnswerResult() answer_result = AnswerResult()
answer_result.history = self.history answer_result.history = self.history
if listenerQueue.listenerQueue.__len__() > 0:
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
generate_with_callback(answer_result) yield answer_result

View File

@ -3,9 +3,7 @@ from langchain.llms.base import LLM
from typing import Optional, List from typing import Optional, List
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch import torch
@ -53,10 +51,9 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass pass
def _generate_answer(self, prompt: str, def generatorAnswer(self, prompt: str,
history: List[List[str]] = [], history: List[List[str]] = [],
streaming: bool = False, streaming: bool = False):
generate_with_callback: AnswerResultStream = None) -> None:
if len(history) > 0: if len(history) > 0:
history = history[-self.history_len:-1] if self.history_len > 0 else [] history = history[-self.history_len:-1] if self.history_len > 0 else []
prompt_w_history = str(history) prompt_w_history = str(history)
@ -86,6 +83,6 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
answer_result.history = history answer_result.history = history
answer_result.llm_output = {"answer": response} answer_result.llm_output = {"answer": response}
generate_with_callback(answer_result) yield answer_result

View File

@ -6,9 +6,7 @@ from chains.local_doc_qa import LocalDocQA
from configs.model_config import * from configs.model_config import *
import nltk import nltk
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult)
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import models.shared as shared import models.shared as shared
from models.loader.args import parser from models.loader.args import parser
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint