Merge pull request #904 from bones-zhu/dev

1 修复model_config.py的type;2 修改默认的LLM_model;3. 移除chatglm_llm.py的无用模块
This commit is contained in:
Zhi-guo Huang 2023-07-21 15:14:33 +08:00 committed by GitHub
commit 7d37dc876e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 5 deletions

View File

@ -203,7 +203,7 @@ llm_model_dict = {
} }
# LLM 名称 # LLM 名称
LLM_MODEL = "fastchat-chatglm-6b-int4" LLM_MODEL = "fastchat-chatglm"
# 量化加载8bit 模型 # 量化加载8bit 模型
LOAD_IN_8BIT = False LOAD_IN_8BIT = False
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
@ -220,7 +220,7 @@ STREAMING = True
# Use p-tuning-v2 PrefixEncoder # Use p-tuning-v2 PrefixEncoder
USE_PTUNING_V2 = False USE_PTUNING_V2 = False
PTUNING_DIR='./ptuing-v2' PTUNING_DIR='./ptuning-v2'
# LLM running device # LLM running device
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

View File

@ -2,14 +2,14 @@ from abc import ABC
from langchain.chains.base import Chain from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator from typing import Any, Dict, List, Optional, Generator
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor # from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer, from models.base import (BaseAnswer,
AnswerResult, AnswerResult,
AnswerResultStream, AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue) AnswerResultQueueSentinelTokenListenerQueue)
import torch # import torch
import transformers import transformers