Langchain-Chatchat-0.3.1/libs/chatchat-server/chatchat/server/agent/container.py

74 lines
2.3 KiB
Python
Raw Normal View History

2024-12-27 17:25:03 +08:00
import logging
from chatchat.server.utils import get_tool_config
from chatchat.utils import build_logger
logger = build_logger()
class ModelContainer:
def __init__(self):
self.model = None
self.metadata = None
self.vision_model = None
self.vision_tokenizer = None
self.audio_tokenizer = None
self.audio_model = None
vqa_config = get_tool_config("vqa_processor")
if vqa_config["use"]:
try:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
LlamaTokenizer,
)
self.vision_tokenizer = LlamaTokenizer.from_pretrained(
vqa_config["tokenizer_path"], trust_remote_code=True
)
self.vision_model = (
AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=vqa_config["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
.to(vqa_config["device"])
.eval()
)
except Exception as e:
logger.error(e, exc_info=True)
aqa_config = get_tool_config("vqa_processor")
if aqa_config["use"]:
try:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
LlamaTokenizer,
)
self.audio_tokenizer = AutoTokenizer.from_pretrained(
aqa_config["tokenizer_path"], trust_remote_code=True
)
self.audio_model = (
AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=aqa_config["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
.to(aqa_config["device"])
.eval()
)
except Exception as e:
logger.error(e, exc_info=True)
container = ModelContainer()