diff --git a/chatglm_llm.py b/chatglm_llm.py index 3d5f4e7..3833e4f 100644 --- a/chatglm_llm.py +++ b/chatglm_llm.py @@ -2,6 +2,19 @@ from langchain.llms.base import LLM from typing import Optional, List from langchain.llms.utils import enforce_stop_tokens from transformers import AutoTokenizer, AutoModel +import torch + +DEVICE = "cuda" +DEVICE_ID = "0" +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE + + +def torch_gc(): + if torch.cuda.is_available(): + with torch.cuda.device(CUDA_DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + tokenizer = AutoTokenizer.from_pretrained( "THUDM/chatglm-6b", @@ -15,6 +28,7 @@ model = ( .cuda() ) + class ChatGLM(LLM): max_token: int = 10000 temperature: float = 0.1