fix: model_worker need global variable: args
This commit is contained in:
parent
c880412300
commit
1a7271e966
|
|
@ -54,6 +54,7 @@ def create_model_worker_app(
|
|||
):
|
||||
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
|
||||
from fastchat.serve import model_worker
|
||||
import argparse
|
||||
from loguru import logger
|
||||
logger.add(os.path.join(LOG_PATH, "model_worker.log"), level="INFO")
|
||||
|
||||
|
|
@ -79,7 +80,26 @@ def create_model_worker_app(
|
|||
cpu_offloading,
|
||||
gptq_config,
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
args.model_path=model_path
|
||||
args.model_names=model_names
|
||||
args.device=device
|
||||
args.load_8bit=load_8bit
|
||||
args.gptq_ckpt=gptq_ckpt
|
||||
args.gptq_wbits=gptq_wbits
|
||||
args.gpus=gpus
|
||||
args.num_gpus=num_gpus
|
||||
args.max_gpu_memory=max_gpu_memory
|
||||
args.cpu_offloading=cpu_offloading
|
||||
args.worker_address=worker_address
|
||||
args.controller_address=controller_address
|
||||
args.limit_model_concurrency=limit_model_concurrency
|
||||
args.stream_interval=stream_interval
|
||||
args.no_register=no_register
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
sys.modules["fastchat.serve.model_worker"].logger = logger
|
||||
return app
|
||||
|
|
|
|||
Loading…
Reference in New Issue