diff --git a/configs/model_config.py b/configs/model_config.py index e18f69b..0f25092 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -74,7 +74,7 @@ llm_model_dict = { "vicuna-13b-hf": { "name": "vicuna-13b-hf", "pretrained_model_name": "vicuna-13b-hf", - "local_model_path": "/media/checkpoint/vicuna-13b-hf", + "local_model_path": None, "provides": "LLamaLLM" }, diff --git a/models/llama_llm.py b/models/llama_llm.py index 1b0f403..69fde56 100644 --- a/models/llama_llm.py +++ b/models/llama_llm.py @@ -98,9 +98,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC): """ formatted_history = '' history = history[-self.history_len:] if self.history_len > 0 else [] - for i, (old_query, response) in enumerate(history): - formatted_history += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) - formatted_history += "[Round {}]\n问:{}\n答:".format(len(history), query) + if len(history) > 0: + for i, (old_query, response) in enumerate(history): + formatted_history += "### Human:{}\n### Assistant:{}\n".format(old_query, response) + formatted_history += "### Human:{}\n### Assistant:".format(query) return formatted_history def prepare_inputs_for_generation(self, @@ -140,12 +141,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC): "max_new_tokens": self.max_new_tokens, "num_beams": self.num_beams, "top_p": self.top_p, + "do_sample": True, "top_k": self.top_k, "repetition_penalty": self.repetition_penalty, "encoder_repetition_penalty": self.encoder_repetition_penalty, "min_length": self.min_length, "temperature": self.temperature, - "eos_token_id": self.eos_token_id, + "eos_token_id": self.checkPoint.tokenizer.eos_token_id, "logits_processor": self.logits_processor} # 向量转换 @@ -178,6 +180,6 @@ class LLamaLLM(BaseAnswer, LLM, ABC): response = self._call(prompt=softprompt, stop=['\n###']) answer_result = AnswerResult() - answer_result.history = history + [[None, response]] + answer_result.history = history + [[prompt, response]] answer_result.llm_output = {"answer": response} yield answer_result diff --git a/models/moss_llm.py b/models/moss_llm.py index c608edb..80a8687 100644 --- a/models/moss_llm.py +++ b/models/moss_llm.py @@ -75,8 +75,8 @@ class MOSSLLM(BaseAnswer, LLM, ABC): repetition_penalty=1.02, num_return_sequences=1, eos_token_id=106068, - pad_token_id=self.tokenizer.pad_token_id) - response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + pad_token_id=self.checkPoint.tokenizer.pad_token_id) + response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) self.checkPoint.clear_torch_cache() history += [[prompt, response]] answer_result = AnswerResult()