Merge branch 'dev' of github.com:imClumsyPanda/langchain-ChatGLM into dev

This commit is contained in:
hzg0601 2023-07-21 15:29:53 +08:00
commit 6ec1e56e03
4 changed files with 9 additions and 8 deletions

View File

@ -203,7 +203,7 @@ llm_model_dict = {
} }
# LLM 名称 # LLM 名称
LLM_MODEL = "fastchat-chatglm-6b-int4" LLM_MODEL = "fastchat-chatglm"
# 量化加载8bit 模型 # 量化加载8bit 模型
LOAD_IN_8BIT = False LOAD_IN_8BIT = False
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
@ -220,7 +220,7 @@ STREAMING = True
# Use p-tuning-v2 PrefixEncoder # Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False USE_PTUNING_V2 = False
PTUNING_DIR='./ptuing-v2' PTUNING_DIR='./ptuning-v2'
# LLM running device # LLM running device
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

View File

@ -2,14 +2,14 @@ from abc import ABC
from langchain.chains.base import Chain from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator from typing import Any, Dict, List, Optional, Generator
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor # from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult,
AnswerResultStream, AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue) AnswerResultQueueSentinelTokenListenerQueue)
import torch # import torch
import transformers import transformers

View File

@ -44,7 +44,7 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.') parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
parser.add_argument('--use-ptuning-v2',type=str,default=USE_PTUNING_V2,help="whether use ptuning-v2 checkpoint") parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint")
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint") parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
# Accelerate/transformers # Accelerate/transformers
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT, parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,

View File

@ -441,7 +441,7 @@ class LoaderCheckPoint:
if self.use_ptuning_v2: if self.use_ptuning_v2:
try: try:
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r')
prefix_encoder_config = json.loads(prefix_encoder_file.read()) prefix_encoder_config = json.loads(prefix_encoder_file.read())
prefix_encoder_file.close() prefix_encoder_file.close()
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
@ -457,13 +457,14 @@ class LoaderCheckPoint:
if self.use_ptuning_v2: if self.use_ptuning_v2:
try: try:
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin'))
new_prefix_state_dict = {} new_prefix_state_dict = {}
for k, v in prefix_state_dict.items(): for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."): if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float() self.model.transformer.prefix_encoder.float()
print("加载ptuning检查点成功")
except Exception as e: except Exception as e:
print(e) print(e)
print("加载PrefixEncoder模型参数失败") print("加载PrefixEncoder模型参数失败")