File size: 4,194 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import importlib.util
import logging
import os
from contextlib import contextmanager
from types import MethodType
from typing import Optional
from modelscope.utils.logger import get_logger as get_ms_logger
# Avoid circular reference
def _is_local_master():
local_rank = int(os.getenv('LOCAL_RANK', -1))
return local_rank in {-1, 0}
init_loggers = {}
# old format
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger_format = logging.Formatter('[%(levelname)s:%(name)s] %(message)s')
info_set = set()
warning_set = set()
def info_once(self, msg, *args, **kwargs):
hash_id = kwargs.get('hash_id') or msg
if hash_id in info_set:
return
info_set.add(hash_id)
self.info(msg)
def warning_once(self, msg, *args, **kwargs):
hash_id = kwargs.get('hash_id') or msg
if hash_id in warning_set:
return
warning_set.add(hash_id)
self.warning(msg)
def get_logger(log_file: Optional[str] = None, log_level: Optional[int] = None, file_mode: str = 'w'):
""" Get logging logger
Args:
log_file: Log filename, if specified, file handler will be added to
logger
log_level: Logging level.
file_mode: Specifies the mode to open the file, if filename is
specified (if filemode is unspecified, it defaults to 'w').
"""
if log_level is None:
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
log_level = getattr(logging, log_level, logging.INFO)
logger_name = __name__.split('.')[0]
logger = logging.getLogger(logger_name)
logger.propagate = False
if logger_name in init_loggers:
add_file_handler_if_needed(logger, log_file, file_mode, log_level)
return logger
# handle duplicate logs to the console
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
# to the root logger. As logger.propagate is True by default, this root
# level handler causes logging messages from rank>0 processes to
# unexpectedly show up on the console, creating much unwanted clutter.
# To fix this issue, we set the root logger's StreamHandler, if any, to log
# at the ERROR level.
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
is_worker0 = _is_local_master()
if is_worker0 and log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
for handler in handlers:
handler.setFormatter(logger_format)
handler.setLevel(log_level)
logger.addHandler(handler)
if is_worker0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
init_loggers[logger_name] = True
logger.info_once = MethodType(info_once, logger)
logger.warning_once = MethodType(warning_once, logger)
return logger
logger = get_logger()
ms_logger = get_ms_logger()
logger.handlers[0].setFormatter(logger_format)
ms_logger.handlers[0].setFormatter(logger_format)
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
if _is_local_master():
ms_logger.setLevel(log_level)
else:
ms_logger.setLevel(logging.ERROR)
@contextmanager
def ms_logger_ignore_error():
ms_logger = get_ms_logger()
origin_log_level = ms_logger.level
ms_logger.setLevel(logging.CRITICAL)
try:
yield
finally:
ms_logger.setLevel(origin_log_level)
def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
for handler in logger.handlers:
if isinstance(handler, logging.FileHandler):
return
if importlib.util.find_spec('torch') is not None:
is_worker0 = int(os.getenv('LOCAL_RANK', -1)) in {-1, 0}
else:
is_worker0 = True
if is_worker0 and log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
file_handler.setFormatter(logger_format)
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
|