fix: correct model_worker's logger and semaphor

This commit is contained in:
liunux4odoo 2023-10-20 11:50:50 +08:00
parent 1b3bd4442c
commit 0b25d7b079
2 changed files with 14 additions and 10 deletions

View File

@ -7,7 +7,7 @@ import json
import sys import sys
from pydantic import BaseModel from pydantic import BaseModel
import fastchat import fastchat
import threading import asyncio
from typing import Dict, List from typing import Dict, List
@ -40,6 +40,7 @@ class ApiModelWorker(BaseModelWorker):
worker_addr=worker_addr, worker_addr=worker_addr,
**kwargs) **kwargs)
self.context_len = context_len self.context_len = context_len
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
self.init_heart_beat() self.init_heart_beat()
def count_token(self, params): def count_token(self, params):

View File

@ -78,9 +78,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
""" """
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import worker_id, logger
import argparse import argparse
logger.setLevel(log_level)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
args = parser.parse_args([]) args = parser.parse_args([])
@ -90,19 +88,22 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
# 在线模型API # 在线模型API
if worker_class := kwargs.get("worker_class"): if worker_class := kwargs.get("worker_class"):
from fastchat.serve.model_worker import app from fastchat.serve.base_model_worker import app
worker = worker_class(model_names=args.model_names, worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address, controller_addr=args.controller_address,
worker_addr=args.worker_address) worker_addr=args.worker_address)
sys.modules["fastchat.serve.model_worker"].worker = worker # sys.modules["fastchat.serve.base_model_worker"].worker = worker
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
# 本地模型 # 本地模型
else: else:
from configs.model_config import VLLM_MODEL_DICT from configs.model_config import VLLM_MODEL_DICT
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm": if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
import fastchat.serve.vllm_worker import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker,app from fastchat.serve.vllm_worker import VLLMWorker, app
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加 args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
args.tokenizer_mode = 'auto' args.tokenizer_mode = 'auto'
args.trust_remote_code= True args.trust_remote_code= True
@ -155,10 +156,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
conv_template = args.conv_template, conv_template = args.conv_template,
) )
sys.modules["fastchat.serve.vllm_worker"].engine = engine sys.modules["fastchat.serve.vllm_worker"].engine = engine
sys.modules["fastchat.serve.vllm_worker"].worker = worker # sys.modules["fastchat.serve.vllm_worker"].worker = worker
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
else: else:
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3" args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3"
args.max_gpu_memory = "22GiB" args.max_gpu_memory = "22GiB"
args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量 args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量
@ -221,8 +224,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
) )
sys.modules["fastchat.serve.model_worker"].args = args sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
# sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({args.model_names[0]})" app.title = f"FastChat LLM Server ({args.model_names[0]})"