support chatglm2cpp
This commit is contained in:
parent
c5bc21781c
commit
62ba7679fe
|
|
@ -174,4 +174,8 @@ embedding/*
|
|||
|
||||
pyrightconfig.json
|
||||
loader/tmp_files
|
||||
<<<<<<< HEAD
|
||||
flagged/*
|
||||
=======
|
||||
flagged/*
|
||||
>>>>>>> 19147c3 (Update .gitignore)
|
||||
|
|
|
|||
|
|
@ -63,6 +63,13 @@ llm_model_dict = {
|
|||
"local_model_path": None,
|
||||
"provides": "ChatGLMLLMChain"
|
||||
},
|
||||
# 注:chatglm2-cpp已在mac上测试通过,其他系统暂不支持
|
||||
"chatglm2-cpp": {
|
||||
"name": "chatglm2-cpp",
|
||||
"pretrained_model_name": "cylee0909/chatglm2cpp",
|
||||
"local_model_path": None,
|
||||
"provides": "ChatGLMCppLLMChain"
|
||||
},
|
||||
"chatglm2-6b-int4": {
|
||||
"name": "chatglm2-6b-int4",
|
||||
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .chatglm_llm import ChatGLMLLMChain
|
||||
from .llama_llm import LLamaLLMChain
|
||||
from .chatglmcpp_llm import ChatGLMCppLLMChain
|
||||
from .fastchat_openai_llm import FastChatOpenAILLMChain
|
||||
from .moss_llm import MOSSLLMChain
|
||||
|
|
|
|||
|
|
@ -0,0 +1,143 @@
|
|||
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import (LogitsProcessorList,
|
||||
StoppingCriteriaList)
|
||||
|
||||
from models.base import (AnswerResult,
|
||||
AnswerResultStream, BaseAnswer)
|
||||
from models.loader import LoaderCheckPoint
|
||||
|
||||
|
||||
class ChatGLMCppLLMChain(BaseAnswer, Chain, ABC):
|
||||
checkPoint: LoaderCheckPoint = None
|
||||
streaming_key: str = "streaming" #: :meta private:
|
||||
history_key: str = "history" #: :meta private:
|
||||
prompt_key: str = "prompt" #: :meta private:
|
||||
output_key: str = "answer_result_stream" #: :meta private:
|
||||
|
||||
max_length = 2048
|
||||
max_context_length = 512
|
||||
do_sample = True
|
||||
top_k = 0
|
||||
top_p = 0.7
|
||||
temperature = 0.95
|
||||
num_threads = 0
|
||||
|
||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||
super().__init__()
|
||||
self.checkPoint = checkPoint
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "ChatglmCppLLMChain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.prompt_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@property
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
return self.checkPoint
|
||||
|
||||
def encode(self, prompt, truncation_length=None):
|
||||
input_ids = self.checkPoint.tokenizer.encode(str(prompt))
|
||||
return input_ids
|
||||
|
||||
def decode(self, output_ids):
|
||||
reply = self.checkPoint.tokenizer.decode(output_ids)
|
||||
return reply
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Generator]:
|
||||
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||
return {self.output_key: generator}
|
||||
|
||||
def _generate_answer(self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
generate_with_callback: AnswerResultStream = None) -> None:
|
||||
|
||||
history = inputs[self.history_key]
|
||||
streaming = inputs[self.streaming_key]
|
||||
prompt = inputs[self.prompt_key]
|
||||
print(f"__call:{prompt}")
|
||||
|
||||
if prompt == "clear":
|
||||
history=[]
|
||||
|
||||
local_history = []
|
||||
|
||||
if not history:
|
||||
history =[]
|
||||
|
||||
for k,v in history:
|
||||
if k:
|
||||
local_history.append(k)
|
||||
local_history.append(v)
|
||||
|
||||
local_history.append(prompt)
|
||||
|
||||
if streaming:
|
||||
history += [[]]
|
||||
pieces = []
|
||||
print(f"++++++++++++++Stream++++++++++++++++++++")
|
||||
for piece in self.checkPoint.model.stream_chat(
|
||||
local_history,
|
||||
max_length=self.max_length,
|
||||
max_context_length=self.max_context_length,
|
||||
do_sample=self.temperature > 0,
|
||||
top_k=self.top_k,
|
||||
top_p=self.top_p,
|
||||
temperature=self.temperature,
|
||||
):
|
||||
pieces.append(piece)
|
||||
reply = ''.join(pieces)
|
||||
print(f"{piece}",end='')
|
||||
|
||||
answer_result = AnswerResult()
|
||||
history[-1] = [prompt, reply]
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": reply}
|
||||
generate_with_callback(answer_result)
|
||||
print("")
|
||||
else :
|
||||
reply = self.checkPoint.model.chat(
|
||||
local_history,
|
||||
max_length=self.max_length,
|
||||
max_context_length=self.max_context_length,
|
||||
do_sample=self.temperature > 0,
|
||||
top_k=self.top_k,
|
||||
top_p=self.top_p,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
|
||||
print(f"response:{reply}")
|
||||
print(f"+++++++++++++++++++++++++++++++++++")
|
||||
|
||||
answer_result = AnswerResult()
|
||||
history.append([prompt, reply])
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": reply}
|
||||
generate_with_callback(answer_result)
|
||||
|
|
@ -11,7 +11,6 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
|||
AutoTokenizer, LlamaTokenizer)
|
||||
from configs.model_config import LLM_DEVICE
|
||||
|
||||
|
||||
class LoaderCheckPoint:
|
||||
"""
|
||||
加载自定义 model CheckPoint
|
||||
|
|
@ -68,6 +67,8 @@ class LoaderCheckPoint:
|
|||
self.load_in_8bit = params.get('load_in_8bit', False)
|
||||
self.bf16 = params.get('bf16', False)
|
||||
|
||||
self.is_chatgmlcpp = "chatglm2-cpp" == self.model_name
|
||||
|
||||
def _load_model_config(self):
|
||||
|
||||
if self.model_path:
|
||||
|
|
@ -119,7 +120,7 @@ class LoaderCheckPoint:
|
|||
# 如果加载没问题,但在推理时报错RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
|
||||
# 那还是因为显存不够,此时只能考虑--load-in-8bit,或者配置默认模型为`chatglm-6b-int8`
|
||||
if not any([self.llm_device.lower() == "cpu",
|
||||
self.load_in_8bit, self.is_llamacpp]):
|
||||
self.load_in_8bit, self.is_llamacpp, self.is_chatgmlcpp]):
|
||||
|
||||
if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"):
|
||||
# 根据当前设备GPU数量决定是否进行多卡部署
|
||||
|
|
@ -183,11 +184,34 @@ class LoaderCheckPoint:
|
|||
.to(self.llm_device)
|
||||
)
|
||||
|
||||
elif self.is_chatgmlcpp :
|
||||
try:
|
||||
import chatglm_cpp
|
||||
except ImportError as exc:
|
||||
import platform
|
||||
if platform.system() == "Darwin":
|
||||
raise ValueError(
|
||||
"Could not import depend python package "
|
||||
"Please install it with `pip install chatglm-cpp`."
|
||||
) from exc
|
||||
else :
|
||||
raise SystemError(
|
||||
f"chatglm-cpp not support {platform.system()}."
|
||||
) from exc
|
||||
|
||||
model = (
|
||||
LoaderClass.from_pretrained(
|
||||
checkpoint,
|
||||
config=self.model_config,
|
||||
trust_remote_code=True)
|
||||
)
|
||||
# model = chatglm_cpp.Pipeline(f'{self.model_path}/{self.model_name}.bin')
|
||||
tokenizer = getattr(model, "tokenizer")
|
||||
return model, tokenizer
|
||||
|
||||
elif self.is_llamacpp:
|
||||
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import depend python package "
|
||||
|
|
@ -473,5 +497,5 @@ class LoaderCheckPoint:
|
|||
except Exception as e:
|
||||
print("加载PrefixEncoder模型参数失败")
|
||||
# llama-cpp模型(至少vicuna-13b)的eval方法就是自身,其没有eval方法
|
||||
if not self.is_llamacpp:
|
||||
if not self.is_llamacpp and not self.is_chatgmlcpp:
|
||||
self.model = self.model.eval()
|
||||
|
|
|
|||
Loading…
Reference in New Issue