Update moss_llm.py
This commit is contained in:
parent
7c74933285
commit
5524c47645
|
|
@ -58,6 +58,11 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
|
||||||
history: List[List[str]] = [],
|
history: List[List[str]] = [],
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
generate_with_callback: AnswerResultStream = None) -> None:
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
# Create the StoppingCriteriaList with the stopping strings
|
||||||
|
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||||
|
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||||
|
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||||
|
stopping_criteria_list.append(listenerQueue)
|
||||||
if len(history) > 0:
|
if len(history) > 0:
|
||||||
history = history[-self.history_len:-1] if self.history_len > 0 else []
|
history = history[-self.history_len:-1] if self.history_len > 0 else []
|
||||||
prompt_w_history = str(history)
|
prompt_w_history = str(history)
|
||||||
|
|
@ -83,6 +88,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
|
||||||
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||||
self.checkPoint.clear_torch_cache()
|
self.checkPoint.clear_torch_cache()
|
||||||
history += [[prompt, response]]
|
history += [[prompt, response]]
|
||||||
yield response, history
|
answer_result = AnswerResult()
|
||||||
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": response}
|
||||||
|
if listenerQueue.listenerQueue.__len__() > 0:
|
||||||
|
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||||
|
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue