fix: model_worker need global variable: args

This commit is contained in:
liunux4odoo 2023-07-29 23:22:25 +08:00
parent c880412300
commit 1a7271e966
1 changed files with 20 additions and 0 deletions

View File

@ -54,6 +54,7 @@ def create_model_worker_app(
):
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
from fastchat.serve import model_worker
import argparse
from loguru import logger
logger.add(os.path.join(LOG_PATH, "model_worker.log"), level="INFO")
@ -79,7 +80,26 @@ def create_model_worker_app(
cpu_offloading,
gptq_config,
)
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.model_path=model_path
args.model_names=model_names
args.device=device
args.load_8bit=load_8bit
args.gptq_ckpt=gptq_ckpt
args.gptq_wbits=gptq_wbits
args.gpus=gpus
args.num_gpus=num_gpus
args.max_gpu_memory=max_gpu_memory
args.cpu_offloading=cpu_offloading
args.worker_address=worker_address
args.controller_address=controller_address
args.limit_model_concurrency=limit_model_concurrency
args.stream_interval=stream_interval
args.no_register=no_register
sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
sys.modules["fastchat.serve.model_worker"].logger = logger
return app