From 75cf9f9b4edd8a1412c4af365320a964155c3daa Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Thu, 18 May 2023 23:19:23 +0800 Subject: [PATCH] =?UTF-8?q?ptuning-v2=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py | 2 -- models/loader/loader.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/configs/model_config.py b/configs/model_config.py index 3a9b469..c2b0db3 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -88,8 +88,6 @@ USE_PTUNING_V2 = False # LLM running device LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" -# MOSS load in 8bit -LOAD_IN_8BIT = True VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store") diff --git a/models/loader/loader.py b/models/loader/loader.py index c50f7d1..201c580 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -65,7 +65,7 @@ class LoaderCheckPoint: self.tokenizer = None self.model_dir = params.get('model_dir', '') self.lora_dir = params.get('lora_dir', '') - self.ptuning_dir = params.get('ptuning_dir', '') + self.ptuning_dir = params.get('ptuning_dir', 'ptuning-v2') self.cpu = params.get('cpu', False) self.gpu_memory = params.get('gpu_memory', None) self.cpu_memory = params.get('cpu_memory', None)