diff --git a/startup.py b/startup.py index 88840aa..e610cda 100644 --- a/startup.py +++ b/startup.py @@ -104,7 +104,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: from configs.model_config import VLLM_MODEL_DICT if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm": import fastchat.serve.vllm_worker - from fastchat.serve.vllm_worker import VLLMWorker, app + from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id from vllm import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs @@ -137,7 +137,10 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: args.quantization = None args.max_log_len = None args.tokenizer_revision = None - + + # 0.2.2 vllm需要新加的参数 + args.max_paddings = 256 + if args.model_path: args.model = args.model_path if args.num_gpus > 1: @@ -161,7 +164,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: conv_template = args.conv_template, ) sys.modules["fastchat.serve.vllm_worker"].engine = engine - # sys.modules["fastchat.serve.vllm_worker"].worker = worker + sys.modules["fastchat.serve.vllm_worker"].worker = worker sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level) else: