97 lines
2.7 KiB
Python
97 lines
2.7 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
|
|
|
from loguru import logger
|
|
|
|
import inspect
|
|
import os
|
|
import sys
|
|
|
|
|
|
def get_caller_name(depth=0):
|
|
"""
|
|
Args:
|
|
depth (int): Depth of caller conext, use 0 for caller depth. Default value: 0.
|
|
|
|
Returns:
|
|
str: module name of the caller
|
|
"""
|
|
# the following logic is a little bit faster than inspect.stack() logic
|
|
frame = inspect.currentframe().f_back
|
|
for _ in range(depth):
|
|
frame = frame.f_back
|
|
|
|
return frame.f_globals["__name__"]
|
|
|
|
|
|
class StreamToLoguru:
|
|
"""
|
|
stream object that redirects writes to a logger instance.
|
|
"""
|
|
|
|
def __init__(self, level="INFO", caller_names=("apex", "pycocotools")):
|
|
"""
|
|
Args:
|
|
level(str): log level string of loguru. Default value: "INFO".
|
|
caller_names(tuple): caller names of redirected module.
|
|
Default value: (apex, pycocotools).
|
|
"""
|
|
self.level = level
|
|
self.linebuf = ""
|
|
self.caller_names = caller_names
|
|
|
|
def write(self, buf):
|
|
full_name = get_caller_name(depth=1)
|
|
module_name = full_name.rsplit(".", maxsplit=-1)[0]
|
|
if module_name in self.caller_names:
|
|
for line in buf.rstrip().splitlines():
|
|
# use caller level log
|
|
logger.opt(depth=2).log(self.level, line.rstrip())
|
|
else:
|
|
sys.__stdout__.write(buf)
|
|
|
|
def flush(self):
|
|
pass
|
|
|
|
|
|
def redirect_sys_output(log_level="INFO"):
|
|
redirect_logger = StreamToLoguru(log_level)
|
|
sys.stderr = redirect_logger
|
|
sys.stdout = redirect_logger
|
|
|
|
|
|
def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
|
|
"""setup logger for training and testing.
|
|
Args:
|
|
save_dir(str): location to save log file
|
|
distributed_rank(int): device rank when multi-gpu environment
|
|
filename (string): log save name.
|
|
mode(str): log file write mode, `append` or `override`. default is `a`.
|
|
|
|
Return:
|
|
logger instance.
|
|
"""
|
|
loguru_format = (
|
|
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
|
|
"<level>{level: <8}</level> | "
|
|
"<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
|
)
|
|
|
|
logger.remove()
|
|
save_file = os.path.join(save_dir, filename)
|
|
if mode == "o" and os.path.exists(save_file):
|
|
os.remove(save_file)
|
|
# only keep logger in rank0 process
|
|
if distributed_rank == 0:
|
|
logger.add(
|
|
sys.stderr,
|
|
format=loguru_format,
|
|
level="INFO",
|
|
enqueue=True,
|
|
)
|
|
logger.add(save_file)
|
|
|
|
# redirect stdout/stderr to loguru
|
|
redirect_sys_output("INFO")
|