Update moss_llm.py
This commit is contained in:
parent
7c74933285
commit
5524c47645
|
|
@ -58,6 +58,11 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
|
|||
history: List[List[str]] = [],
|
||||
streaming: bool = False,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
# Create the StoppingCriteriaList with the stopping strings
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||
stopping_criteria_list.append(listenerQueue)
|
||||
if len(history) > 0:
|
||||
history = history[-self.history_len:-1] if self.history_len > 0 else []
|
||||
prompt_w_history = str(history)
|
||||
|
|
@ -83,6 +88,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
|
|||
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
self.checkPoint.clear_torch_cache()
|
||||
history += [[prompt, response]]
|
||||
yield response, history
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": response}
|
||||
if listenerQueue.listenerQueue.__len__() > 0:
|
||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||
|
||||
generate_with_callback(answer_result)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue