diff --git a/configs/model_config.py b/configs/model_config.py index ea24c36..4a51c5c 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -203,7 +203,7 @@ llm_model_dict = { } # LLM 名称 -LLM_MODEL = "fastchat-chatglm-6b-int4" +LLM_MODEL = "fastchat-chatglm" # 量化加载8bit 模型 LOAD_IN_8BIT = False # Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. @@ -220,7 +220,7 @@ STREAMING = True # Use p-tuning-v2 PrefixEncoder USE_PTUNING_V2 = False -PTUNING_DIR='./ptuing-v2' +PTUNING_DIR='./ptuning-v2' # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 81878ce..0d19ee6 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -2,14 +2,14 @@ from abc import ABC from langchain.chains.base import Chain from typing import Any, Dict, List, Optional, Generator from langchain.callbacks.manager import CallbackManagerForChainRun -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList +# from transformers.generation.logits_process import LogitsProcessor +# from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, AnswerResult, AnswerResultStream, AnswerResultQueueSentinelTokenListenerQueue) -import torch +# import torch import transformers