fix: correct model_worker's logger and semaphor
This commit is contained in:
parent
1b3bd4442c
commit
0b25d7b079
|
|
@ -7,7 +7,7 @@ import json
|
||||||
import sys
|
import sys
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import fastchat
|
import fastchat
|
||||||
import threading
|
import asyncio
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -40,6 +40,7 @@ class ApiModelWorker(BaseModelWorker):
|
||||||
worker_addr=worker_addr,
|
worker_addr=worker_addr,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
self.context_len = context_len
|
self.context_len = context_len
|
||||||
|
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
|
||||||
self.init_heart_beat()
|
self.init_heart_beat()
|
||||||
|
|
||||||
def count_token(self, params):
|
def count_token(self, params):
|
||||||
|
|
|
||||||
21
startup.py
21
startup.py
|
|
@ -78,9 +78,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||||
"""
|
"""
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
from fastchat.serve.model_worker import worker_id, logger
|
|
||||||
import argparse
|
import argparse
|
||||||
logger.setLevel(log_level)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
|
@ -90,19 +88,22 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||||
|
|
||||||
# 在线模型API
|
# 在线模型API
|
||||||
if worker_class := kwargs.get("worker_class"):
|
if worker_class := kwargs.get("worker_class"):
|
||||||
from fastchat.serve.model_worker import app
|
from fastchat.serve.base_model_worker import app
|
||||||
|
|
||||||
worker = worker_class(model_names=args.model_names,
|
worker = worker_class(model_names=args.model_names,
|
||||||
controller_addr=args.controller_address,
|
controller_addr=args.controller_address,
|
||||||
worker_addr=args.worker_address)
|
worker_addr=args.worker_address)
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
|
||||||
|
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
|
||||||
# 本地模型
|
# 本地模型
|
||||||
else:
|
else:
|
||||||
from configs.model_config import VLLM_MODEL_DICT
|
from configs.model_config import VLLM_MODEL_DICT
|
||||||
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
|
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
|
||||||
import fastchat.serve.vllm_worker
|
import fastchat.serve.vllm_worker
|
||||||
from fastchat.serve.vllm_worker import VLLMWorker,app
|
from fastchat.serve.vllm_worker import VLLMWorker, app
|
||||||
from vllm import AsyncLLMEngine
|
from vllm import AsyncLLMEngine
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
|
||||||
|
|
||||||
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
||||||
args.tokenizer_mode = 'auto'
|
args.tokenizer_mode = 'auto'
|
||||||
args.trust_remote_code= True
|
args.trust_remote_code= True
|
||||||
|
|
@ -155,10 +156,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||||
conv_template = args.conv_template,
|
conv_template = args.conv_template,
|
||||||
)
|
)
|
||||||
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
||||||
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
# sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
||||||
|
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker
|
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||||
|
|
||||||
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
||||||
args.max_gpu_memory = "22GiB"
|
args.max_gpu_memory = "22GiB"
|
||||||
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
||||||
|
|
@ -221,8 +224,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||||
)
|
)
|
||||||
sys.modules["fastchat.serve.model_worker"].args = args
|
sys.modules["fastchat.serve.model_worker"].args = args
|
||||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||||
|
# sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
|
||||||
|
|
||||||
MakeFastAPIOffline(app)
|
MakeFastAPIOffline(app)
|
||||||
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue