diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index ecd5e88..7bde8a0 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -12,9 +12,7 @@ from tqdm import tqdm from pypinyin import lazy_pinyin from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader from models.base import (BaseAnswer, - AnswerResult, - AnswerResultStream, - AnswerResultQueueSentinelTokenListenerQueue) + AnswerResult) from models.loader.args import parser from models.loader import LoaderCheckPoint import models.shared as shared diff --git a/models/base.py b/models/base.py index 648b087..b0fb498 100644 --- a/models/base.py +++ b/models/base.py @@ -10,142 +10,12 @@ import transformers from models.loader import LoaderCheckPoint -class ListenerToken: - """ - 观测结果 - """ - - input_ids: torch.LongTensor - _scores: torch.FloatTensor - - def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor): - self.input_ids = input_ids - self._scores = _scores - - class AnswerResult: """ 消息实体 """ history: List[List[str]] = [] llm_output: Optional[dict] = None - listenerToken: ListenerToken = None - - -class AnswerResultStream: - def __init__(self, callback_func=None): - self.callback_func = callback_func - - def __call__(self, answerResult: AnswerResult): - if self.callback_func is not None: - self.callback_func(answerResult) - - -class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria): - """ - 定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult - 实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数, - 通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件 - 当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束 - 输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制 - """ - - listenerQueue: deque = deque(maxlen=1) - - def __init__(self): - transformers.StoppingCriteria.__init__(self) - - def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool: - """ - 每次响应时将数据添加到响应队列 - :param input_ids: - :param _scores: - :param kwargs: - :return: - """ - self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores)) - return False - - -class Iteratorize: - """ - Transforms a function that takes a callback - into a lazy iterator (generator). - """ - - def __init__(self, func, kwargs={}): - self.mfunc = func - self.q = Queue() - self.sentinel = object() - self.kwargs = kwargs - self.stop_now = False - - def _callback(val): - """ - 模型输出预测结果收集 - 通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束 - 结束条件包含如下 - 1、模型预测结束、收集器self.q队列收到 self.sentinel标识 - 2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件 - 3、模型预测出错 - 因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为 - 迭代器收集的行为如下 - 创建Iteratorize迭代对象, - 定义generate_with_callback收集器AnswerResultStream - 启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer - _generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体 - 由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测 - 这时generate_with_callback会被阻塞 - 主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费 - 1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理 - 2、消息为self.sentinel标识,抛出StopIteration异常 - 主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新 - 异步线程检测stop_now属性被更新,抛出异常结束预测行为 - 迭代行为结束 - :param val: - :return: - """ - if self.stop_now: - raise ValueError - self.q.put(val) - - def gen(): - try: - ret = self.mfunc(callback=_callback, **self.kwargs) - except ValueError: - pass - except: - traceback.print_exc() - pass - - self.q.put(self.sentinel) - - self.thread = Thread(target=gen) - self.thread.start() - - def __iter__(self): - return self - - def __next__(self): - obj = self.q.get(True, None) - if obj is self.sentinel: - raise StopIteration - else: - return obj - - def __del__(self): - """ - 暂无实现 - :return: - """ - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ break 后会执行 """ - self.stop_now = True class BaseAnswer(ABC): @@ -168,22 +38,4 @@ class BaseAnswer(ABC): def generatorAnswer(self, prompt: str, history: List[List[str]] = [], streaming: bool = False): - def generate_with_callback(callback=None, **kwargs): - kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback) - self._generate_answer(**kwargs) - - def generate_with_streaming(**kwargs): - return Iteratorize(generate_with_callback, kwargs) - - with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator: - for answerResult in generator: - if answerResult.listenerToken: - output = answerResult.listenerToken.input_ids - yield answerResult - - @abstractmethod - def _generate_answer(self, prompt: str, - history: List[List[str]] = [], - streaming: bool = False, - generate_with_callback: AnswerResultStream = None) -> None: pass diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index ec8d052..09970b3 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -5,9 +5,7 @@ from langchain.llms.base import LLM from typing import Optional, List from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, - AnswerResult, - AnswerResultStream, - AnswerResultQueueSentinelTokenListenerQueue) + AnswerResult) import transformers @@ -43,15 +41,9 @@ class ChatGLM(BaseAnswer, LLM, ABC): def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: pass - def _generate_answer(self, prompt: str, + def generatorAnswer(self, prompt: str, history: List[List[str]] = [], - streaming: bool = False, - generate_with_callback: AnswerResultStream = None) -> None: - # Create the StoppingCriteriaList with the stopping strings - stopping_criteria_list = transformers.StoppingCriteriaList() - # 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult - listenerQueue = AnswerResultQueueSentinelTokenListenerQueue() - stopping_criteria_list.append(listenerQueue) + streaming: bool = False): if streaming: history += [[]] @@ -60,34 +52,27 @@ class ChatGLM(BaseAnswer, LLM, ABC): prompt, history=history[-self.history_len:-1] if self.history_len > 0 else [], max_length=self.max_token, - temperature=self.temperature, - stopping_criteria=stopping_criteria_list + temperature=self.temperature )): # self.checkPoint.clear_torch_cache() history[-1] = [prompt, stream_resp] answer_result = AnswerResult() answer_result.history = history answer_result.llm_output = {"answer": stream_resp} - if listenerQueue.listenerQueue.__len__() > 0: - answer_result.listenerToken = listenerQueue.listenerQueue.pop() - generate_with_callback(answer_result) + yield answer_result else: response, _ = self.checkPoint.model.chat( self.checkPoint.tokenizer, prompt, history=history[-self.history_len:] if self.history_len > 0 else [], max_length=self.max_token, - temperature=self.temperature, - stopping_criteria=stopping_criteria_list + temperature=self.temperature ) self.checkPoint.clear_torch_cache() history += [[prompt, response]] answer_result = AnswerResult() answer_result.history = history answer_result.llm_output = {"answer": response} - if listenerQueue.listenerQueue.__len__() > 0: - answer_result.listenerToken = listenerQueue.listenerQueue.pop() - - generate_with_callback(answer_result) + yield answer_result diff --git a/models/fastchat_llm.py b/models/fastchat_llm.py index 4ee136b..3356a26 100644 --- a/models/fastchat_llm.py +++ b/models/fastchat_llm.py @@ -5,9 +5,7 @@ from langchain.llms.base import LLM from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, - AnswerResult, - AnswerResultStream, - AnswerResultQueueSentinelTokenListenerQueue) + AnswerResult) class FastChatLLM(BaseAnswer, LLM, ABC): @@ -40,10 +38,9 @@ class FastChatLLM(BaseAnswer, LLM, ABC): def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: pass - def _generate_answer(self, prompt: str, + def generatorAnswer(self, prompt: str, history: List[List[str]] = [], - streaming: bool = False, - generate_with_callback: AnswerResultStream = None) -> None: + streaming: bool = False): response = "fastchat 响应结果" history += [[prompt, response]] @@ -51,4 +48,4 @@ class FastChatLLM(BaseAnswer, LLM, ABC): answer_result.history = history answer_result.llm_output = {"answer": response} - generate_with_callback(answer_result) + yield answer_result diff --git a/models/llama_llm.py b/models/llama_llm.py index 41bc9ed..400f704 100644 --- a/models/llama_llm.py +++ b/models/llama_llm.py @@ -9,9 +9,7 @@ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaL from typing import Optional, List, Dict, Any from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, - AnswerResult, - AnswerResultStream, - AnswerResultQueueSentinelTokenListenerQueue) + AnswerResult) class InvalidScoreLogitsProcessor(LogitsProcessor): @@ -178,23 +176,15 @@ class LLamaLLM(BaseAnswer, LLM, ABC): self.history = self.history + [[None, reply]] return reply - def _generate_answer(self, prompt: str, + def generatorAnswer(self, prompt: str, history: List[List[str]] = [], - streaming: bool = False, - generate_with_callback: AnswerResultStream = None) -> None: + streaming: bool = False): if history: self.history = history - # Create the StoppingCriteriaList with the stopping strings - self.stopping_criteria = transformers.StoppingCriteriaList() - # 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult - listenerQueue = AnswerResultQueueSentinelTokenListenerQueue() - self.stopping_criteria.append(listenerQueue) # TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现 softprompt = self.generate_softprompt_history_tensors(prompt) response = self._call(prompt=softprompt, stop=['\n###']) answer_result = AnswerResult() answer_result.history = self.history - if listenerQueue.listenerQueue.__len__() > 0: - answer_result.listenerToken = listenerQueue.listenerQueue.pop() answer_result.llm_output = {"answer": response} - generate_with_callback(answer_result) + yield answer_result diff --git a/models/moss_llm.py b/models/moss_llm.py index be9711b..084e761 100644 --- a/models/moss_llm.py +++ b/models/moss_llm.py @@ -3,9 +3,7 @@ from langchain.llms.base import LLM from typing import Optional, List from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, - AnswerResult, - AnswerResultStream, - AnswerResultQueueSentinelTokenListenerQueue) + AnswerResult) import torch @@ -53,10 +51,9 @@ class MOSSLLM(BaseAnswer, LLM, ABC): def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: pass - def _generate_answer(self, prompt: str, + def generatorAnswer(self, prompt: str, history: List[List[str]] = [], - streaming: bool = False, - generate_with_callback: AnswerResultStream = None) -> None: + streaming: bool = False): if len(history) > 0: history = history[-self.history_len:-1] if self.history_len > 0 else [] prompt_w_history = str(history) @@ -86,6 +83,6 @@ class MOSSLLM(BaseAnswer, LLM, ABC): answer_result.history = history answer_result.llm_output = {"answer": response} - generate_with_callback(answer_result) + yield answer_result diff --git a/webui.py b/webui.py index ea1fe03..cfe7b58 100644 --- a/webui.py +++ b/webui.py @@ -6,9 +6,7 @@ from chains.local_doc_qa import LocalDocQA from configs.model_config import * import nltk from models.base import (BaseAnswer, - AnswerResult, - AnswerResultStream, - AnswerResultQueueSentinelTokenListenerQueue) + AnswerResult) import models.shared as shared from models.loader.args import parser from models.loader import LoaderCheckPoint