Merge branch 'dev' of github.com:imClumsyPanda/langchain-ChatGLM into dev
This commit is contained in:
commit
6ec1e56e03
|
|
@ -203,7 +203,7 @@ llm_model_dict = {
|
||||||
}
|
}
|
||||||
|
|
||||||
# LLM 名称
|
# LLM 名称
|
||||||
LLM_MODEL = "fastchat-chatglm-6b-int4"
|
LLM_MODEL = "fastchat-chatglm"
|
||||||
# 量化加载8bit 模型
|
# 量化加载8bit 模型
|
||||||
LOAD_IN_8BIT = False
|
LOAD_IN_8BIT = False
|
||||||
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
||||||
|
|
@ -220,7 +220,7 @@ STREAMING = True
|
||||||
|
|
||||||
# Use p-tuning-v2 PrefixEncoder
|
# Use p-tuning-v2 PrefixEncoder
|
||||||
USE_PTUNING_V2 = False
|
USE_PTUNING_V2 = False
|
||||||
PTUNING_DIR='./ptuing-v2'
|
PTUNING_DIR='./ptuning-v2'
|
||||||
# LLM running device
|
# LLM running device
|
||||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,14 @@ from abc import ABC
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from typing import Any, Dict, List, Optional, Generator
|
from typing import Any, Dict, List, Optional, Generator
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
# from transformers.generation.logits_process import LogitsProcessor
|
||||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
# from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult,
|
AnswerResult,
|
||||||
AnswerResultStream,
|
AnswerResultStream,
|
||||||
AnswerResultQueueSentinelTokenListenerQueue)
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
import torch
|
# import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th
|
||||||
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
||||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||||
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
||||||
parser.add_argument('--use-ptuning-v2',type=str,default=USE_PTUNING_V2,help="whether use ptuning-v2 checkpoint")
|
parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint")
|
||||||
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
|
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
|
||||||
# Accelerate/transformers
|
# Accelerate/transformers
|
||||||
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
||||||
|
|
|
||||||
|
|
@ -441,7 +441,7 @@ class LoaderCheckPoint:
|
||||||
|
|
||||||
if self.use_ptuning_v2:
|
if self.use_ptuning_v2:
|
||||||
try:
|
try:
|
||||||
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r')
|
prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r')
|
||||||
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
||||||
prefix_encoder_file.close()
|
prefix_encoder_file.close()
|
||||||
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||||
|
|
@ -457,13 +457,14 @@ class LoaderCheckPoint:
|
||||||
|
|
||||||
if self.use_ptuning_v2:
|
if self.use_ptuning_v2:
|
||||||
try:
|
try:
|
||||||
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin'))
|
prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin'))
|
||||||
new_prefix_state_dict = {}
|
new_prefix_state_dict = {}
|
||||||
for k, v in prefix_state_dict.items():
|
for k, v in prefix_state_dict.items():
|
||||||
if k.startswith("transformer.prefix_encoder."):
|
if k.startswith("transformer.prefix_encoder."):
|
||||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||||
self.model.transformer.prefix_encoder.float()
|
self.model.transformer.prefix_encoder.float()
|
||||||
|
print("加载ptuning检查点成功!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
print("加载PrefixEncoder模型参数失败")
|
print("加载PrefixEncoder模型参数失败")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue