diff --git a/startup.py b/startup.py index 07630d9..5ef05ce 100644 --- a/startup.py +++ b/startup.py @@ -29,10 +29,12 @@ from configs import VERSION def create_controller_app( dispatch_method: str, + log_level: str = "INFO", ) -> FastAPI: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.controller import app, Controller + from fastchat.serve.controller import app, Controller, logger + logger.setLevel(log_level) controller = Controller(dispatch_method) sys.modules["fastchat.serve.controller"].controller = controller @@ -42,13 +44,14 @@ def create_controller_app( return app -def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: +def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id + from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger import argparse import threading import fastchat.serve.model_worker + logger.setLevel(log_level) # workaround to make program exit with Ctrl+c # it should be deleted after pr is merged by fastchat @@ -137,10 +140,14 @@ def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI] def create_openai_api_app( controller_address: str, api_keys: List = [], + log_level: str = "INFO", ) -> FastAPI: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings + from fastchat.utils import build_logger + logger = build_logger("openai_api", "openai_api.log") + logger.setLevel(log_level) app.add_middleware( CORSMiddleware, @@ -150,6 +157,7 @@ def create_openai_api_app( allow_headers=["*"], ) + sys.modules["fastchat.serve.openai_api_server"].logger = logger app_settings.controller_address = controller_address app_settings.api_keys = api_keys @@ -159,6 +167,9 @@ def create_openai_api_app( def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): + if q is None or not isinstance(run_seq, int): + return + if run_seq == 1: @app.on_event("startup") async def on_startup(): @@ -177,15 +188,22 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): q.put(run_seq) -def run_controller(q: Queue, run_seq: int = 1): +def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"): import uvicorn + import sys - app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method")) + app = create_controller_app( + dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"), + log_level=log_level, + ) _set_app_seq(app, q, run_seq) host = FSCHAT_CONTROLLER["host"] port = FSCHAT_CONTROLLER["port"] - uvicorn.run(app, host=host, port=port) + if log_level == "ERROR": + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) def run_model_worker( @@ -193,8 +211,10 @@ def run_model_worker( controller_address: str = "", q: Queue = None, run_seq: int = 2, + log_level: str ="INFO", ): import uvicorn + import sys kwargs = get_model_worker_config(model_name) host = kwargs.pop("host") @@ -205,21 +225,28 @@ def run_model_worker( kwargs["controller_address"] = controller_address or fschat_controller_address() kwargs["worker_address"] = fschat_model_worker_address() - app = create_model_worker_app(**kwargs) + app = create_model_worker_app(log_level=log_level, **kwargs) _set_app_seq(app, q, run_seq) + if log_level == "ERROR": + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ - uvicorn.run(app, host=host, port=port) + uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) -def run_openai_api(q: Queue, run_seq: int = 3): +def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"): import uvicorn + import sys controller_addr = fschat_controller_address() - app = create_openai_api_app(controller_addr) # todo: not support keys yet. + app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet. _set_app_seq(app, q, run_seq) host = FSCHAT_OPENAI_API["host"] port = FSCHAT_OPENAI_API["port"] + if log_level == "ERROR": + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ uvicorn.run(app, host=host, port=port) @@ -239,13 +266,15 @@ def run_api_server(q: Queue, run_seq: int = 4): def run_webui(q: Queue, run_seq: int = 5): host = WEBUI_SERVER["host"] port = WEBUI_SERVER["port"] - while True: - no = q.get() - if no != run_seq - 1: - q.put(no) - else: - break - q.put(run_seq) + + if q is not None and isinstance(run_seq, int): + while True: + no = q.get() + if no != run_seq - 1: + q.put(no) + else: + break + q.put(run_seq) p = subprocess.Popen(["streamlit", "run", "webui.py", "--server.address", host, "--server.port", str(port)]) @@ -315,11 +344,18 @@ def parse_args() -> argparse.ArgumentParser: help="run webui.py server", dest="webui", ) + parser.add_argument( + "-q", + "--quiet", + action="store_true", + help="减少fastchat服务log信息", + dest="quiet", + ) args = parser.parse_args() return args, parser -def dump_server_info(after_start=False): +def dump_server_info(after_start=False, args=None): import platform import langchain import fastchat @@ -355,6 +391,7 @@ if __name__ == "__main__": mp.set_start_method("spawn") queue = Queue() args, parser = parse_args() + if args.all_webui: args.openai_api = True args.model_worker = True @@ -373,19 +410,23 @@ if __name__ == "__main__": args.api = False args.webui = False - dump_server_info() + dump_server_info(args=args) if len(sys.argv) > 1: logger.info(f"正在启动服务:") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") processes = {} + if args.quiet: + log_level = "ERROR" + else: + log_level = "INFO" if args.openai_api: process = Process( target=run_controller, name=f"controller({os.getpid()})", - args=(queue, len(processes) + 1), + args=(queue, len(processes) + 1, log_level), daemon=True, ) process.start() @@ -406,7 +447,7 @@ if __name__ == "__main__": process = Process( target=run_model_worker, name=f"model_worker({os.getpid()})", - args=(args.model_name, args.controller_address, queue, len(processes) + 1), + args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level), daemon=True, ) process.start() @@ -441,7 +482,7 @@ if __name__ == "__main__": no = queue.get() if no == len(processes): time.sleep(0.5) - dump_server_info(True) + dump_server_info(after_start=True, args=args) break else: queue.put(no)