startup.py增加参数-q | quiet,可以过滤fastchat的controller/model_worker不必要的日志输出 (#1333)

* startup.py增加参数`-q | quiet`,可以过滤fastchat的controller/model_worker不必要的日志输出
This commit is contained in:
liunux4odoo 2023-08-31 22:55:07 +08:00 committed by GitHub
parent b1201a5f23
commit 72b9da2649
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 63 additions and 22 deletions

View File

@ -29,10 +29,12 @@ from configs import VERSION
def create_controller_app(
dispatch_method: str,
log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.controller import app, Controller
from fastchat.serve.controller import app, Controller, logger
logger.setLevel(log_level)
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
@ -42,13 +44,14 @@ def create_controller_app(
return app
def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
import argparse
import threading
import fastchat.serve.model_worker
logger.setLevel(log_level)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
@ -137,10 +140,14 @@ def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]
def create_openai_api_app(
controller_address: str,
api_keys: List = [],
log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
from fastchat.utils import build_logger
logger = build_logger("openai_api", "openai_api.log")
logger.setLevel(log_level)
app.add_middleware(
CORSMiddleware,
@ -150,6 +157,7 @@ def create_openai_api_app(
allow_headers=["*"],
)
sys.modules["fastchat.serve.openai_api_server"].logger = logger
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
@ -159,6 +167,9 @@ def create_openai_api_app(
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
if q is None or not isinstance(run_seq, int):
return
if run_seq == 1:
@app.on_event("startup")
async def on_startup():
@ -177,15 +188,22 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
q.put(run_seq)
def run_controller(q: Queue, run_seq: int = 1):
def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
import uvicorn
import sys
app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method"))
app = create_controller_app(
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
log_level=log_level,
)
_set_app_seq(app, q, run_seq)
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
uvicorn.run(app, host=host, port=port)
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_model_worker(
@ -193,8 +211,10 @@ def run_model_worker(
controller_address: str = "",
q: Queue = None,
run_seq: int = 2,
log_level: str ="INFO",
):
import uvicorn
import sys
kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
@ -205,21 +225,28 @@ def run_model_worker(
kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address()
app = create_model_worker_app(**kwargs)
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq)
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port)
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_openai_api(q: Queue, run_seq: int = 3):
def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
import uvicorn
import sys
controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr) # todo: not support keys yet.
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
_set_app_seq(app, q, run_seq)
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port)
@ -239,13 +266,15 @@ def run_api_server(q: Queue, run_seq: int = 4):
def run_webui(q: Queue, run_seq: int = 5):
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
if q is not None and isinstance(run_seq, int):
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
p = subprocess.Popen(["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port)])
@ -315,11 +344,18 @@ def parse_args() -> argparse.ArgumentParser:
help="run webui.py server",
dest="webui",
)
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="减少fastchat服务log信息",
dest="quiet",
)
args = parser.parse_args()
return args, parser
def dump_server_info(after_start=False):
def dump_server_info(after_start=False, args=None):
import platform
import langchain
import fastchat
@ -355,6 +391,7 @@ if __name__ == "__main__":
mp.set_start_method("spawn")
queue = Queue()
args, parser = parse_args()
if args.all_webui:
args.openai_api = True
args.model_worker = True
@ -373,19 +410,23 @@ if __name__ == "__main__":
args.api = False
args.webui = False
dump_server_info()
dump_server_info(args=args)
if len(sys.argv) > 1:
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {}
if args.quiet:
log_level = "ERROR"
else:
log_level = "INFO"
if args.openai_api:
process = Process(
target=run_controller,
name=f"controller({os.getpid()})",
args=(queue, len(processes) + 1),
args=(queue, len(processes) + 1, log_level),
daemon=True,
)
process.start()
@ -406,7 +447,7 @@ if __name__ == "__main__":
process = Process(
target=run_model_worker,
name=f"model_worker({os.getpid()})",
args=(args.model_name, args.controller_address, queue, len(processes) + 1),
args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level),
daemon=True,
)
process.start()
@ -441,7 +482,7 @@ if __name__ == "__main__":
no = queue.get()
if no == len(processes):
time.sleep(0.5)
dump_server_info(True)
dump_server_info(after_start=True, args=args)
break
else:
queue.put(no)