add torch_gc to clear gpu cache in knowledge_based_chatglm.py

This commit is contained in:
littlepanda0716 2023-04-07 10:45:44 +08:00
parent 3cbc6aa77c
commit 5664d1ff62
1 changed files with 14 additions and 0 deletions

View File

@ -2,6 +2,19 @@ from langchain.llms.base import LLM
from typing import Optional, List
from langchain.llms.utils import enforce_stop_tokens
from transformers import AutoTokenizer, AutoModel
import torch
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b",
@ -15,6 +28,7 @@ model = (
.cuda()
)
class ChatGLM(LLM):
max_token: int = 10000
temperature: float = 0.1