Adding an implementaion for involving with accessing Baichuan-13B-Chat (#1005)
This commit is contained in:
parent
22c6192561
commit
62047c880e
|
|
@ -241,6 +241,17 @@ class LocalDocQA:
|
|||
else:
|
||||
prompt = query
|
||||
|
||||
# 接入baichuan的代码分支:
|
||||
if LLM_MODEL == "Baichuan-13B-Chat":
|
||||
for answer_result in self.llm_model_chain._generate_answer(prompt=prompt, history=chat_history,
|
||||
streaming=streaming):
|
||||
resp = answer_result.llm_output["answer"]
|
||||
history = answer_result.history
|
||||
response = {"query": query,
|
||||
"result": resp,
|
||||
"source_documents": related_docs_with_score}
|
||||
yield response, history
|
||||
else: # 原本逻辑分支:
|
||||
answer_result_stream_result = self.llm_model_chain(
|
||||
{"prompt": prompt, "history": chat_history, "streaming": streaming})
|
||||
|
||||
|
|
|
|||
|
|
@ -156,6 +156,12 @@ llm_model_dict = {
|
|||
"local_model_path": None,
|
||||
"provides": "MOSSLLMChain"
|
||||
},
|
||||
"Baichuan-13b-Chat": {
|
||||
"name": "Baichuan-13b-Chat",
|
||||
"pretrained_model_name": "baichuan-inc/Baichuan-13b-Chat",
|
||||
"local_model_path": None,
|
||||
"provides": "BaichuanLLMChain"
|
||||
},
|
||||
# llama-cpp模型的兼容性问题参考https://github.com/abetlen/llama-cpp-python/issues/204
|
||||
"ggml-vicuna-13b-1.1-q5": {
|
||||
"name": "ggml-vicuna-13b-1.1-q5",
|
||||
|
|
@ -244,6 +250,9 @@ USE_LORA = True if LORA_NAME else False
|
|||
# LLM streaming reponse
|
||||
STREAMING = True
|
||||
|
||||
# 直接定义baichuan的lora完整路径即可
|
||||
LORA_MODEL_PATH_BAICHUAN=""
|
||||
|
||||
# Use p-tuning-v2 PrefixEncoder
|
||||
USE_PTUNING_V2 = False
|
||||
PTUNING_DIR='./ptuning-v2'
|
||||
|
|
|
|||
|
|
@ -3,3 +3,5 @@ from .llama_llm import LLamaLLMChain
|
|||
from .chatglmcpp_llm import ChatGLMCppLLMChain
|
||||
from .fastchat_openai_llm import FastChatOpenAILLMChain
|
||||
from .moss_llm import MOSSLLMChain
|
||||
from .baichuan_llm import BaichuanLLMChain
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
from abc import ABC
|
||||
from langchain.llms.base import LLM
|
||||
from typing import Optional, List
|
||||
from models.loader import LoaderCheckPoint
|
||||
from models.base import (BaseAnswer,
|
||||
AnswerResult)
|
||||
|
||||
class BaichuanLLMChain(BaseAnswer, LLM, ABC):
|
||||
max_token: int = 10000
|
||||
temperature: float = 0.01
|
||||
top_p = 0.9
|
||||
checkPoint: LoaderCheckPoint = None
|
||||
# history = []
|
||||
history_len: int = 10
|
||||
|
||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||
super().__init__()
|
||||
self.checkPoint = checkPoint
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "BaichuanLLMChain"
|
||||
|
||||
@property
|
||||
def _check_point(self) -> LoaderCheckPoint:
|
||||
return self.checkPoint
|
||||
|
||||
@property
|
||||
def _history_len(self) -> int:
|
||||
return self.history_len
|
||||
|
||||
def set_history_len(self, history_len: int = 10) -> None:
|
||||
self.history_len = history_len
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
print(f"__call:{prompt}")
|
||||
response, _ = self.checkPoint.model.chat(
|
||||
self.checkPoint.tokenizer,
|
||||
prompt,
|
||||
# history=[],
|
||||
# max_length=self.max_token,
|
||||
# temperature=self.temperature
|
||||
)
|
||||
print(f"response:{response}")
|
||||
print(f"+++++++++++++++++++++++++++++++++++")
|
||||
return response
|
||||
|
||||
def _generate_answer(self, prompt: str,
|
||||
history: List[List[str]] = [],
|
||||
streaming: bool = False):
|
||||
messages = []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
if streaming:
|
||||
for inum, stream_resp in enumerate(self.checkPoint.model.chat(
|
||||
self.checkPoint.tokenizer,
|
||||
messages,
|
||||
stream=True
|
||||
)):
|
||||
self.checkPoint.clear_torch_cache()
|
||||
answer_result = AnswerResult()
|
||||
answer_result.llm_output = {"answer": stream_resp}
|
||||
yield answer_result
|
||||
else:
|
||||
response = self.checkPoint.model.chat(
|
||||
self.checkPoint.tokenizer,
|
||||
messages
|
||||
)
|
||||
self.checkPoint.clear_torch_cache()
|
||||
answer_result = AnswerResult()
|
||||
answer_result.llm_output = {"answer": response}
|
||||
yield answer_result
|
||||
|
|
@ -9,7 +9,9 @@ import torch
|
|||
import transformers
|
||||
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoTokenizer, LlamaTokenizer)
|
||||
from configs.model_config import LLM_DEVICE
|
||||
from configs.model_config import LLM_DEVICE, LLM_MODEL, LORA_MODEL_PATH_BAICHUAN
|
||||
from peft import PeftModel
|
||||
from transformers.generation.utils import GenerationConfig
|
||||
|
||||
class LoaderCheckPoint:
|
||||
"""
|
||||
|
|
@ -126,6 +128,26 @@ class LoaderCheckPoint:
|
|||
# 根据当前设备GPU数量决定是否进行多卡部署
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus < 2 and self.device_map is None:
|
||||
# if LORA_MODEL_PATH_BAICHUAN is not None:
|
||||
if LORA_MODEL_PATH_BAICHUAN:
|
||||
if LLM_MODEL == "Baichuan-13B-Chat":
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.float16,
|
||||
device_map="auto", trust_remote_code=True, )
|
||||
model.generation_config = GenerationConfig.from_pretrained(checkpoint)
|
||||
from configs.model_config import LLM_DEVICE, LORA_MODEL_PATH_BAICHUAN
|
||||
# if LORA_MODEL_PATH_BAICHUAN is not None:
|
||||
if LORA_MODEL_PATH_BAICHUAN:
|
||||
print("loading lora:{path}".format(path=LORA_MODEL_PATH_BAICHUAN))
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
LORA_MODEL_PATH_BAICHUAN,
|
||||
torch_dtype=torch.float16,
|
||||
device_map={"": LLM_DEVICE}
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=False,
|
||||
trust_remote_code=True)
|
||||
model.half().cuda()
|
||||
else:
|
||||
model = (
|
||||
LoaderClass.from_pretrained(checkpoint,
|
||||
config=self.model_config,
|
||||
|
|
|
|||
Loading…
Reference in New Issue