llama_llm.py 提示词修改
This commit is contained in:
parent
f1cfd6d688
commit
0abd2d9992
|
|
@ -74,7 +74,7 @@ llm_model_dict = {
|
||||||
"vicuna-13b-hf": {
|
"vicuna-13b-hf": {
|
||||||
"name": "vicuna-13b-hf",
|
"name": "vicuna-13b-hf",
|
||||||
"pretrained_model_name": "vicuna-13b-hf",
|
"pretrained_model_name": "vicuna-13b-hf",
|
||||||
"local_model_path": "/media/checkpoint/vicuna-13b-hf",
|
"local_model_path": None,
|
||||||
"provides": "LLamaLLM"
|
"provides": "LLamaLLM"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,9 +98,10 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||||
"""
|
"""
|
||||||
formatted_history = ''
|
formatted_history = ''
|
||||||
history = history[-self.history_len:] if self.history_len > 0 else []
|
history = history[-self.history_len:] if self.history_len > 0 else []
|
||||||
for i, (old_query, response) in enumerate(history):
|
if len(history) > 0:
|
||||||
formatted_history += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
for i, (old_query, response) in enumerate(history):
|
||||||
formatted_history += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
formatted_history += "### Human:{}\n### Assistant:{}\n".format(old_query, response)
|
||||||
|
formatted_history += "### Human:{}\n### Assistant:".format(query)
|
||||||
return formatted_history
|
return formatted_history
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self,
|
def prepare_inputs_for_generation(self,
|
||||||
|
|
@ -140,12 +141,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||||
"max_new_tokens": self.max_new_tokens,
|
"max_new_tokens": self.max_new_tokens,
|
||||||
"num_beams": self.num_beams,
|
"num_beams": self.num_beams,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
|
"do_sample": True,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"repetition_penalty": self.repetition_penalty,
|
"repetition_penalty": self.repetition_penalty,
|
||||||
"encoder_repetition_penalty": self.encoder_repetition_penalty,
|
"encoder_repetition_penalty": self.encoder_repetition_penalty,
|
||||||
"min_length": self.min_length,
|
"min_length": self.min_length,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"eos_token_id": self.eos_token_id,
|
"eos_token_id": self.checkPoint.tokenizer.eos_token_id,
|
||||||
"logits_processor": self.logits_processor}
|
"logits_processor": self.logits_processor}
|
||||||
|
|
||||||
# 向量转换
|
# 向量转换
|
||||||
|
|
@ -178,6 +180,6 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||||
response = self._call(prompt=softprompt, stop=['\n###'])
|
response = self._call(prompt=softprompt, stop=['\n###'])
|
||||||
|
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history + [[None, response]]
|
answer_result.history = history + [[prompt, response]]
|
||||||
answer_result.llm_output = {"answer": response}
|
answer_result.llm_output = {"answer": response}
|
||||||
yield answer_result
|
yield answer_result
|
||||||
|
|
|
||||||
|
|
@ -75,8 +75,8 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
|
||||||
repetition_penalty=1.02,
|
repetition_penalty=1.02,
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
eos_token_id=106068,
|
eos_token_id=106068,
|
||||||
pad_token_id=self.tokenizer.pad_token_id)
|
pad_token_id=self.checkPoint.tokenizer.pad_token_id)
|
||||||
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||||
self.checkPoint.clear_torch_cache()
|
self.checkPoint.clear_torch_cache()
|
||||||
history += [[prompt, response]]
|
history += [[prompt, response]]
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue