From 5524c476450c2bf08dce25af508e8921464014be Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Fri, 19 May 2023 17:32:38 +0800 Subject: [PATCH] Update moss_llm.py --- models/moss_llm.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/models/moss_llm.py b/models/moss_llm.py index c958baa..e7a4d0b 100644 --- a/models/moss_llm.py +++ b/models/moss_llm.py @@ -58,6 +58,11 @@ class MOSSLLM(BaseAnswer, LLM, ABC): history: List[List[str]] = [], streaming: bool = False, generate_with_callback: AnswerResultStream = None) -> None: + # Create the StoppingCriteriaList with the stopping strings + stopping_criteria_list = transformers.StoppingCriteriaList() + # 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult + listenerQueue = AnswerResultQueueSentinelTokenListenerQueue() + stopping_criteria_list.append(listenerQueue) if len(history) > 0: history = history[-self.history_len:-1] if self.history_len > 0 else [] prompt_w_history = str(history) @@ -83,6 +88,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC): response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) self.checkPoint.clear_torch_cache() history += [[prompt, response]] - yield response, history + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": response} + if listenerQueue.listenerQueue.__len__() > 0: + answer_result.listenerToken = listenerQueue.listenerQueue.pop() + + generate_with_callback(answer_result)