From 3b4b660d3c1776e6a65ea79568208d8c60b6291f Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sun, 23 Apr 2023 22:13:20 +0800 Subject: [PATCH] update chatglm_llm.py --- models/chatglm_llm.py | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 5608cbb..7243967 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -68,19 +68,33 @@ class ChatGLM(LLM): def _call(self, prompt: str, - stop: Optional[List[str]] = None) -> str: - response, _ = self.model.chat( - self.tokenizer, - prompt, - history=self.history[-self.history_len:] if self.history_len > 0 else [], - max_length=self.max_token, - temperature=self.temperature, - ) - torch_gc() - if stop is not None: - response = enforce_stop_tokens(response, stop) - self.history = self.history + [[None, response]] - return response + stop: Optional[List[str]] = None, + stream=True) -> str: + if stream: + self.history = self.history + [[None, ""]] + response, _ = self.model.stream_chat( + self.tokenizer, + prompt, + history=self.history[-self.history_len:] if self.history_len > 0 else [], + max_length=self.max_token, + temperature=self.temperature, + ) + torch_gc() + self.history[-1][-1] = response + yield response + else: + response, _ = self.model.chat( + self.tokenizer, + prompt, + history=self.history[-self.history_len:] if self.history_len > 0 else [], + max_length=self.max_token, + temperature=self.temperature, + ) + torch_gc() + if stop is not None: + response = enforce_stop_tokens(response, stop) + self.history = self.history + [[None, response]] + return response def chat(self, prompt: str) -> str: