add torch_gc to clear gpu cache in knowledge_based_chatglm.py
This commit is contained in:
parent
3cbc6aa77c
commit
5664d1ff62
|
|
@ -2,6 +2,19 @@ from langchain.llms.base import LLM
|
|||
from typing import Optional, List
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
|
||||
DEVICE = "cuda"
|
||||
DEVICE_ID = "0"
|
||||
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
||||
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(CUDA_DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"THUDM/chatglm-6b",
|
||||
|
|
@ -15,6 +28,7 @@ model = (
|
|||
.cuda()
|
||||
)
|
||||
|
||||
|
||||
class ChatGLM(LLM):
|
||||
max_token: int = 10000
|
||||
temperature: float = 0.1
|
||||
|
|
|
|||
Loading…
Reference in New Issue