add torch_gc to clear gpu cache in knowledge_based_chatglm.py
This commit is contained in:
parent
3cbc6aa77c
commit
5664d1ff62
|
|
@ -2,6 +2,19 @@ from langchain.llms.base import LLM
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
import torch
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
DEVICE_ID = "0"
|
||||||
|
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
||||||
|
|
||||||
|
|
||||||
|
def torch_gc():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
with torch.cuda.device(CUDA_DEVICE):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
"THUDM/chatglm-6b",
|
"THUDM/chatglm-6b",
|
||||||
|
|
@ -15,6 +28,7 @@ model = (
|
||||||
.cuda()
|
.cuda()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatGLM(LLM):
|
class ChatGLM(LLM):
|
||||||
max_token: int = 10000
|
max_token: int = 10000
|
||||||
temperature: float = 0.1
|
temperature: float = 0.1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue