diff --git a/server/model_workers/base.py b/server/model_workers/base.py index f280680..2b39bd6 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -7,7 +7,7 @@ import json import sys from pydantic import BaseModel import fastchat -import threading +import asyncio from typing import Dict, List @@ -40,6 +40,7 @@ class ApiModelWorker(BaseModelWorker): worker_addr=worker_addr, **kwargs) self.context_len = context_len + self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency) self.init_heart_beat() def count_token(self, params): diff --git a/startup.py b/startup.py index 878e7ec..a9504ee 100644 --- a/startup.py +++ b/startup.py @@ -78,9 +78,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: """ import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.model_worker import worker_id, logger import argparse - logger.setLevel(log_level) parser = argparse.ArgumentParser() args = parser.parse_args([]) @@ -90,19 +88,22 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: # 在线模型API if worker_class := kwargs.get("worker_class"): - from fastchat.serve.model_worker import app + from fastchat.serve.base_model_worker import app + worker = worker_class(model_names=args.model_names, controller_addr=args.controller_address, worker_addr=args.worker_address) - sys.modules["fastchat.serve.model_worker"].worker = worker + # sys.modules["fastchat.serve.base_model_worker"].worker = worker + sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level) # 本地模型 else: from configs.model_config import VLLM_MODEL_DICT if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm": import fastchat.serve.vllm_worker - from fastchat.serve.vllm_worker import VLLMWorker,app + from fastchat.serve.vllm_worker import VLLMWorker, app from vllm import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs + args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加 args.tokenizer_mode = 'auto' args.trust_remote_code= True @@ -155,10 +156,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: conv_template = args.conv_template, ) sys.modules["fastchat.serve.vllm_worker"].engine = engine - sys.modules["fastchat.serve.vllm_worker"].worker = worker + # sys.modules["fastchat.serve.vllm_worker"].worker = worker + sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level) else: - from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker + from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id + args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3" args.max_gpu_memory = "22GiB" args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量 @@ -221,8 +224,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: ) sys.modules["fastchat.serve.model_worker"].args = args sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config - - sys.modules["fastchat.serve.model_worker"].worker = worker + # sys.modules["fastchat.serve.model_worker"].worker = worker + sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level) MakeFastAPIOffline(app) app.title = f"FastChat LLM Server ({args.model_names[0]})"