|
|
|
|
|
import datetime as dt |
|
|
import fnmatch |
|
|
import glob |
|
|
import importlib |
|
|
import os |
|
|
import random |
|
|
import re |
|
|
import shutil |
|
|
import socket |
|
|
import subprocess |
|
|
import sys |
|
|
import time |
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from transformers import HfArgumentParser, enable_full_determinism, set_seed |
|
|
from transformers.utils import strtobool |
|
|
|
|
|
from .env import is_dist, is_dist_ta |
|
|
from .logger import get_logger |
|
|
from .np_utils import stat_array |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
def check_json_format(obj: Any, token_safe: bool = True) -> Any: |
|
|
if obj is None or isinstance(obj, (int, float, str, complex)): |
|
|
return obj |
|
|
if isinstance(obj, bytes): |
|
|
return '<<<bytes>>>' |
|
|
if isinstance(obj, (torch.dtype, torch.device)): |
|
|
obj = str(obj) |
|
|
return obj[len('torch.'):] if obj.startswith('torch.') else obj |
|
|
|
|
|
if isinstance(obj, Sequence): |
|
|
res = [] |
|
|
for x in obj: |
|
|
res.append(check_json_format(x, token_safe)) |
|
|
elif isinstance(obj, Mapping): |
|
|
res = {} |
|
|
for k, v in obj.items(): |
|
|
if token_safe and isinstance(k, str) and '_token' in k and isinstance(v, str): |
|
|
res[k] = None |
|
|
else: |
|
|
res[k] = check_json_format(v, token_safe) |
|
|
else: |
|
|
if token_safe: |
|
|
unsafe_items = {} |
|
|
for k, v in obj.__dict__.items(): |
|
|
if '_token' in k: |
|
|
unsafe_items[k] = v |
|
|
setattr(obj, k, None) |
|
|
res = repr(obj) |
|
|
|
|
|
for k, v in unsafe_items.items(): |
|
|
setattr(obj, k, v) |
|
|
else: |
|
|
res = repr(obj) |
|
|
return res |
|
|
|
|
|
|
|
|
def _get_version(work_dir: str) -> int: |
|
|
if os.path.isdir(work_dir): |
|
|
fnames = os.listdir(work_dir) |
|
|
else: |
|
|
fnames = [] |
|
|
v_list = [-1] |
|
|
for fname in fnames: |
|
|
m = re.match(r'v(\d+)', fname) |
|
|
if m is None: |
|
|
continue |
|
|
v = m.group(1) |
|
|
v_list.append(int(v)) |
|
|
return max(v_list) + 1 |
|
|
|
|
|
|
|
|
def format_time(seconds): |
|
|
days = int(seconds // (24 * 3600)) |
|
|
hours = int((seconds % (24 * 3600)) // 3600) |
|
|
minutes = int((seconds % 3600) // 60) |
|
|
seconds = int(seconds % 60) |
|
|
|
|
|
if days > 0: |
|
|
time_str = f'{days}d {hours}h {minutes}m {seconds}s' |
|
|
elif hours > 0: |
|
|
time_str = f'{hours}h {minutes}m {seconds}s' |
|
|
elif minutes > 0: |
|
|
time_str = f'{minutes}m {seconds}s' |
|
|
else: |
|
|
time_str = f'{seconds}s' |
|
|
|
|
|
return time_str |
|
|
|
|
|
|
|
|
def deep_getattr(obj, attr: str, default=None): |
|
|
attrs = attr.split('.') |
|
|
for a in attrs: |
|
|
if obj is None: |
|
|
break |
|
|
if isinstance(obj, dict): |
|
|
obj = obj.get(a, default) |
|
|
else: |
|
|
obj = getattr(obj, a, default) |
|
|
return obj |
|
|
|
|
|
|
|
|
def seed_everything(seed: Optional[int] = None, full_determinism: bool = False, *, verbose: bool = True) -> int: |
|
|
|
|
|
if seed is None: |
|
|
seed_max = np.iinfo(np.int32).max |
|
|
seed = random.randint(0, seed_max) |
|
|
|
|
|
if full_determinism: |
|
|
enable_full_determinism(seed) |
|
|
else: |
|
|
set_seed(seed) |
|
|
if verbose: |
|
|
logger.info(f'Global seed set to {seed}') |
|
|
return seed |
|
|
|
|
|
|
|
|
def add_version_to_work_dir(work_dir: str) -> str: |
|
|
"""add version""" |
|
|
version = _get_version(work_dir) |
|
|
time = dt.datetime.now().strftime('%Y%m%d-%H%M%S') |
|
|
sub_folder = f'v{version}-{time}' |
|
|
if (dist.is_initialized() and is_dist()) or is_dist_ta(): |
|
|
obj_list = [sub_folder] |
|
|
dist.broadcast_object_list(obj_list) |
|
|
sub_folder = obj_list[0] |
|
|
|
|
|
work_dir = os.path.join(work_dir, sub_folder) |
|
|
return work_dir |
|
|
|
|
|
|
|
|
_T = TypeVar('_T') |
|
|
|
|
|
|
|
|
def parse_args(class_type: Type[_T], argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]: |
|
|
parser = HfArgumentParser([class_type]) |
|
|
if argv is None: |
|
|
argv = sys.argv[1:] |
|
|
if len(argv) > 0 and argv[0].endswith('.json'): |
|
|
json_path = os.path.abspath(os.path.expanduser(argv[0])) |
|
|
args, = parser.parse_json_file(json_path) |
|
|
remaining_args = argv[1:] |
|
|
else: |
|
|
args, remaining_args = parser.parse_args_into_dataclasses(argv, return_remaining_strings=True) |
|
|
return args, remaining_args |
|
|
|
|
|
|
|
|
def lower_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int: |
|
|
|
|
|
while lo < hi: |
|
|
mid = (lo + hi) >> 1 |
|
|
if cond(mid): |
|
|
hi = mid |
|
|
else: |
|
|
lo = mid + 1 |
|
|
return lo |
|
|
|
|
|
|
|
|
def upper_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int: |
|
|
|
|
|
while lo < hi: |
|
|
mid = (lo + hi + 1) >> 1 |
|
|
if cond(mid): |
|
|
lo = mid |
|
|
else: |
|
|
hi = mid - 1 |
|
|
return lo |
|
|
|
|
|
|
|
|
def test_time(func: Callable[[], _T], |
|
|
number: int = 1, |
|
|
warmup: int = 0, |
|
|
timer: Optional[Callable[[], float]] = None) -> _T: |
|
|
|
|
|
timer = timer if timer is not None else time.perf_counter |
|
|
|
|
|
ts = [] |
|
|
res = None |
|
|
|
|
|
for _ in range(warmup): |
|
|
res = func() |
|
|
|
|
|
for _ in range(number): |
|
|
t1 = timer() |
|
|
res = func() |
|
|
t2 = timer() |
|
|
ts.append(t2 - t1) |
|
|
|
|
|
ts = np.array(ts) |
|
|
_, stat_str = stat_array(ts) |
|
|
|
|
|
logger.info(f'time[number={number}]: {stat_str}') |
|
|
return res |
|
|
|
|
|
|
|
|
def read_multi_line(addi_prompt: str = '') -> str: |
|
|
res = [] |
|
|
prompt = f'<<<{addi_prompt} ' |
|
|
while True: |
|
|
text = input(prompt) + '\n' |
|
|
prompt = '' |
|
|
res.append(text) |
|
|
if text.endswith('#\n'): |
|
|
res[-1] = text[:-2] |
|
|
break |
|
|
return ''.join(res) |
|
|
|
|
|
|
|
|
def subprocess_run(command: List[str], env: Optional[Dict[str, str]] = None, stdout=None, stderr=None): |
|
|
|
|
|
resp = subprocess.run(command, env=env, stdout=stdout, stderr=stderr) |
|
|
resp.check_returncode() |
|
|
return resp |
|
|
|
|
|
|
|
|
def get_env_args(args_name: str, type_func: Callable[[str], _T], default_value: Optional[_T]) -> Optional[_T]: |
|
|
args_name_upper = args_name.upper() |
|
|
value = os.getenv(args_name_upper) |
|
|
if value is None: |
|
|
value = default_value |
|
|
log_info = (f'Setting {args_name}: {default_value}. ' |
|
|
f'You can adjust this hyperparameter through the environment variable: `{args_name_upper}`.') |
|
|
else: |
|
|
if type_func is bool: |
|
|
value = strtobool(value) |
|
|
value = type_func(value) |
|
|
log_info = f'Using environment variable `{args_name_upper}`, Setting {args_name}: {value}.' |
|
|
logger.info_once(log_info) |
|
|
return value |
|
|
|
|
|
|
|
|
def find_free_port(start_port: Optional[int] = None, retry: int = 100) -> int: |
|
|
if start_port is None: |
|
|
start_port = 0 |
|
|
for port in range(start_port, start_port + retry): |
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: |
|
|
try: |
|
|
sock.bind(('', port)) |
|
|
port = sock.getsockname()[1] |
|
|
break |
|
|
except OSError: |
|
|
pass |
|
|
return port |
|
|
|
|
|
|
|
|
def copy_files_by_pattern(source_dir, dest_dir, patterns): |
|
|
if not os.path.exists(dest_dir): |
|
|
os.makedirs(dest_dir) |
|
|
|
|
|
if isinstance(patterns, str): |
|
|
patterns = [patterns] |
|
|
|
|
|
for pattern in patterns: |
|
|
pattern_parts = pattern.split(os.path.sep) |
|
|
if len(pattern_parts) > 1: |
|
|
subdir_pattern = os.path.sep.join(pattern_parts[:-1]) |
|
|
file_pattern = pattern_parts[-1] |
|
|
|
|
|
for root, dirs, files in os.walk(source_dir): |
|
|
rel_path = os.path.relpath(root, source_dir) |
|
|
if rel_path == '.' or (rel_path != '.' and not fnmatch.fnmatch(rel_path, subdir_pattern)): |
|
|
continue |
|
|
|
|
|
for file in files: |
|
|
if fnmatch.fnmatch(file, file_pattern): |
|
|
file_path = os.path.join(root, file) |
|
|
target_dir = os.path.join(dest_dir, rel_path) |
|
|
if not os.path.exists(target_dir): |
|
|
os.makedirs(target_dir) |
|
|
dest_file = os.path.join(target_dir, file) |
|
|
|
|
|
if not os.path.exists(dest_file): |
|
|
shutil.copy2(file_path, dest_file) |
|
|
else: |
|
|
search_path = os.path.join(source_dir, pattern) |
|
|
matched_files = glob.glob(search_path) |
|
|
|
|
|
for file_path in matched_files: |
|
|
if os.path.isfile(file_path): |
|
|
file_name = os.path.basename(file_path) |
|
|
destination = os.path.join(dest_dir, file_name) |
|
|
if not os.path.exists(destination): |
|
|
shutil.copy2(file_path, destination) |
|
|
|
|
|
|
|
|
def split_list(ori_list, num_shards): |
|
|
idx_list = np.linspace(0, len(ori_list), num_shards + 1) |
|
|
shard = [] |
|
|
for i in range(len(idx_list) - 1): |
|
|
shard.append(ori_list[int(idx_list[i]):int(idx_list[i + 1])]) |
|
|
return shard |
|
|
|
|
|
|
|
|
def patch_getattr(obj_cls, item_name: str): |
|
|
if hasattr(obj_cls, '_patch'): |
|
|
return |
|
|
|
|
|
def __new_getattr__(self, key: str): |
|
|
try: |
|
|
return super(self.__class__, self).__getattr__(key) |
|
|
except AttributeError: |
|
|
if item_name in dir(self): |
|
|
item = getattr(self, item_name) |
|
|
return getattr(item, key) |
|
|
raise |
|
|
|
|
|
obj_cls.__getattr__ = __new_getattr__ |
|
|
obj_cls._patch = True |
|
|
|
|
|
|
|
|
def import_external_file(file_path: str): |
|
|
file_path = os.path.abspath(os.path.expanduser(file_path)) |
|
|
py_dir, py_file = os.path.split(file_path) |
|
|
assert os.path.isdir(py_dir), f'py_dir: {py_dir}' |
|
|
sys.path.insert(0, py_dir) |
|
|
return importlib.import_module(py_file.split('.', 1)[0]) |
|
|
|