Langchain-Chatchat/models/llama_llm.py

202 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from abc import ABC
from langchain.llms.base import LLM
import random
import torch
import transformers
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from typing import Optional, List, Dict, Any,Union
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: Union[torch.LongTensor,list], scores: Union[torch.FloatTensor,list]) -> torch.FloatTensor:
# llama-cpp模型返回的是list,为兼容性考虑需要判断input_ids和scores的类型将list转换为torch.Tensor
input_ids = torch.tensor(input_ids) if isinstance(input_ids,list) else input_ids
scores = torch.tensor(scores) if isinstance(scores,list) else scores
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
class LLamaLLM(BaseAnswer, LLM, ABC):
checkPoint: LoaderCheckPoint = None
# history = []
history_len: int = 3
max_new_tokens: int = 500
num_beams: int = 1
temperature: float = 0.5
top_p: float = 0.4
top_k: int = 10
repetition_penalty: float = 1.2
encoder_repetition_penalty: int = 1
min_length: int = 0
logits_processor: LogitsProcessorList = None
stopping_criteria: Optional[StoppingCriteriaList] = None
eos_token_id: Optional[int] = [2]
state: object = {'max_new_tokens': 50,
'seed': 1,
'temperature': 0, 'top_p': 0.1,
'top_k': 40, 'typical_p': 1,
'repetition_penalty': 1.2,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'min_length': 0,
'penalty_alpha': 0,
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
'truncation_length': 2048, 'custom_stopping_strings': '',
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
'pre_layer': 0, 'gpu_memory_0': 0}
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "LLamaLLM"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
def encode(self, prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
input_ids = self.checkPoint.tokenizer.encode(str(prompt), return_tensors='pt',
add_special_tokens=add_special_tokens)
# This is a hack for making replies more creative.
if not add_bos_token and input_ids[0][0] == self.checkPoint.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
# Llama adds this extra token when the first character is '\n', and this
# compromises the stopping criteria, so we just remove it
if type(self.checkPoint.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
input_ids = input_ids[:, 1:]
# Handling truncation
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]
return input_ids.cuda()
def decode(self, output_ids):
reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True)
return reply
# 将历史对话数组转换为文本格式
def history_to_text(self, query, history):
"""
历史对话软提示
这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history
数组转换为所需的文本格式。然后,我们将格式化后的历史文本
再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。
:return:
"""
formatted_history = ''
history = history[-self.history_len:] if self.history_len > 0 else []
if len(history) > 0:
for i, (old_query, response) in enumerate(history):
formatted_history += "### Human{}\n### Assistant{}\n".format(old_query, response)
formatted_history += "### Human{}\n### Assistant".format(query)
return formatted_history
def prepare_inputs_for_generation(self,
input_ids: torch.LongTensor):
"""
预生成注意力掩码和 输入序列中每个位置的索引的张量
# TODO 没有思路
:return:
"""
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
attention_mask = self.get_masks(input_ids, input_ids.device)
position_ids = self.get_position_ids(
input_ids,
device=input_ids.device,
mask_positions=mask_positions
)
return input_ids, position_ids, attention_mask
@property
def _history_len(self) -> int:
return self.history_len
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
print(f"__call:{prompt}")
if self.logits_processor is None:
self.logits_processor = LogitsProcessorList()
self.logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
"max_new_tokens": self.max_new_tokens,
"num_beams": self.num_beams,
"top_p": self.top_p,
"do_sample": True,
"top_k": self.top_k,
"repetition_penalty": self.repetition_penalty,
"encoder_repetition_penalty": self.encoder_repetition_penalty,
"min_length": self.min_length,
"temperature": self.temperature,
"eos_token_id": self.checkPoint.tokenizer.eos_token_id,
"logits_processor": self.logits_processor}
# 向量转换
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
gen_kwargs.update({'inputs': input_ids})
# 注意力掩码
# gen_kwargs.update({'attention_mask': attention_mask})
# gen_kwargs.update({'position_ids': position_ids})
if self.stopping_criteria is None:
self.stopping_criteria = transformers.StoppingCriteriaList()
# 观测输出
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
# llama-cpp模型的参数与transformers的参数字段有较大差异直接调用会返回不支持的字段错误
# 因此需要先判断模型是否是llama-cpp模型然后取gen_kwargs与模型generate方法字段的交集
# 仅将交集字段传给模型以保证兼容性
# todo llama-cpp模型在本框架下兼容性较差后续可以考虑重写一个llama_cpp_llm.py模块
if "llama_cpp" in self.checkPoint.model.__str__():
import inspect
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args)&set(gen_kwargs.keys())
common_kwargs = {key:gen_kwargs[key] for key in common_kwargs_keys}
#? llama-cpp模型的generate方法似乎只接受.cpu类型的输入响应很慢慢到哭泣
#?为什么会不支持GPU呢不应该啊
output_ids = torch.tensor([list(self.checkPoint.model.generate(input_id_i.cpu(),**common_kwargs)) for input_id_i in input_ids])
else:
output_ids = self.checkPoint.model.generate(**gen_kwargs)
new_tokens = len(output_ids[0]) - len(input_ids[0])
reply = self.decode(output_ids[0][-new_tokens:])
print(f"response:{reply}")
print(f"+++++++++++++++++++++++++++++++++++")
return reply
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
# TODO 需要实现chat对话模块和注意力模型目前_call为langchain的LLM拓展的api默认为无提示词模式如果需要操作注意力模型可以参考chat_glm的实现
softprompt = self.history_to_text(prompt,history=history)
response = self._call(prompt=softprompt, stop=['\n###'])
answer_result = AnswerResult()
answer_result.history = history + [[prompt, response]]
answer_result.llm_output = {"answer": response}
yield answer_result