Langchain-Chatchat/models/chatglm_llm.py

118 lines
4.4 KiB
Python
Raw Normal View History

from abc import ABC
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
2023-03-31 20:09:40 +08:00
max_token: int = 10000
temperature: float = 0.01
# 相关度
top_p = 0.4
# 候选词数量
top_k = 10
checkPoint: LoaderCheckPoint = None
2023-04-26 22:29:20 +08:00
# history = []
history_len: int = 10
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
2023-03-31 20:09:40 +08:00
def __init__(self, checkPoint: LoaderCheckPoint = None):
2023-03-31 20:09:40 +08:00
super().__init__()
self.checkPoint = checkPoint
2023-03-31 20:09:40 +08:00
@property
def _chain_type(self) -> str:
return "ChatGLMLLMChain"
2023-03-31 20:09:40 +08:00
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
stopping_criteria_list.append(listenerQueue)
if streaming:
history += [[]]
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
self.checkPoint.tokenizer,
2023-04-26 22:29:20 +08:00
prompt,
history=history[-self.history_len:-1] if self.history_len > 0 else [],
2023-04-26 22:29:20 +08:00
max_length=self.max_token,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
stopping_criteria=stopping_criteria_list
2023-04-26 23:19:11 +08:00
)):
2023-05-25 20:58:04 +08:00
# self.checkPoint.clear_torch_cache()
history[-1] = [prompt, stream_resp]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": stream_resp}
generate_with_callback(answer_result)
2023-06-26 19:12:50 +08:00
self.checkPoint.clear_torch_cache()
2023-04-23 22:13:20 +08:00
else:
response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer,
prompt,
history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
stopping_criteria=stopping_criteria_list
2023-04-23 22:13:20 +08:00
)
self.checkPoint.clear_torch_cache()
history += [[prompt, response]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
generate_with_callback(answer_result)