support chatglm2cpp
This commit is contained in:
parent
c5bc21781c
commit
62ba7679fe
|
|
@ -174,4 +174,8 @@ embedding/*
|
||||||
|
|
||||||
pyrightconfig.json
|
pyrightconfig.json
|
||||||
loader/tmp_files
|
loader/tmp_files
|
||||||
|
<<<<<<< HEAD
|
||||||
flagged/*
|
flagged/*
|
||||||
|
=======
|
||||||
|
flagged/*
|
||||||
|
>>>>>>> 19147c3 (Update .gitignore)
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,13 @@ llm_model_dict = {
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "ChatGLMLLMChain"
|
"provides": "ChatGLMLLMChain"
|
||||||
},
|
},
|
||||||
|
# 注:chatglm2-cpp已在mac上测试通过,其他系统暂不支持
|
||||||
|
"chatglm2-cpp": {
|
||||||
|
"name": "chatglm2-cpp",
|
||||||
|
"pretrained_model_name": "cylee0909/chatglm2cpp",
|
||||||
|
"local_model_path": None,
|
||||||
|
"provides": "ChatGLMCppLLMChain"
|
||||||
|
},
|
||||||
"chatglm2-6b-int4": {
|
"chatglm2-6b-int4": {
|
||||||
"name": "chatglm2-6b-int4",
|
"name": "chatglm2-6b-int4",
|
||||||
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
|
"pretrained_model_name": "THUDM/chatglm2-6b-int4",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from .chatglm_llm import ChatGLMLLMChain
|
from .chatglm_llm import ChatGLMLLMChain
|
||||||
from .llama_llm import LLamaLLMChain
|
from .llama_llm import LLamaLLMChain
|
||||||
|
from .chatglmcpp_llm import ChatGLMCppLLMChain
|
||||||
from .fastchat_openai_llm import FastChatOpenAILLMChain
|
from .fastchat_openai_llm import FastChatOpenAILLMChain
|
||||||
from .moss_llm import MOSSLLMChain
|
from .moss_llm import MOSSLLMChain
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
from transformers.generation.utils import (LogitsProcessorList,
|
||||||
|
StoppingCriteriaList)
|
||||||
|
|
||||||
|
from models.base import (AnswerResult,
|
||||||
|
AnswerResultStream, BaseAnswer)
|
||||||
|
from models.loader import LoaderCheckPoint
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMCppLLMChain(BaseAnswer, Chain, ABC):
|
||||||
|
checkPoint: LoaderCheckPoint = None
|
||||||
|
streaming_key: str = "streaming" #: :meta private:
|
||||||
|
history_key: str = "history" #: :meta private:
|
||||||
|
prompt_key: str = "prompt" #: :meta private:
|
||||||
|
output_key: str = "answer_result_stream" #: :meta private:
|
||||||
|
|
||||||
|
max_length = 2048
|
||||||
|
max_context_length = 512
|
||||||
|
do_sample = True
|
||||||
|
top_k = 0
|
||||||
|
top_p = 0.7
|
||||||
|
temperature = 0.95
|
||||||
|
num_threads = 0
|
||||||
|
|
||||||
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
|
super().__init__()
|
||||||
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
return "ChatglmCppLLMChain"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.prompt_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
|
return self.checkPoint
|
||||||
|
|
||||||
|
def encode(self, prompt, truncation_length=None):
|
||||||
|
input_ids = self.checkPoint.tokenizer.encode(str(prompt))
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def decode(self, output_ids):
|
||||||
|
reply = self.checkPoint.tokenizer.decode(output_ids)
|
||||||
|
return reply
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Generator]:
|
||||||
|
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||||
|
return {self.output_key: generator}
|
||||||
|
|
||||||
|
def _generate_answer(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
|
||||||
|
history = inputs[self.history_key]
|
||||||
|
streaming = inputs[self.streaming_key]
|
||||||
|
prompt = inputs[self.prompt_key]
|
||||||
|
print(f"__call:{prompt}")
|
||||||
|
|
||||||
|
if prompt == "clear":
|
||||||
|
history=[]
|
||||||
|
|
||||||
|
local_history = []
|
||||||
|
|
||||||
|
if not history:
|
||||||
|
history =[]
|
||||||
|
|
||||||
|
for k,v in history:
|
||||||
|
if k:
|
||||||
|
local_history.append(k)
|
||||||
|
local_history.append(v)
|
||||||
|
|
||||||
|
local_history.append(prompt)
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
history += [[]]
|
||||||
|
pieces = []
|
||||||
|
print(f"++++++++++++++Stream++++++++++++++++++++")
|
||||||
|
for piece in self.checkPoint.model.stream_chat(
|
||||||
|
local_history,
|
||||||
|
max_length=self.max_length,
|
||||||
|
max_context_length=self.max_context_length,
|
||||||
|
do_sample=self.temperature > 0,
|
||||||
|
top_k=self.top_k,
|
||||||
|
top_p=self.top_p,
|
||||||
|
temperature=self.temperature,
|
||||||
|
):
|
||||||
|
pieces.append(piece)
|
||||||
|
reply = ''.join(pieces)
|
||||||
|
print(f"{piece}",end='')
|
||||||
|
|
||||||
|
answer_result = AnswerResult()
|
||||||
|
history[-1] = [prompt, reply]
|
||||||
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": reply}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
print("")
|
||||||
|
else :
|
||||||
|
reply = self.checkPoint.model.chat(
|
||||||
|
local_history,
|
||||||
|
max_length=self.max_length,
|
||||||
|
max_context_length=self.max_context_length,
|
||||||
|
do_sample=self.temperature > 0,
|
||||||
|
top_k=self.top_k,
|
||||||
|
top_p=self.top_p,
|
||||||
|
temperature=self.temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"response:{reply}")
|
||||||
|
print(f"+++++++++++++++++++++++++++++++++++")
|
||||||
|
|
||||||
|
answer_result = AnswerResult()
|
||||||
|
history.append([prompt, reply])
|
||||||
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": reply}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
|
@ -11,7 +11,6 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||||
AutoTokenizer, LlamaTokenizer)
|
AutoTokenizer, LlamaTokenizer)
|
||||||
from configs.model_config import LLM_DEVICE
|
from configs.model_config import LLM_DEVICE
|
||||||
|
|
||||||
|
|
||||||
class LoaderCheckPoint:
|
class LoaderCheckPoint:
|
||||||
"""
|
"""
|
||||||
加载自定义 model CheckPoint
|
加载自定义 model CheckPoint
|
||||||
|
|
@ -68,6 +67,8 @@ class LoaderCheckPoint:
|
||||||
self.load_in_8bit = params.get('load_in_8bit', False)
|
self.load_in_8bit = params.get('load_in_8bit', False)
|
||||||
self.bf16 = params.get('bf16', False)
|
self.bf16 = params.get('bf16', False)
|
||||||
|
|
||||||
|
self.is_chatgmlcpp = "chatglm2-cpp" == self.model_name
|
||||||
|
|
||||||
def _load_model_config(self):
|
def _load_model_config(self):
|
||||||
|
|
||||||
if self.model_path:
|
if self.model_path:
|
||||||
|
|
@ -119,7 +120,7 @@ class LoaderCheckPoint:
|
||||||
# 如果加载没问题,但在推理时报错RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
|
# 如果加载没问题,但在推理时报错RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
|
||||||
# 那还是因为显存不够,此时只能考虑--load-in-8bit,或者配置默认模型为`chatglm-6b-int8`
|
# 那还是因为显存不够,此时只能考虑--load-in-8bit,或者配置默认模型为`chatglm-6b-int8`
|
||||||
if not any([self.llm_device.lower() == "cpu",
|
if not any([self.llm_device.lower() == "cpu",
|
||||||
self.load_in_8bit, self.is_llamacpp]):
|
self.load_in_8bit, self.is_llamacpp, self.is_chatgmlcpp]):
|
||||||
|
|
||||||
if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"):
|
if torch.cuda.is_available() and self.llm_device.lower().startswith("cuda"):
|
||||||
# 根据当前设备GPU数量决定是否进行多卡部署
|
# 根据当前设备GPU数量决定是否进行多卡部署
|
||||||
|
|
@ -183,11 +184,34 @@ class LoaderCheckPoint:
|
||||||
.to(self.llm_device)
|
.to(self.llm_device)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.is_chatgmlcpp :
|
||||||
|
try:
|
||||||
|
import chatglm_cpp
|
||||||
|
except ImportError as exc:
|
||||||
|
import platform
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import depend python package "
|
||||||
|
"Please install it with `pip install chatglm-cpp`."
|
||||||
|
) from exc
|
||||||
|
else :
|
||||||
|
raise SystemError(
|
||||||
|
f"chatglm-cpp not support {platform.system()}."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
model = (
|
||||||
|
LoaderClass.from_pretrained(
|
||||||
|
checkpoint,
|
||||||
|
config=self.model_config,
|
||||||
|
trust_remote_code=True)
|
||||||
|
)
|
||||||
|
# model = chatglm_cpp.Pipeline(f'{self.model_path}/{self.model_name}.bin')
|
||||||
|
tokenizer = getattr(model, "tokenizer")
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
elif self.is_llamacpp:
|
elif self.is_llamacpp:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from llama_cpp import Llama
|
from llama_cpp import Llama
|
||||||
|
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not import depend python package "
|
"Could not import depend python package "
|
||||||
|
|
@ -473,5 +497,5 @@ class LoaderCheckPoint:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("加载PrefixEncoder模型参数失败")
|
print("加载PrefixEncoder模型参数失败")
|
||||||
# llama-cpp模型(至少vicuna-13b)的eval方法就是自身,其没有eval方法
|
# llama-cpp模型(至少vicuna-13b)的eval方法就是自身,其没有eval方法
|
||||||
if not self.is_llamacpp:
|
if not self.is_llamacpp and not self.is_chatgmlcpp:
|
||||||
self.model = self.model.eval()
|
self.model = self.model.eval()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue