Langchain-Chatchat/models/chatglmcpp_llm.py

144 lines
4.5 KiB
Python

from abc import ABC
from typing import Any, Dict, Generator, List, Optional, Union
import torch
import transformers
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import (LogitsProcessorList,
StoppingCriteriaList)
from models.base import (AnswerResult,
AnswerResultStream, BaseAnswer)
from models.loader import LoaderCheckPoint
class ChatGLMCppLLMChain(BaseAnswer, Chain, ABC):
checkPoint: LoaderCheckPoint = None
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
max_length = 2048
max_context_length = 512
do_sample = True
top_k = 0
top_p = 0.7
temperature = 0.95
num_threads = 0
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _chain_type(self) -> str:
return "ChatglmCppLLMChain"
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
def encode(self, prompt, truncation_length=None):
input_ids = self.checkPoint.tokenizer.encode(str(prompt))
return input_ids
def decode(self, output_ids):
reply = self.checkPoint.tokenizer.decode(output_ids)
return reply
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
if prompt == "clear":
history=[]
local_history = []
if not history:
history =[]
for k,v in history:
if k:
local_history.append(k)
local_history.append(v)
local_history.append(prompt)
if streaming:
history += [[]]
pieces = []
print(f"++++++++++++++Stream++++++++++++++++++++")
for piece in self.checkPoint.model.stream_chat(
local_history,
max_length=self.max_length,
max_context_length=self.max_context_length,
do_sample=self.temperature > 0,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
):
pieces.append(piece)
reply = ''.join(pieces)
print(f"{piece}",end='')
answer_result = AnswerResult()
history[-1] = [prompt, reply]
answer_result.history = history
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)
print("")
else :
reply = self.checkPoint.model.chat(
local_history,
max_length=self.max_length,
max_context_length=self.max_context_length,
do_sample=self.temperature > 0,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
)
print(f"response:{reply}")
print(f"+++++++++++++++++++++++++++++++++++")
answer_result = AnswerResult()
history.append([prompt, reply])
answer_result.history = history
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)