update chatglm_llm.py
This commit is contained in:
parent
8db6cd5119
commit
57d4d0a29e
|
|
@ -1,13 +1,9 @@
|
||||||
import json
|
import json
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from typing import Optional, List
|
from typing import List, Dict, Optional
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
|
||||||
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
||||||
import torch
|
import torch
|
||||||
from configs.model_config import *
|
from configs.model_config import *
|
||||||
from langchain.callbacks.base import CallbackManager
|
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
||||||
from typing import Dict, Tuple, Union, Optional
|
|
||||||
from utils import torch_gc
|
from utils import torch_gc
|
||||||
|
|
||||||
DEVICE_ = LLM_DEVICE
|
DEVICE_ = LLM_DEVICE
|
||||||
|
|
@ -53,7 +49,6 @@ class ChatGLM(LLM):
|
||||||
tokenizer: object = None
|
tokenizer: object = None
|
||||||
model: object = None
|
model: object = None
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue