diff --git a/models/moss_llm.py b/models/moss_llm.py index c958baa..e7a4d0b 100644 --- a/models/moss_llm.py +++ b/models/moss_llm.py @@ -58,6 +58,11 @@ class MOSSLLM(BaseAnswer, LLM, ABC): history: List[List[str]] = [], streaming: bool = False, generate_with_callback: AnswerResultStream = None) -> None: + # Create the StoppingCriteriaList with the stopping strings + stopping_criteria_list = transformers.StoppingCriteriaList() + # 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult + listenerQueue = AnswerResultQueueSentinelTokenListenerQueue() + stopping_criteria_list.append(listenerQueue) if len(history) > 0: history = history[-self.history_len:-1] if self.history_len > 0 else [] prompt_w_history = str(history) @@ -83,6 +88,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC): response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) self.checkPoint.clear_torch_cache() history += [[prompt, response]] - yield response, history + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": response} + if listenerQueue.listenerQueue.__len__() > 0: + answer_result.listenerToken = listenerQueue.listenerQueue.pop() + + generate_with_callback(answer_result)