diff --git a/server/llm_api.py b/server/llm_api.py index a67076e..9cb4c62 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -54,6 +54,7 @@ def create_model_worker_app( ): from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id from fastchat.serve import model_worker + import argparse from loguru import logger logger.add(os.path.join(LOG_PATH, "model_worker.log"), level="INFO") @@ -79,7 +80,26 @@ def create_model_worker_app( cpu_offloading, gptq_config, ) + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.model_path=model_path + args.model_names=model_names + args.device=device + args.load_8bit=load_8bit + args.gptq_ckpt=gptq_ckpt + args.gptq_wbits=gptq_wbits + args.gpus=gpus + args.num_gpus=num_gpus + args.max_gpu_memory=max_gpu_memory + args.cpu_offloading=cpu_offloading + args.worker_address=worker_address + args.controller_address=controller_address + args.limit_model_concurrency=limit_model_concurrency + args.stream_interval=stream_interval + args.no_register=no_register + sys.modules["fastchat.serve.model_worker"].worker = worker + sys.modules["fastchat.serve.model_worker"].args = args sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config sys.modules["fastchat.serve.model_worker"].logger = logger return app