Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Dict, List, Optional, Tuple
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
Item = Dict[str, float]
TB_COLOR, TB_COLOR_SMOOTH = '#FFE2D9', '#FF7043'
def read_tensorboard_file(fpath: str) -> Dict[str, List[Item]]:
if not os.path.isfile(fpath):
raise FileNotFoundError(f'fpath: {fpath}')
ea = EventAccumulator(fpath)
ea.Reload()
res: Dict[str, List[Item]] = {}
tags = ea.Tags()['scalars']
for tag in tags:
values = ea.Scalars(tag)
r: List[Item] = []
for v in values:
r.append({'step': v.step, 'value': v.value})
res[tag] = r
return res
def tensorboard_smoothing(values: List[float], smooth: float = 0.9) -> List[float]:
norm_factor = 0
x = 0
res: List[float] = []
for i in range(len(values)):
x = x * smooth + values[i] # Exponential decay
norm_factor *= smooth
norm_factor += 1
res.append(x / norm_factor)
return res
def plot_images(images_dir: str,
tb_dir: str,
smooth_key: Optional[List[str]] = None,
smooth_val: float = 0.9,
figsize: Tuple[int, int] = (8, 5),
dpi: int = 100) -> None:
"""Using tensorboard's data content to plot images"""
smooth_key = smooth_key or []
os.makedirs(images_dir, exist_ok=True)
fname = [fname for fname in os.listdir(tb_dir) if os.path.isfile(os.path.join(tb_dir, fname))][0]
tb_path = os.path.join(tb_dir, fname)
data = read_tensorboard_file(tb_path)
for k in data.keys():
_data = data[k]
steps = [d['step'] for d in _data]
values = [d['value'] for d in _data]
if len(values) == 0:
continue
_, ax = plt.subplots(1, 1, squeeze=True, figsize=figsize, dpi=dpi)
ax.set_title(k)
if len(values) == 1:
ax.scatter(steps, values, color=TB_COLOR_SMOOTH)
elif k in smooth_key:
ax.plot(steps, values, color=TB_COLOR)
values_s = tensorboard_smoothing(values, smooth_val)
ax.plot(steps, values_s, color=TB_COLOR_SMOOTH)
else:
ax.plot(steps, values, color=TB_COLOR_SMOOTH)
fpath = os.path.join(images_dir, k.replace('/', '_').replace('.', '_'))
plt.savefig(fpath, dpi=dpi, bbox_inches='tight')
plt.close()