support chatglm2cpp

This commit is contained in:
lichongyang 2023-07-16 17:24:46 +08:00
parent c5bc21781c
commit 62ba7679fe
5 changed files with 184 additions and 5 deletions

4
.gitignore vendored
View File

@ -174,4 +174,8 @@ embedding/*
pyrightconfig.json
loader/tmp_files
<<<<<<< HEAD
flagged/*
=======
flagged/*
>>>>>>> 19147c3 (Update .gitignore)

View File

@ -63,6 +63,13 @@ llm_model_dict = {
"local_model_path": None,
"provides": "ChatGLMLLMChain"
},
# 注chatglm2-cpp已在mac上测试通过其他系统暂不支持
"chatglm2-cpp": {
"name": "chatglm2-cpp",
"pretrained_model_name": "cylee0909/chatglm2cpp",
"local_model_path": None,
"provides": "ChatGLMCppLLMChain"
},
"chatglm2-6b-int4": {
"name": "chatglm2-6b-int4",
"pretrained_model_name": "THUDM/chatglm2-6b-int4",

View File

@ -1,4 +1,5 @@
from .chatglm_llm import ChatGLMLLMChain
from .llama_llm import LLamaLLMChain
from .chatglmcpp_llm import ChatGLMCppLLMChain
from .fastchat_openai_llm import FastChatOpenAILLMChain
from .moss_llm import MOSSLLMChain

143
models/chatglmcpp_llm.py Normal file
View File

@ -0,0 +1,143 @@
from abc import ABC
from typing import Any, Dict, Generator, List, Optional, Union
import torch
import transformers
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import (LogitsProcessorList,
StoppingCriteriaList)
from models.base import (AnswerResult,
AnswerResultStream, BaseAnswer)
from models.loader import LoaderCheckPoint
class ChatGLMCppLLMChain(BaseAnswer, Chain, ABC):
checkPoint: LoaderCheckPoint = None
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
max_length = 2048
max_context_length = 512
do_sample = True
top_k = 0
top_p = 0.7
temperature = 0.95
num_threads = 0
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _chain_type(self) -> str:
return "ChatglmCppLLMChain"
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
def encode(self, prompt, truncation_length=None):
input_ids = self.checkPoint.tokenizer.encode(str(prompt))
return input_ids
def decode(self, output_ids):
reply = self.checkPoint.tokenizer.decode(output_ids)
return reply
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
if prompt == "clear":
history=[]
local_history = []
if not history:
history =[]
for k,v in history:
if k:
local_history.append(k)
local_history.append(v)
local_history.append(prompt)
if streaming:
history += [[]]
pieces = []
print(f"++++++++++++++Stream++++++++++++++++++++")
for piece in self.checkPoint.model.stream_chat(
local_history,
max_length=self.max_length,
max_context_length=self.max_context_length,
do_sample=self.temperature > 0,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
):
pieces.append(piece)
reply = ''.join(pieces)
print(f"{piece}",end='')
answer_result = AnswerResult()
history[-1] = [prompt, reply]
answer_result.history = history
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)
print("")
else :
reply = self.checkPoint.model.chat(
local_history,
max_length=self.max_length,
max_context_length=self.max_context_length,
do_sample=self.temperature > 0,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
)
print(f"response:{reply}")
print(f"+++++++++++++++++++++++++++++++++++")
answer_result = AnswerResult()
history.append([prompt, reply])
answer_result.history = history
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)

View File

@ -11,7 +11,6 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, LlamaTokenizer)
from configs.model_config import LLM_DEVICE
class LoaderCheckPoint:
"""
加载自定义 model CheckPoint
@ -68,6 +67,8 @@ class LoaderCheckPoint:
self.load_in_8bit = params.get('load_in_8bit', False)
self.bf16 = params.get('bf16', False)
self.is_chatgmlcpp = "chatglm2-cpp" == self.model_name
def _load_model_config(self):
if self.model_path:
@ -119,7 +120,7 @@ class LoaderCheckPoint:
# 如果加载没问题但在推理时报错RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
# 那还是因为显存不够,此时只能考虑--load-in-8bit,或者配置默认模型为`chatglm-6b-int8`
if not any([self.llm_device.lower() == "cpu",
self.load_in_8bit, self.is_llamacpp]):
self.load_in_8bit, self.is_llamacpp, self.is_chatgmlcpp]):
if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"):
# 根据当前设备GPU数量决定是否进行多卡部署
@ -183,11 +184,34 @@ class LoaderCheckPoint:
.to(self.llm_device)
)
elif self.is_chatgmlcpp :
try:
import chatglm_cpp
except ImportError as exc:
import platform
if platform.system() == "Darwin":
raise ValueError(
"Could not import depend python package "
"Please install it with `pip install chatglm-cpp`."
) from exc
else :
raise SystemError(
f"chatglm-cpp not support {platform.system()}."
) from exc
model = (
LoaderClass.from_pretrained(
checkpoint,
config=self.model_config,
trust_remote_code=True)
)
# model = chatglm_cpp.Pipeline(f'{self.model_path}/{self.model_name}.bin')
tokenizer = getattr(model, "tokenizer")
return model, tokenizer
elif self.is_llamacpp:
try:
from llama_cpp import Llama
except ImportError as exc:
raise ValueError(
"Could not import depend python package "
@ -473,5 +497,5 @@ class LoaderCheckPoint:
except Exception as e:
print("加载PrefixEncoder模型参数失败")
# llama-cpp模型至少vicuna-13b的eval方法就是自身其没有eval方法
if not self.is_llamacpp:
if not self.is_llamacpp and not self.is_chatgmlcpp:
self.model = self.model.eval()