fix: correct model_worker's logger and semaphor
This commit is contained in:
parent
1b3bd4442c
commit
0b25d7b079
|
|
@ -7,7 +7,7 @@ import json
|
|||
import sys
|
||||
from pydantic import BaseModel
|
||||
import fastchat
|
||||
import threading
|
||||
import asyncio
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
|
|
@ -40,6 +40,7 @@ class ApiModelWorker(BaseModelWorker):
|
|||
worker_addr=worker_addr,
|
||||
**kwargs)
|
||||
self.context_len = context_len
|
||||
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
|
||||
self.init_heart_beat()
|
||||
|
||||
def count_token(self, params):
|
||||
|
|
|
|||
21
startup.py
21
startup.py
|
|
@ -78,9 +78,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||
"""
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import worker_id, logger
|
||||
import argparse
|
||||
logger.setLevel(log_level)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args([])
|
||||
|
|
@ -90,19 +88,22 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||
|
||||
# 在线模型API
|
||||
if worker_class := kwargs.get("worker_class"):
|
||||
from fastchat.serve.model_worker import app
|
||||
from fastchat.serve.base_model_worker import app
|
||||
|
||||
worker = worker_class(model_names=args.model_names,
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
|
||||
# 本地模型
|
||||
else:
|
||||
from configs.model_config import VLLM_MODEL_DICT
|
||||
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
|
||||
import fastchat.serve.vllm_worker
|
||||
from fastchat.serve.vllm_worker import VLLMWorker,app
|
||||
from fastchat.serve.vllm_worker import VLLMWorker, app
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
|
||||
|
||||
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
||||
args.tokenizer_mode = 'auto'
|
||||
args.trust_remote_code= True
|
||||
|
|
@ -155,10 +156,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||
conv_template = args.conv_template,
|
||||
)
|
||||
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
||||
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
||||
# sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
|
||||
|
||||
else:
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||
|
||||
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
||||
args.max_gpu_memory = "22GiB"
|
||||
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
||||
|
|
@ -221,8 +224,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||
)
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
# sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
||||
|
|
|
|||
Loading…
Reference in New Issue