From 5664d1ff62301ba9f6137751c48fa62ec08bb359 Mon Sep 17 00:00:00 2001 From: littlepanda0716 Date: Fri, 7 Apr 2023 10:45:44 +0800 Subject: [PATCH] add torch_gc to clear gpu cache in knowledge_based_chatglm.py --- chatglm_llm.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/chatglm_llm.py b/chatglm_llm.py index 3d5f4e7..3833e4f 100644 --- a/chatglm_llm.py +++ b/chatglm_llm.py @@ -2,6 +2,19 @@ from langchain.llms.base import LLM from typing import Optional, List from langchain.llms.utils import enforce_stop_tokens from transformers import AutoTokenizer, AutoModel +import torch + +DEVICE = "cuda" +DEVICE_ID = "0" +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE + + +def torch_gc(): + if torch.cuda.is_available(): + with torch.cuda.device(CUDA_DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + tokenizer = AutoTokenizer.from_pretrained( "THUDM/chatglm-6b", @@ -15,6 +28,7 @@ model = ( .cuda() ) + class ChatGLM(LLM): max_token: int = 10000 temperature: float = 0.1