ezpz.historyβοΈ
- See ezpz/
history.py
history.py
Contains implementation of History object for tracking / aggregating metrics.
History
βοΈ
A class to track and log metrics during training or evaluation.
Source code in src/ezpz/history.py
class History:
"""
A class to track and log metrics during training or evaluation.
"""
def __init__(
self,
keys: Optional[list[str]] = None,
*,
report_dir: Optional[PathLike] = None,
report_enabled: bool = True,
jsonl_path: Optional[PathLike] = None,
jsonl_overwrite: bool = False,
distributed_history: bool = AUTO_USE_DISTRIBUTED_HISTORY,
project_name: Optional[str] = None,
backends: Optional[str | Sequence[str]] = None,
config: Optional[dict[str, Any]] = None,
outdir: Optional[PathLike] = None,
tracker: Optional[Tracker] = None,
) -> None:
"""
Initialize the History object.
Args:
keys (Optional[list[str]]): List of keys to initialize the history with.
If None, initializes with an empty list.
report_dir (Optional[PathLike]): Directory for markdown reports. Defaults
to ``OUTPUTS_DIR/history``.
report_enabled (bool): Toggle automatic markdown generation.
jsonl_path (Optional[PathLike]): Destination for JSONL metric logging.
jsonl_overwrite (bool): Whether to truncate an existing JSONL log.
distributed_history (bool): Enable distributed history tracking.
project_name (Optional[str]): Project name for tracker backends (e.g. wandb).
backends (Optional[str | Sequence[str]]): Comma-separated string or list
of tracker backend names (e.g. ``"wandb,csv"``).
config (Optional[dict[str, Any]]): Run-level config (hyperparameters) to
log via the tracker.
outdir (Optional[PathLike]): Output directory for file-based tracker
backends (e.g. CSV).
tracker (Optional[Tracker]): Inject a pre-built Tracker instance directly.
"""
self.keys = [] if keys is None else keys
self._groups: dict[str, dict[str, list[Any]]] = {}
self._flat_cache: dict[str, list[Any]] | None = None
if (
os.environ.get("EZPZ_NO_DISTRIBUTED_HISTORY", None)
or os.environ.get("EZPZ_LOCAL_HISTORY", False)
or ezpz.distributed.get_world_size() <= 1
):
logger.info(
"Not using distributed metrics! Will only be tracked from a single rank..."
)
distributed_history = False
# aggregate_metrics = False
self.distributed_history = distributed_history
logger.info(
f"Using {self.__class__.__name__} with distributed_history={self.distributed_history}"
)
# self._aggregate_metrics = aggregate_metrics
self._rank = get_rank()
now = datetime.now(timezone.utc)
self._run_id = now.strftime("%Y%m%d-%H%M%S")
self.report_enabled = report_enabled
base_report_root = (
Path(report_dir)
if report_dir is not None
else Path(OUTPUTS_DIR).joinpath("history")
)
self._report_root = Path(base_report_root).expanduser().resolve()
self._report_dir = self._report_root.joinpath(self._run_id)
self._report_path: Optional[Path] = None
self._asset_dir: Optional[Path] = None
self._report_filename = "report.md"
self._report_initialized = False
self._jsonl_path_explicit = jsonl_path is not None
if jsonl_path is None:
default_jsonl_dir = (
self._report_dir if report_enabled else Path(OUTPUTS_DIR)
)
self._jsonl_path = (
Path(default_jsonl_dir)
.expanduser()
.resolve()
.joinpath(f"{self._run_id}.jsonl")
)
else:
self._jsonl_path = Path(jsonl_path).expanduser().resolve()
if jsonl_overwrite and self._jsonl_path.exists():
try:
self._jsonl_path.unlink()
except OSError:
logger.warning(
"Unable to remove existing JSONL log at %s",
self._jsonl_path,
)
self._jsonl_enabled = True
# Serializes JSONL writes against the finalize() move so a
# background update() can't race the cross-FS shutil.move
# and either (a) write to a half-moved file or (b) lose the
# in-flight record altogether.
self._jsonl_lock = threading.Lock()
self._dist = torch.distributed
self._environment_written = False
self._metric_summary_written = False
# -- Tracker integration --
if tracker is not None:
self._tracker: Tracker = tracker
elif any(
arg is not None for arg in (project_name, backends)
) or os.environ.get("EZPZ_TRACKER_BACKENDS"):
self._tracker = setup_tracker(
project_name=project_name,
backends=backends,
config=config,
outdir=str(outdir) if outdir is not None else None,
)
else:
# Backward compat: auto-detect existing wandb.run
if wandb is not None and getattr(wandb, "run", None) is not None:
warnings.warn(
"History detected an active wandb.run but no 'backends' "
"argument was provided. Automatically using "
"backends='wandb'. In a future version, pass "
"backends='wandb' explicitly.",
DeprecationWarning,
stacklevel=2,
)
self._tracker = setup_tracker(backends="wandb")
else:
self._tracker = NullTracker()
# Forward config to the tracker when the backends didn't receive it
# in their constructors. This covers:
# - Injected tracker= (backends never saw config)
# - Auto-detect wandb.run path (setup_tracker called without config)
# The setup_tracker(config=...) path already handles config internally,
# so we skip it there to avoid duplicates.
_tracker_got_config = tracker is None and (
any(arg is not None for arg in (project_name, backends))
or os.environ.get(
"EZPZ_TRACKER_BACKENDS", os.environ.get("EZPZ_TRACKER_BACKEND")
)
)
if config is not None and not _tracker_got_config:
self._tracker.log_config(config)
@property
def history(self) -> dict[str, list[Any]]:
"""Flattened view of all metric groups.
Returns a dict keyed by full metric names (e.g. ``"train/loss"``,
``"loss"``), preserving backward compatibility with code that reads
``history.history["key"]``.
"""
if self._flat_cache is not None:
return self._flat_cache
flat: dict[str, list[Any]] = {}
for prefix, metrics in self._groups.items():
for key, values in metrics.items():
full_key = f"{prefix}/{key}" if prefix else key
flat[full_key] = values
return flat
@history.setter
def history(self, value: dict[str, list[Any]]) -> None:
"""Allow direct assignment for backward compat (e.g. reset)."""
self._groups.clear()
self._flat_cache = None
for full_key, values in value.items():
prefix, _, short_key = full_key.partition("/")
if not _:
prefix, short_key = "", full_key
group = self._groups.setdefault(prefix, {})
group[short_key] = values
@property
def data(self) -> dict[str, list[Any]]:
"""Alias for :attr:`history` (backward compat)."""
return self.history
@data.setter
def data(self, value: dict[str, list[Any]]) -> None:
self.history = value
@property
def groups(self) -> dict[str, dict[str, list[Any]]]:
"""Access the prefix-grouped metrics directly."""
return self._groups
@staticmethod
def _split_prefix(
metrics: dict[str, Any],
) -> tuple[str, dict[str, Any]]:
"""Extract prefix from metric keys and return (prefix, stripped_dict).
If any key contains ``"/"``, the part before the first ``"/"`` is
the prefix. All keys are then stripped of that prefix. If no key
contains ``"/"``, the prefix is ``""`` and keys are unchanged.
"""
prefix = ""
for key in metrics:
if "/" in key:
prefix = key.split("/", 1)[0]
break
if not prefix:
return "", dict(metrics)
stripped: dict[str, Any] = {}
for key, val in metrics.items():
if key.startswith(f"{prefix}/"):
stripped[key[len(prefix) + 1 :]] = val
else:
stripped[key] = val
return prefix, stripped
def _invalidate_flat_cache(self) -> None:
"""Mark the flattened history cache as stale."""
self._flat_cache = None
@property
def tracker(self) -> Tracker:
"""The internal Tracker instance.
Use this to access backend-specific features::
wb = history.tracker.get_backend("wandb")
if wb is not None:
wb.watch(model, log="all")
"""
return self._tracker
# ------------------------------------------------------------------ #
# Internal helpers
# ------------------------------------------------------------------ #
@staticmethod
def _utc_iso() -> str:
"""Return the current UTC timestamp in ISO-8601 format with trailing Z."""
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
def _configure_report_destination(self, base_dir: Path) -> None:
"""Configure the report directory to live under *base_dir*."""
base_dir = base_dir.expanduser().resolve()
self._report_root = base_dir
self._report_dir = base_dir
self._report_path = base_dir.joinpath(self._report_filename)
self._asset_dir = base_dir.joinpath("assets")
self._report_initialized = False
if self._jsonl_enabled and not self._jsonl_path_explicit:
self._jsonl_path = base_dir.joinpath(f"{self._run_id}.jsonl")
self._environment_written = False
self._metric_summary_written = False
def _ensure_report_file(self) -> Optional[Path]:
"""Ensure the markdown report and asset directories exist."""
if not self.report_enabled:
return None
if not self._report_initialized:
self._report_dir.mkdir(parents=True, exist_ok=True)
self._asset_dir = self._report_dir.joinpath("assets")
self._asset_dir.mkdir(parents=True, exist_ok=True)
self._report_path = self._report_dir.joinpath(
self._report_filename
)
header = (
f"# History Report ({self._run_id})\n\n"
f"_Generated at {self._utc_iso()}_\n\n"
)
self._report_path.write_text(header, encoding="utf-8")
self._report_initialized = True
return self._report_path
def _prepare_report_asset(self, source: Path) -> Optional[Path]:
"""Copy plot artifacts into the report asset directory."""
report_file = self._ensure_report_file()
if report_file is None:
return None
assert self._asset_dir is not None
source = source.resolve()
try:
if source.is_relative_to(self._asset_dir):
return source
except AttributeError: # Python < 3.9 fallback (not expected)
pass
destination = self._asset_dir.joinpath(source.name)
if destination != source:
try:
shutil.copy2(source, destination)
except OSError:
logger.warning(
"Unable to copy asset %s into report directory.", source
)
return source
return destination
def _write_plot_report(
self,
key: Optional[str],
asset_path: Path,
*,
kind: str,
metadata: Optional[dict[str, Any]] = None,
) -> None:
"""Append a markdown section describing the generated plot."""
report_file = self._ensure_report_file()
if report_file is None:
return
asset_path = asset_path.resolve()
if not asset_path.exists():
return
asset_path = self._prepare_report_asset(asset_path) or asset_path
try:
rel_path = asset_path.relative_to(report_file.parent)
except ValueError:
rel_path = asset_path
title = key or asset_path.stem
timestamp = self._utc_iso()
lines = [
f"## {title}",
"",
f"_Kind_: `{kind}` ",
f"_Generated_: {timestamp}",
"",
]
if asset_path.suffix.lower() in {".txt", ".log"}:
try:
text = asset_path.read_text(encoding="utf-8")
except OSError:
text = ""
snippet = "\n".join(text.splitlines()[:40]).rstrip("\n")
lines.extend(["```", snippet, "```", ""])
else:
lines.append(f"})")
lines.append("")
if metadata:
for meta_key, meta_val in metadata.items():
lines.append(f"- **{meta_key}**: {meta_val}")
lines.append("")
with report_file.open("a", encoding="utf-8") as handle:
handle.write("\n".join(lines))
if not lines[-1].endswith("\n"):
handle.write("\n")
def _wandb_log_matplotlib_asset(
self,
key: Optional[str],
asset_path: Optional[Path],
*,
kind: str = "matplotlib",
commit: bool = False,
) -> None:
if self._rank != 0 or asset_path is None:
return
asset_path = Path(asset_path)
if not asset_path.exists():
return
if asset_path.suffix.lower() not in {".png", ".jpg", ".jpeg"}:
return
title = key or asset_path.stem
self._tracker.log_image(
f"plots/{title}",
str(asset_path),
caption=f"{kind}:{title}",
)
def _write_environment_section(
self, env_info: Optional[dict[str, Any]]
) -> None:
"""Write environment details into the report."""
if (
not self.report_enabled
or env_info is None
or self._environment_written
):
return
report_file = self._ensure_report_file()
if report_file is None:
return
lines: list[str] = ["## Environment", ""]
for section, details in env_info.items():
if isinstance(details, dict):
lines.extend((f"### {section}", ""))
lines.extend(
f"- **{key}**: {value}" for key, value in details.items()
)
lines.append("")
else:
lines.append(f"- **{section}**: {details}")
with report_file.open("a", encoding="utf-8") as handle:
handle.write("\n".join(lines))
if not lines[-1].endswith("\n"):
handle.write("\n")
self._environment_written = True
def _default_environment_info(self) -> dict[str, dict[str, str]]:
"""Return a minimal environment summary."""
python_info = {
"Version": (
f"{sys.version_info.major}."
f"{sys.version_info.minor}."
f"{sys.version_info.micro}"
),
"Implementation": sys.implementation.name,
"Executable": sys.executable,
"Platform": platform.platform(),
}
try:
torch_version = torch.__version__
except Exception: # pragma: no cover - torch should be importable
torch_version = "unknown"
torch_info = {
"Version": torch_version,
}
path_info = {
"Working directory": str(Path.cwd()),
}
env_vars: dict[str, str] = {}
for key in (
"MASTER_ADDR",
"MASTER_PORT",
"WORLD_SIZE",
"RANK",
"LOCAL_RANK",
):
value = os.environ.get(key)
if value is not None:
env_vars[key] = value
summary: dict[str, dict[str, str]] = {
"Paths": path_info,
"Python": python_info,
"Torch": torch_info,
}
if env_vars:
summary["Environment Variables"] = env_vars
try:
dist_info = ezpz.get_dist_info()
summary["Distributed"] = {
str(k): str(v) for k, v in dist_info.items()
}
except Exception:
pass
_wandb_run = self._tracker.wandb_run
if _wandb_run is not None:
wb_info: dict[str, str] = {
"Run Name": _wandb_run.name,
"Project": _wandb_run.project,
"URL": _wandb_run.url,
}
wb_info.update(
{str(k): str(v) for k, v in _wandb_run.config.items()}
)
summary["Weights & Biases"] = wb_info
_mlflow_be = self._tracker.get_backend("mlflow")
if _mlflow_be is not None and getattr(_mlflow_be, "_active", False):
ml_info: dict[str, str] = {
"Run ID": getattr(_mlflow_be, "_run_id", ""),
"Experiment ID": getattr(_mlflow_be, "_experiment_id", ""),
"Tracking URI": getattr(_mlflow_be, "_tracking_uri", ""),
}
run_url = getattr(_mlflow_be, "run_url", None)
if run_url is not None:
ml_info["URL"] = run_url
summary["MLflow"] = ml_info
return summary
def _collect_metric_groups(
self, dataset: xr.Dataset
) -> dict[str, dict[str, float]]:
"""Return metric statistics grouped by base metric name."""
assert dataset is not None and hasattr(dataset, "data_vars")
groups: dict[str, dict[str, float]] = {}
for name in sorted(dataset.data_vars):
arr = dataset[name]
if arr.size == 0:
continue
try:
latest = arr.isel({arr.dims[0]: -1})
except Exception:
latest = arr
data = np.asarray(latest)
if data.size == 0:
continue
value = float(data.mean()) if data.ndim > 0 else float(data.item())
base, _, suffix = name.partition("_")
if suffix in {"mean", "max", "min", "std"}:
groups.setdefault(base, {})[suffix] = value
else:
groups.setdefault(name, {})["latest"] = value
return groups
def _write_metric_summary(self, dataset: xr.Dataset) -> None:
"""Append a metric overview table grouped by metric."""
if not self.report_enabled or self._metric_summary_written:
return
groups = self._collect_metric_groups(dataset)
if not groups:
return
report_file = self._ensure_report_file()
if report_file is None:
return
with report_file.open("a", encoding="utf-8") as handle:
handle.write("## Metric Overview\n\n")
for metric_name, stats in groups.items():
handle.write(f"### {metric_name}\n\n")
rows: list[tuple[str, str]] = []
for label in ("latest", "mean", "max", "min", "std"):
if label in stats:
value = stats[label]
rows.append((label.capitalize(), f"{value:.6f}"))
if rows:
header = ("Statistic", "Value")
col1 = max(len(header[0]), *(len(r[0]) for r in rows))
col2 = max(len(header[1]), *(len(r[1]) for r in rows))
handle.write(
f"| {header[0]:<{col1}} | {header[1]:>{col2}} |\n"
)
handle.write(
f"|:{'-' * (col1 - 1)} | {'-' * (col2 - 1)}:|\n"
)
for stat_label, stat_value in rows:
handle.write(
f"| {stat_label:<{col1}} | {stat_value:>{col2}} |\n"
)
handle.write("\n")
self._metric_summary_written = True
def _series_from_dataarray(self, data: xr.DataArray) -> np.ndarray:
"""Convert an xarray DataArray into a 1-D numerical series."""
arr = np.asarray(data.values)
if arr.ndim == 1:
return arr
if arr.ndim == 0:
return np.array([float(arr)])
axes = tuple(range(arr.ndim - 1))
return arr.mean(axis=axes)
def _group_metric_variables(
self, dataset: xr.Dataset
) -> dict[str, dict[str, xr.DataArray]]:
"""Group metric variables by base name and associated aggregates."""
groups: dict[str, dict[str, xr.DataArray]] = {}
for name, data_array in dataset.data_vars.items():
base, sep, suffix = name.rpartition("_")
if sep and base and suffix in {"mean", "max", "min", "std"}:
groups.setdefault(base, {})[suffix] = data_array
else:
groups.setdefault(name, {})["raw"] = data_array
return groups
def _plot_metric_group(
self,
name: str,
metric_vars: dict[str, xr.DataArray],
*,
warmup: Optional[float | int] = 0.0,
title: Optional[str] = None,
outdir: Optional[Path] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
verbose: bool = False,
) -> Optional[Path]:
"""Render a single matplotlib figure combining metric aggregates."""
import matplotlib.pyplot as plt
import seaborn as sns
subplots_kwargs = (
{} if subplots_kwargs is None else dict(subplots_kwargs)
)
plot_kwargs = {} if plot_kwargs is None else dict(plot_kwargs)
series_candidates = [
metric_vars.get("raw"),
metric_vars.get("mean"),
metric_vars.get("min"),
metric_vars.get("max"),
]
# base_series = None
# for candidate in series_candidates:
# if candidate is not None:
# base_series = self._series_from_dataarray(candidate)
# break
base_series = next(
(
self._series_from_dataarray(candidate)
for candidate in series_candidates
if candidate is not None
),
None,
)
if base_series is None or len(base_series) == 0:
return None
x = np.arange(base_series.shape[-1])
fig, ax = plt.subplots(**subplots_kwargs)
color = plot_kwargs.get("color")
raw_da = metric_vars.get("raw")
if raw_da is not None:
raw_series = self._series_from_dataarray(raw_da)
ax.plot(
x,
raw_series,
label=name,
color=color,
alpha=0.35,
linewidth=1.25,
)
mean_da = metric_vars.get("mean")
std_da = metric_vars.get("std")
min_da = metric_vars.get("min")
max_da = metric_vars.get("max")
mean_series = None
if mean_da is not None:
mean_series = self._series_from_dataarray(mean_da)
ax.plot(
x,
mean_series,
label=f"{name} mean",
color=color,
linewidth=2.0,
)
if mean_series is not None and std_da is not None:
std_series = self._series_from_dataarray(std_da)
upper = mean_series + std_series
lower = mean_series - std_series
ax.fill_between(
x,
lower,
upper,
color=color,
alpha=0.2,
label=f"{name} Β± std",
)
elif min_da is not None and max_da is not None:
min_series = self._series_from_dataarray(min_da)
max_series = self._series_from_dataarray(max_da)
ax.fill_between(
x,
min_series,
max_series,
color=color,
alpha=0.15,
label=f"{name} range",
)
if (
mean_da is None
and raw_da is None
and min_da is None
and max_da is None
):
# fall back to plotting whichever aggregate is available
for label, array in metric_vars.items():
series = self._series_from_dataarray(array)
ax.plot(x, series, label=f"{name} {label}", linewidth=1.75)
ax.set_xlabel("step")
ax.set_ylabel(name)
if title is not None:
ax.set_title(title)
sns.despine(ax=ax, top=True, right=True)
ax.legend(loc="best", frameon=False)
if outdir is None and self.report_enabled:
save_dir = self._report_dir.joinpath("mplot")
elif outdir is not None:
save_dir = Path(outdir)
else:
save_dir = None
primary_asset: Optional[Path] = None
if save_dir is not None:
save_dir = save_dir.expanduser().resolve()
save_dir.mkdir(parents=True, exist_ok=True)
asset_name = name.replace("/", "_")
dirs = {
"png": save_dir.joinpath("pngs"),
"svg": save_dir.joinpath("svgs"),
}
for directory in dirs.values():
directory.mkdir(parents=True, exist_ok=True)
if verbose:
logger.info("Saving %s plot to: %s", name, save_dir)
for ext, directory in dirs.items():
outfile = directory.joinpath(f"{asset_name}.{ext}")
if outfile.exists():
outfile = directory.joinpath(
f"{asset_name}-{get_timestamp()}.{ext}"
)
fig.savefig(outfile, dpi=400, bbox_inches="tight")
if primary_asset is None and ext == "png":
primary_asset = outfile
plt.close(fig)
return primary_asset
def _tplot_metric_group(
self,
name: str,
metric_vars: dict[str, xr.DataArray],
*,
warmup: Optional[float | int] = 0.0,
outdir: Optional[Path] = None,
plot_type: Optional[str] = None,
marker: Optional[str] = None,
verbose: bool = False,
logfreq: Optional[int] = None,
) -> Optional[Path]:
"""Render grouped metrics into a single text-based plot asset."""
outdir = Path(outdir) if outdir is not None else None
if outdir is None and self.report_enabled:
outdir = self._report_dir.joinpath("tplot")
if outdir is None:
return None
outdir = outdir.expanduser().resolve()
outdir.mkdir(parents=True, exist_ok=True)
asset_path = outdir.joinpath(f"{name.replace('/', '_')}.txt")
summary_path = outdir.joinpath(f"{name.replace('/', '_')}_summary.txt")
hist_path = outdir.joinpath(f"{name.replace('/', '_')}_hist.txt")
order = [
("raw", name),
("mean", f"{name} mean"),
("max", f"{name} max"),
("min", f"{name} min"),
("std", f"{name} std"),
]
stats_keys = ("mean", "max", "min", "std")
stats_present = any(key in metric_vars for key in stats_keys)
stats_nonzero = False
if stats_present:
nonzero_keys = [
key for key in ("mean", "max", "min") if key in metric_vars
]
if nonzero_keys:
stats_nonzero = any(
np.any(
np.nan_to_num(
self._series_from_dataarray(metric_vars[key])
)
!= 0
)
for key in nonzero_keys
)
try:
_ = plotext_prepare_figure(theme="clear")
except ModuleNotFoundError: # pragma: no cover - optional dependency
logger.error(
"Unable to import `plotext` which is needed for text-based plotting."
)
return None
resolved_plot_type = (
plot_type
if plot_type is not None
else os.environ.get("EZPZ_TPLOT_TYPE")
)
resolved_marker = (
marker
if marker is not None
else os.environ.get("EZPZ_TPLOT_MARKER")
)
if resolved_marker is None and resolved_plot_type != "hist":
resolved_marker = DEFAULT_MARKER
def _metric_marker(metric_key: str) -> Optional[str]:
key = metric_key.split("/")[-1]
mapped = MARKER_MAP.get(key)
if mapped == "line":
return None
if mapped is not None:
return mapped
return resolved_marker
use_subplots = stats_present and stats_nonzero
plt = None
left = None
right = None
if use_subplots:
try:
plt, left, right = plotext_subplots(
left_layout=(2, 1),
right_layout=(3, 1),
# height_scale=8.0,
)
except (
ModuleNotFoundError
): # pragma: no cover - optional dependency
use_subplots = False
wrote_any = False
points = 0
# Whether to also print the multi-pane grid (raw/mean | min/std/max)
# and the 2x2 histogram grid to stdout. Off by default β these
# take ~30 lines of console real estate each and the same data
# is in the saved .txt file. Set EZPZ_TPLOT_STDOUT=1 to restore
# the legacy behavior. The single-pane overlay (mean/max/min/raw
# on one chart) IS always printed β it's the useful summary.
_show_grids = os.environ.get("EZPZ_TPLOT_STDOUT", "0") == "1"
if use_subplots:
assert plt is not None and left is not None and right is not None
left_slots = [
(1, "raw", name, None),
(2, "mean", f"{name}/mean", "green"),
]
right_slots = [
(1, "min", f"{name}/min", "cyan"),
(2, "std", f"{name}/std", "magenta"),
(3, "max", f"{name}/max", "red"),
]
for row, key, label, color in left_slots:
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
points = max(points, len(series))
if left is not None and hasattr(left, "subplot"):
left.subplot(row, 1)
plotext_plot_series(
plt,
series,
label=None,
color=color,
plot_type=resolved_plot_type,
marker=_metric_marker(key),
)
if plt is not None and hasattr(plt, "title"):
plt.title(label)
if plt is not None and hasattr(plt, "xlabel"):
plt.xlabel("step")
if plt is not None and hasattr(plt, "ylabel"):
plt.ylabel(label)
wrote_any = True
for row, key, label, color in right_slots:
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
points = max(points, len(series))
if hasattr(right, "subplot"):
right.subplot(row, 1)
plotext_plot_series(
plt,
series,
label=None,
color=color,
plot_type=resolved_plot_type,
marker=_metric_marker(key),
)
if hasattr(plt, "title"):
plt.title(label)
if hasattr(plt, "xlabel"):
plt.xlabel("step")
if hasattr(plt, "ylabel"):
plt.ylabel(label)
wrote_any = True
if wrote_any:
if _show_grids:
plt.show()
# File is saved either way β `tplot/<metric>.txt` still
# has the multi-pane grid for later inspection.
plt.savefig(
asset_path.as_posix(), append=False, keep_colors=True
)
if stats_present:
plt = plotext_prepare_figure(theme="clear")
plotext_set_size(plt, min_height=40)
overlay_order = [
("max", f"{name}/max", "red"),
("min", f"{name}/min", "cyan"),
("mean", f"{name}/mean", "green"),
("raw", name, None),
]
overlay_points = 0
for key, label, color in overlay_order:
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
overlay_points = max(overlay_points, len(series))
plotext_plot_series(
plt,
series,
label=label,
color=color,
marker=_metric_marker(key),
)
if overlay_points > 0:
plt.show()
plt.savefig(
summary_path.as_posix(),
append=False,
keep_colors=True,
)
if self.report_enabled:
self._write_plot_report(
f"{name} summary",
summary_path,
kind="tplot",
metadata={
"components": ", ".join(
key
for key, _, _ in overlay_order
if key in metric_vars
),
"points": overlay_points,
},
)
hist_order = [
("mean", f"{name}/mean"),
("max", f"{name}/max"),
("min", f"{name}/min"),
("std", f"{name}/std"),
]
plt = plotext_prepare_figure(theme="clear")
plotext_set_size(plt)
plt.subplots(2, 2)
hist_points = 0
for idx, (key, label) in enumerate(hist_order, start=1):
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
hist_points = max(hist_points, len(series))
row = ((idx - 1) // 2) + 1
col = ((idx - 1) % 2) + 1
if hasattr(plt, "subplot"):
plt.subplot(row, col)
plotext_hist_series(plt, series, label=None)
if hasattr(plt, "title"):
plt.title(f"{label} hist")
if hist_points > 0:
if _show_grids:
plt.show()
# Saved to `tplot/<metric>_hist.txt` either way.
plt.savefig(
hist_path.as_posix(),
append=False,
keep_colors=True,
)
if self.report_enabled:
self._write_plot_report(
f"{name} hist",
hist_path,
kind="tplot-hist",
metadata={
"components": ", ".join(
key
for key, _ in hist_order
if key in metric_vars
),
"points": hist_points,
},
)
if not use_subplots:
append_flag = False
for key, label in order:
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
points = max(points, len(series))
self._tplot(
y=series,
xlabel="step",
ylabel=label,
append=append_flag,
outfile=asset_path.as_posix(),
verbose=verbose,
plot_type=resolved_plot_type,
marker=_metric_marker(key),
logfreq=(1 if logfreq is None else logfreq),
record_report=False,
# The per-stat raw/mean/min/max/std panes are what
# used to spam stdout β gate them behind the same
# opt-in. Files still written to `tplot/<metric>.txt`.
quiet=not _show_grids,
)
append_flag = True
wrote_any = True
if stats_present:
overlay_order = [
("max", f"{name}/max", "red"),
("min", f"{name}/min", "cyan"),
("mean", f"{name}/mean", "green"),
("raw", name, None),
]
# Build a SINGLE combined overlay (one show, one savefig)
# β not 4 sequential _tplot(append=True) calls, which
# would print 4 separate panes to stdout. This matches
# the use_subplots branch's behavior.
overlay_plt = plotext_prepare_figure(theme="clear")
plotext_set_size(overlay_plt, min_height=40)
overlay_points = 0
for key, label, color in overlay_order:
data_array = metric_vars.get(key)
if data_array is None:
continue
series = self._series_from_dataarray(data_array)
overlay_points = max(overlay_points, len(series))
plotext_plot_series(
overlay_plt,
series,
label=label,
color=color,
marker=_metric_marker(key),
)
if overlay_points > 0:
overlay_plt.show()
overlay_plt.savefig(
summary_path.as_posix(),
append=False,
keep_colors=True,
)
if (
overlay_points > 0
and self.report_enabled
and summary_path.exists()
):
self._write_plot_report(
f"{name} summary",
summary_path,
kind="tplot",
metadata={
"components": ", ".join(
key
for key, _, _ in overlay_order
if key in metric_vars
),
"points": overlay_points,
},
)
if wrote_any and self.report_enabled:
self._write_plot_report(
name,
asset_path,
kind="tplot",
metadata={
"components": ", ".join(
key for key, _ in order if key in metric_vars
),
"points": points,
},
)
return asset_path
def _write_jsonl_entry(
self,
metrics: dict[str, Any],
aggregated: Optional[dict[str, float]] = None,
) -> None:
"""Append metrics to the configured JSONL log."""
if not self._jsonl_enabled:
return
if self._jsonl_path is None:
return
payload: dict[str, Any] = {
"timestamp": time.time(),
"rank": self._rank,
"metrics": metrics,
}
if aggregated and self._rank == 0:
payload["aggregated"] = aggregated
# Hold the JSONL lock for the open/write/close cycle so a
# concurrent finalize() can't move the file out from under us.
# Each call still does its own open/close so the lock is held
# for microseconds β no contention in practice.
with self._jsonl_lock:
try:
self._jsonl_path.parent.mkdir(parents=True, exist_ok=True)
with self._jsonl_path.open("a", encoding="utf-8") as handle:
handle.write(
json.dumps(payload, default=self._to_serializable)
)
handle.write("\n")
except OSError:
logger.warning(
"Unable to write JSONL metrics to %s", self._jsonl_path
)
@classmethod
def _to_serializable(cls, value: Any) -> Any:
"""Convert values to JSON-serializable structures."""
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (np.floating, np.integer, np.bool_)):
return value.item()
if isinstance(value, Path):
return value.as_posix()
if torch.is_tensor(value):
tensor = value.detach()
if tensor.numel() == 1:
return tensor.item()
return tensor.cpu().tolist()
if isinstance(value, np.ndarray):
if value.shape == ():
return value.item()
return value.tolist()
if isinstance(value, dict):
return {
key: cls._to_serializable(sub_value)
for key, sub_value in value.items()
}
if isinstance(value, (list, tuple)):
return [cls._to_serializable(item) for item in value]
if hasattr(value, "item"):
try:
return value.item()
except Exception:
pass
return str(value)
@classmethod
def _sanitize_metrics(cls, metrics: dict[str, Any]) -> dict[str, Any]:
"""Return a copy of metrics with values converted to JSON-safe types."""
return {
key: cls._to_serializable(value) for key, value in metrics.items()
}
def _iter_scalar_metrics(
self, metrics: dict[str, Any]
) -> Iterable[tuple[str, float]]:
"""Yield scalar metrics suitable for distributed reductions."""
for key, value in metrics.items():
if isinstance(value, (int, float)):
yield key, float(value)
elif isinstance(value, np.ndarray) and value.shape == ():
yield key, float(value.item())
elif torch.is_tensor(value) and value.numel() == 1:
yield key, float(value.item())
def _select_metric_device(self) -> torch.device:
"""Return the device to use for distributed metric aggregation."""
candidate = ezpz.get_torch_device(as_torch_device=True)
device = (
candidate
if isinstance(candidate, torch.device)
else torch.device(str(candidate))
)
device_type = device.type
if device_type == "mps":
return torch.device("cpu")
if device_type == "cuda" and not torch.cuda.is_available():
return torch.device("cpu")
xpu_backend = getattr(torch, "xpu", None)
if device_type == "xpu" and not (
xpu_backend and xpu_backend.is_available()
):
return torch.device("cpu")
return device
def _compute_distributed_metrics(
self, metrics: dict[str, Any]
) -> dict[str, float]:
"""Compute distributed reductions for scalar metrics."""
if not self.distributed_history or self._dist is None:
return {}
try:
if (
not self._dist.is_available()
or not torch.distributed.is_initialized() # type: ignore[attr-defined]
):
return {}
except AttributeError:
return {}
scalars = dict(self._iter_scalar_metrics(metrics))
if not scalars:
return {}
metric_device = self._select_metric_device()
dtype = torch.get_default_dtype()
values = torch.tensor(
list(scalars.values()),
dtype=dtype,
device=metric_device,
)
# Promote to fp64 *before* squaring so the variance arithmetic
# `E[X^2] - E[X]^2` doesn't lose its signal to fp32 cancellation.
# For typical training metrics (`tokens_per_sec β 1e5`,
# `tflops β 40`, etc.) the squared values are 8-10 orders of
# magnitude larger than the across-rank variance, and the
# subtraction collapses to exactly 0.0 in fp32 β `tflops/std=0.0`
# is then indistinguishable from "all ranks identical" and the
# console summary drops the `(Β±std)` parenthetical entirely.
# fp64 has ~16 sig digits of headroom which covers any realistic
# training metric. Cost: two casts + ~2x bandwidth on both the
# sum and squared-sum all-reduces (max/min stay in the original
# dtype since they don't suffer cancellation). Negligible vs.
# the metric collection itself.
#
# Some accelerator backends (notably some Intel XPU devices)
# don't support fp64 natively. We probe the cast + a trivial
# fp64 op BEFORE starting any collectives β if it fails, we
# commit to the original dtype for the whole reduce so we
# don't leave half the ranks in fp64 and half in fp32 (which
# would deadlock the all_reduce). `copy=True` on the cast
# also defends against aliasing when `values` is already fp64
# (otherwise the in-place all_reduce below would mutate it).
try:
sum_vals = values.to(torch.float64, copy=True)
_ = sum_vals.square().sum().item() # probe fp64 arithmetic
sq_vals = sum_vals.square()
except RuntimeError as exc:
logger.warning(
"fp64 promotion failed on %s (%s); falling back to "
"%s for distributed std. `/std` for large-magnitude "
"metrics may collapse to 0.0 due to fp32 cancellation.",
metric_device,
exc,
dtype,
)
sum_vals = values.clone()
sq_vals = values.square()
max_vals = values.clone()
min_vals = values.clone()
# world_size = ezpz.distributed.get_world_size()
# world_size = self._dist.get_world_size()
if (world_size := ezpz.distributed.get_world_size()) <= 1:
return {
f"{key}/{suffix}": (value if suffix != "std" else 0.0)
for key, value in scalars.items()
for suffix in ("mean", "max", "min", "std")
}
# ezpz.distributed.all_reduce(sum_vals, op=ops.SUM, implementation="torch")
# ezpz.distributed.all_reduce(sq_vals, op=ops.SUM, implementation="torch")
# ezpz.distributed.all_reduce(max_vals, op=ops.MAX, implementation="torch")
# ezpz.distributed.all_reduce(min_vals, op=ops.MIN, implementation="torch")
# ops = self._dist.ReduceOp # type: ignore[attr-defined]
ops = torch.distributed.ReduceOp # type: ignore[attr-defined]
torch.distributed.all_reduce(sum_vals, op=ops.SUM)
torch.distributed.all_reduce(sq_vals, op=ops.SUM)
torch.distributed.all_reduce(max_vals, op=ops.MAX)
torch.distributed.all_reduce(min_vals, op=ops.MIN)
mean_vals = sum_vals.div(world_size)
var_vals = sq_vals.div(world_size).sub(mean_vals.square())
std_vals = var_vals.clamp_min_(0.0).sqrt_()
stats: dict[str, float] = {}
for idx, key in enumerate(scalars):
# if any([s in key] for s in ["iter", "epoch", "step", "batch"]):
# continue
stats[f"{key}/mean"] = float(mean_vals[idx].item())
stats[f"{key}/max"] = float(max_vals[idx].item())
stats[f"{key}/min"] = float(min_vals[idx].item())
stats[f"{key}/std"] = float(std_vals[idx].item())
return stats
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def _update(
self,
key: str,
val: Union[Any, ScalarLike, list, torch.Tensor, np.ndarray],
*,
prefix: str = "",
):
"""
Update the history with a new key-value pair.
Args:
key (str): The key to update in the group (without prefix).
val: The value to associate with the key.
prefix: The group prefix (e.g. "train", "eval", "").
"""
group = self._groups.setdefault(prefix, {})
try:
group[key].append(val)
except KeyError:
group[key] = [val]
self._invalidate_flat_cache()
return val
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def update(
self,
metrics: dict,
precision: int = 6,
use_wandb: Optional[bool] = None,
commit: Optional[bool] = True,
summarize: Optional[bool] = True,
step: Optional[int] = None,
) -> str:
"""
Update the history with a dictionary of metrics.
Args:
metrics (dict): Dictionary of metrics to update the history with.
precision (int): Precision for summarizing the metrics.
use_wandb (Optional[bool]): Whether to log the metrics to Weights & Biases.
commit (Optional[bool]): Whether to commit the log to Weights & Biases.
summarize (Optional[bool]): Whether to summarize the metrics.
"""
prefix, stripped = self._split_prefix(metrics)
group = self._groups.setdefault(prefix, {})
for key, val in stripped.items():
try:
group[key].append(val)
except KeyError:
group[key] = [val]
self._invalidate_flat_cache()
aggregated_metrics = self._compute_distributed_metrics(metrics)
if aggregated_metrics and self._rank == 0:
for agg_key, agg_val in aggregated_metrics.items():
# Aggregated keys look like "train/loss/mean" β strip the
# same prefix so they land in the same group as raw metrics.
if prefix and agg_key.startswith(f"{prefix}/"):
short_agg_key = agg_key[len(prefix) + 1 :]
else:
short_agg_key = agg_key
self._update(short_agg_key, agg_val, prefix=prefix)
metrics_for_logging = dict(metrics)
if aggregated_metrics and self._rank == 0:
metrics_for_logging.update(aggregated_metrics)
sanitized_metrics = self._sanitize_metrics(metrics_for_logging)
summary_source = (
sanitized_metrics
if aggregated_metrics and self._rank == 0
else self._sanitize_metrics(metrics)
)
if use_wandb is not None:
warnings.warn(
"The 'use_wandb' parameter is deprecated. Use "
"backends='wandb' in the History constructor instead.",
DeprecationWarning,
stacklevel=2,
)
self._tracker.log(sanitized_metrics, step=step, commit=commit)
self._write_jsonl_entry(sanitized_metrics, aggregated_metrics)
if summarize:
from ezpz.utils import (
format_compact_summary,
format_memory_summary,
)
# format_compact_summary handles all the noise reduction:
# - collapses base + */std into `key=value(Β±std)`
# - drops the */mean /min /max /avg companions
# - strips memory keys (handled separately below)
# - leaves counter-like keys (iter/step/epoch/...) bare
base = format_compact_summary(
summary_source, precision=precision
)
# Build the compact memory string from the RAW metrics dict
# (which still has the 4 keys even after the filter above).
# Empty string when no memory keys, e.g. on CPU/MPS.
#
# `prefix` here came from `_split_prefix(metrics)` and is the
# bare namespace WITHOUT a trailing slash (e.g. "train").
# `format_memory_summary` expects either a full prefix WITH
# slash ("train/") or None for auto-detection. Passing
# "train" directly would make the lookup miss
# ("trainmem_alloc" β no key matches) and silently drop the
# memory= token from the line. Use None and let the helper
# infer the prefix from the *mem_alloc keys it scans.
memory_str = format_memory_summary(metrics, prefix=None)
parts = [p for p in (base, f"memory={memory_str}" if memory_str else "") if p]
return " ".join(parts)
return ""
@staticmethod
def split_metrics_for_logging(
metrics: dict[str, Any],
debug_prefixes: tuple[str, ...] = ("hist/",),
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Split metrics into info-level and debug-level groups.
Keys starting with any of ``debug_prefixes`` are placed in the
debug dict; everything else goes into the info dict.
"""
info_metrics: dict[str, Any] = {}
debug_metrics: dict[str, Any] = {}
for key, value in metrics.items():
if key.startswith(debug_prefixes):
debug_metrics[key] = value
else:
info_metrics[key] = value
return info_metrics, debug_metrics
@staticmethod
def summarize_min_max_std(
metrics: dict[str, Any],
) -> dict[str, float]:
"""Compute mean/min/max/std for each numeric metric."""
numeric: dict[str, list[float]] = {}
for key, value in metrics.items():
if isinstance(value, (int, float)):
numeric[key] = [float(value)]
elif torch.is_tensor(value) and value.numel() == 1:
numeric[key] = [float(value.item())]
summary: dict[str, float] = {}
for key, values in numeric.items():
if not values:
continue
t = torch.tensor(values)
summary[f"{key}/mean"] = float(t.mean().item())
summary[f"{key}/min"] = float(t.min().item())
summary[f"{key}/max"] = float(t.max().item())
summary[f"{key}/std"] = float(t.std(unbiased=False).item())
return summary
def summarize_distributed_min_max_std(
self, metrics: dict[str, Any]
) -> dict[str, float]:
"""Compute distributed mean/min/max/std via all-reduce.
Falls back to local ``summarize_min_max_std`` when distributed
stats are unavailable. All-zero entries are pruned.
"""
summary_stats = self._compute_distributed_metrics(metrics)
if not summary_stats:
summary_stats = self.summarize_min_max_std(metrics)
filtered: dict[str, float] = {
k: v
for k, v in summary_stats.items()
if k.endswith(("/mean", "/min", "/max", "/std"))
}
keys = {k.rsplit("/", 1)[0] for k in filtered}
pruned: dict[str, float] = {}
for base in keys:
mean_v = filtered.get(f"{base}/mean")
min_v = filtered.get(f"{base}/min")
max_v = filtered.get(f"{base}/max")
std_v = filtered.get(f"{base}/std")
if (
mean_v == 0.0
and min_v == 0.0
and max_v == 0.0
and std_v == 0.0
):
continue
# Skip zero-variance metrics (min == max) β no useful info
if min_v is not None and max_v is not None and min_v == max_v:
continue
if mean_v is not None:
pruned[f"{base}/mean"] = mean_v
if min_v is not None:
pruned[f"{base}/min"] = min_v
if max_v is not None:
pruned[f"{base}/max"] = max_v
if std_v is not None:
pruned[f"{base}/std"] = std_v
return pruned
def log_metrics(
self,
metrics: dict[str, Any],
*,
logger: Optional[Any] = None,
debug_prefixes: tuple[str, ...] = ("hist/",),
include_summary: bool = True,
rank0_only_summary: bool = True,
precision: int = 6,
omit_counter_metrics: bool = True,
counter_tokens: tuple[str, ...] = (
"iter",
"epoch",
"step",
"batch",
"idx",
"bidx",
),
) -> None:
"""Log metrics, routing debug-prefixed keys to debug level.
Args:
metrics: Dict of metric name to scalar value.
include_summary: If True, append distributed min/max/std summary.
omit_counter_metrics: If True, skip counter keys (iter, epoch, etc.).
"""
log = logger if logger is not None else get_logger(__name__)
info_metrics, debug_metrics = self.split_metrics_for_logging(
metrics, debug_prefixes=debug_prefixes
)
def _is_counter_key(key: str) -> bool:
parts = key.replace("\\", "/").split("/")
if not parts:
return False
last = parts[-1]
for token in counter_tokens:
if last == token or last.endswith(f"_{token}"):
return True
return False
# Merge distributed min/max/std stats INTO the base info dict so
# format_compact_summary can collapse them into `key=value(Β±std)`
# form instead of emitting a second verbose line.
merged_for_summary: dict[str, Any] = dict(info_metrics)
if include_summary:
summary_input = info_metrics
if omit_counter_metrics:
summary_input = {
k: v
for k, v in info_metrics.items()
if not _is_counter_key(k)
}
summary_stats = self.summarize_distributed_min_max_std(
summary_input
)
if summary_stats and (not rank0_only_summary or self._rank == 0):
merged_for_summary.update(summary_stats)
from ezpz.utils import (
format_compact_summary,
format_memory_summary,
)
# format_compact_summary handles the noise reduction:
# - collapses base + */std into `key=value(Β±std)`
# - drops */mean /min /max /avg companions
# - strips memory keys (formatted separately below)
# - leaves counter keys (iter/step/epoch/...) bare
base = format_compact_summary(
merged_for_summary, precision=precision
).replace("train/", "")
# prefix=None lets format_memory_summary auto-detect "train/" /
# "eval/" / "" from the keys, so we don't have to probe twice.
memory_str = format_memory_summary(info_metrics)
parts = [
p
for p in (base, f"memory={memory_str}" if memory_str else "")
if p
]
info_msg = " ".join(parts)
if info_msg:
log.info(info_msg)
debug_msg = summarize_dict(debug_metrics, precision=precision).replace(
"train/", ""
)
if debug_msg:
log.debug(debug_msg)
def _tplot(
self,
y: np.ndarray,
x: Optional[np.ndarray] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
append: bool = True,
title: Optional[str] = None,
verbose: bool = False,
outfile: Optional[str] = None,
logfreq: Optional[int] = None,
plot_type: Optional[str] = None,
marker: Optional[str] = None,
record_report: bool = True,
quiet: bool = False,
):
"""
Create a text plot of the given data.
Args:
y (np.ndarray): The data to plot.
x (Optional[np.ndarray]): The x-axis data.
xlabel (Optional[str]): The x-axis label.
ylabel (Optional[str]): The y-axis label.
append (bool): Whether to append to an existing plot.
title (Optional[str]): The title of the plot.
verbose (bool): Whether to print the plot.
outfile (Optional[str]): The path to save the plot to.
logfreq (Optional[int]): The log frequency of the plot.
plot_type (Optional[str]): The type of plot to create.
"""
outfile_path: Optional[Path] = None
if outfile is None and self.report_enabled:
label = (ylabel or xlabel or "metric").replace("/", "_")
default_dir = self._report_dir.joinpath("tplot")
default_dir.mkdir(parents=True, exist_ok=True)
outfile_path = default_dir.joinpath(
f"{label}-{get_timestamp()}.txt"
)
outfile = outfile_path.as_posix()
elif outfile is not None:
outfile_path = Path(outfile)
if xlabel is not None and ylabel == xlabel:
return
if len(y) > 1:
x = x if x is not None else np.arange(len(y))
assert x is not None
eztplot(
y=y,
x=x,
xlabel=xlabel,
ylabel=ylabel,
logfreq=(1 if logfreq is None else logfreq),
append=append,
verbose=verbose,
outfile=outfile,
plot_type=plot_type,
marker=marker,
title=title,
quiet=quiet,
# plot_type=('scatter' if 'dt' in ylabel else None),
)
if (
record_report
and self.report_enabled
and outfile_path is not None
):
self._write_plot_report(
ylabel,
outfile_path,
kind="tplot",
metadata={"points": len(y)},
)
if ylabel is not None and "dt" in ylabel:
of = Path(outfile) if outfile is not None else None
if of is not None:
of = Path(of.parent).joinpath(f"{of.stem}-hist{of.suffix}")
eztplot(
y=y,
xlabel=ylabel,
title=title,
ylabel="freq",
append=append,
verbose=verbose,
outfile=(of if of is not None else None),
plot_type="hist",
marker=marker,
quiet=quiet,
)
if record_report and self.report_enabled and of is not None:
self._write_plot_report(
f"{ylabel}-hist",
of,
kind="tplot-hist",
metadata={"points": len(y)},
)
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot(
self,
val: np.ndarray,
key: Optional[str] = None,
warmup: Optional[float] = 0.0,
num_chains: Optional[int] = 128,
title: Optional[str] = None,
outdir: Optional[os.PathLike] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
verbose: bool = False,
):
"""
Plot a single variable from the history.
NOTE: The `warmup` argument can be used to drop the first `warmup`
iterations (as a percent of the total number of iterations) from the
plot.
Args:
val (np.ndarray): The data to plot.
key (Optional[str]): The key for the data.
warmup (Optional[float]): The percentage of iterations to drop from the
beginning of the plot.
num_chains (Optional[int]): The number of chains to plot.
title (Optional[str]): The title of the plot.
outdir (Optional[os.PathLike]): The directory to save the plot to.
subplots_kwargs (Optional[dict[str, Any]]): Additional arguments for
subplots.
plot_kwargs (Optional[dict[str, Any]]): Additional arguments for plotting.
verbose (bool): Emit additional logging when saving plots.
"""
import matplotlib.pyplot as plt
LW = plt.rcParams.get("axes.linewidth", 1.75)
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
figsize = subplots_kwargs.get("figsize", ezplot.set_size())
subplots_kwargs.update({"figsize": figsize})
num_chains = 16 if num_chains is None else num_chains
# tmp = val[0]
arr = np.array(val)
subfigs = None
steps = np.arange(arr.shape[0])
if warmup is not None and warmup > 0 and arr.size > 0:
if isinstance(warmup, int) or warmup >= 1:
warmup_frac = float(warmup) / float(arr.shape[0])
else:
warmup_frac = float(warmup)
warmup_frac = min(max(warmup_frac, 0.0), 1.0)
drop = min(int(round(warmup_frac * arr.shape[0])), arr.shape[0])
if drop > 0:
arr = arr[drop:]
steps = steps[drop:]
if len(arr.shape) == 2:
import seaborn as sns
_ = subplots_kwargs.pop("constrained_layout", True)
figsize = (3 * figsize[0], 1.5 * figsize[1])
fig = plt.figure(figsize=figsize, constrained_layout=True)
subfigs = fig.subfigures(1, 2)
gs_kw = {"width_ratios": [1.33, 0.33]}
(ax, ax1) = subfigs[1].subplots(
1, 2, sharey=True, gridspec_kw=gs_kw
)
ax.grid(alpha=0.2)
ax1.grid(False)
color = plot_kwargs.get("color", None)
label = r"$\langle$" + f" {key} " + r"$\rangle$"
ax.plot(
steps, arr.mean(-1), lw=1.5 * LW, label=label, **plot_kwargs
)
sns.kdeplot(y=arr.flatten(), ax=ax1, color=color, shade=True)
ax1.set_xticks([])
ax1.set_xticklabels([])
# ax1.set_yticks([])
# ax1.set_yticklabels([])
sns.despine(ax=ax, top=True, right=True)
sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
# ax.legend(loc='best', frameon=False)
ax1.set_xlabel("")
# ax1.set_ylabel('')
# ax.set_yticks(ax.get_yticks())
# ax.set_yticklabels(ax.get_yticklabels())
# ax.set_ylabel(key)
# _ = subfigs[1].subplots_adjust(wspace=-0.75)
axes = (ax, ax1)
else:
if len(arr.shape) == 1:
fig, ax = plt.subplots(**subplots_kwargs)
# assert isinstance(ax, plt.Axes)
ax.plot(steps, arr, **plot_kwargs)
axes = ax
elif len(arr.shape) == 3:
fig, ax = plt.subplots(**subplots_kwargs)
# assert isinstance(ax, plt.Axes)
cmap = plt.get_cmap("viridis")
nlf = arr.shape[1]
for idx in range(nlf):
# y = arr[:, idx, :].mean(-1)
# pkwargs = {
# 'color': cmap(idx / nlf),
# 'label': f'{idx}',
# }
# ax.plot(steps, y, **pkwargs)
label = plot_kwargs.pop("label", None)
if label is not None:
label = f"{label}-{idx}"
y = arr[:, idx, :]
color = cmap(idx / y.shape[1])
plot_kwargs["color"] = cmap(idx / y.shape[1])
if len(y.shape) == 2:
# TOO: Plot chains
if num_chains > 0:
for idx in range(min((num_chains, y.shape[1]))):
_ = ax.plot(
steps,
y[:, idx], # color,
lw=LW / 2.0,
alpha=0.8,
**plot_kwargs,
)
_ = ax.plot(
steps,
y.mean(-1), # color=color,
label=label,
**plot_kwargs,
)
else:
_ = ax.plot(
steps,
y, # color=color,
label=label,
**plot_kwargs,
)
axes = ax
else:
raise ValueError("Unexpected shape encountered")
ax.set_ylabel(key)
if num_chains > 0 and len(arr.shape) > 1:
# lw = LW / 2.
for idx in range(min(num_chains, arr.shape[1])):
# ax = subfigs[0].subplots(1, 1)
# plot values of invidual chains, arr[:, idx]
# where arr[:, idx].shape = [ndraws, 1]
ax.plot(
steps, arr[:, idx], alpha=0.5, lw=LW / 2.0, **plot_kwargs
)
ax.set_xlabel("step")
if title is not None:
fig.suptitle(title)
save_dir: Optional[Path]
if outdir is not None:
save_dir = Path(outdir).expanduser().resolve()
elif self.report_enabled:
save_dir = self._report_dir.joinpath("mplot")
else:
save_dir = None
if save_dir is not None:
# plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
# dpi=400, bbox_inches='tight')
save_dir.mkdir(parents=True, exist_ok=True)
outfile = save_dir.joinpath(f"{key}.svg")
if outfile.is_file():
tstamp = ezpz.get_timestamp()
pngdir = save_dir.joinpath("pngs")
pngdir.mkdir(exist_ok=True, parents=True)
pngfile = pngdir.joinpath(f"{key}-{tstamp}.png")
svgfile = save_dir.joinpath(f"{key}-{tstamp}.svg")
plt.savefig(pngfile, dpi=400, bbox_inches="tight")
plt.savefig(svgfile, dpi=400, bbox_inches="tight")
primary_asset: Optional[Path] = None
if save_dir is not None:
dirs = {
"png": Path(save_dir).joinpath("pngs/"),
"svg": Path(save_dir).joinpath("svgs/"),
}
_ = [i.mkdir(exist_ok=True, parents=True) for i in dirs.values()]
for ext, d in dirs.items():
outfile = d.joinpath(f"{key}.{ext}")
if outfile.is_file():
outfile = d.joinpath(f"{key}-subfig.{ext}")
if verbose:
logger.info(f"Saving {key} plot to: {outfile.resolve()}")
plt.savefig(outfile, dpi=400, bbox_inches="tight")
if primary_asset is None and ext == "png":
primary_asset = outfile
if (
self.report_enabled
and primary_asset is not None
and Path(primary_asset).exists()
):
self._write_plot_report(
key,
primary_asset,
kind="matplotlib",
metadata={"shape": list(arr.shape)},
)
self._wandb_log_matplotlib_asset(key, primary_asset, kind="matplotlib")
return fig, subfigs, axes
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot_dataArray(
self,
val: xr.DataArray,
key: Optional[str] = None,
warmup: Optional[float] = 0.0,
num_chains: Optional[int] = 0,
title: Optional[str] = None,
outdir: Optional[str] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
verbose: bool = False,
line_labels: bool = False,
logfreq: Optional[int] = None,
):
"""
Plot a single variable from the history as an xarray DataArray.
Args:
val (xr.DataArray): The data to plot.
key (Optional[str]): The key for the data.
warmup (Optional[float]): The percentage of iterations to drop from the
beginning of the plot.
num_chains (Optional[int]): The number of chains to plot.
title (Optional[str]): The title of the plot.
outdir (Optional[str]): The directory to save the plot to.
subplots_kwargs (Optional[dict[str, Any]]): Additional arguments for
subplots.
plot_kwargs (Optional[dict[str, Any]]): Additional arguments for plotting.
verbose (bool): Whether to print the plot.
line_labels (bool): Whether to label lines in the plot.
logfreq (Optional[int]): The log frequency of the plot.
"""
import matplotlib.pyplot as plt
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
ezplot.set_plot_style()
plt.rcParams["axes.labelcolor"] = "#bdbdbd"
figsize = subplots_kwargs.get("figsize", ezplot.set_size())
subplots_kwargs.update({"figsize": figsize})
subfigs = None
# if key == 'dt':
# warmup = 0.2
arr = val.values # shape: [nchains, ndraws]
# steps = np.arange(len(val.coords['draw']))
steps = val.coords["draw"]
if warmup is not None and warmup > 0.0 and arr.size > 0:
if isinstance(warmup, int) or warmup >= 1:
warmup_frac = float(warmup) / float(arr.shape[0])
else:
warmup_frac = float(warmup)
warmup_frac = min(max(warmup_frac, 0.0), 1.0)
drop = min(int(round(warmup_frac * arr.shape[0])), arr.shape[0])
if drop > 0:
arr = arr[drop:]
steps = steps[drop:]
if len(arr.shape) == 2:
fig, axes = ezplot.plot_combined(
val,
key=key,
num_chains=num_chains,
plot_kwargs=plot_kwargs,
subplots_kwargs=subplots_kwargs,
)
else:
if len(arr.shape) == 1:
fig, ax = ezplot.subplots(**subplots_kwargs)
try:
ax.plot(steps, arr, **plot_kwargs)
except ValueError:
try:
ax.plot(steps, arr[~np.isnan(arr)], **plot_kwargs)
except Exception:
logger.error(f"Unable to plot {key}! Continuing")
_ = ax.grid(True, alpha=0.2)
axes = ax
elif len(arr.shape) == 3:
fig, ax = ezplot.subplots(**subplots_kwargs)
cmap = plt.get_cmap("viridis")
y = val.mean("chain")
for idx in range(len(val.coords["leapfrog"])):
pkwargs = {
"color": cmap(idx / len(val.coords["leapfrog"])),
"label": f"{idx}",
}
ax.plot(steps, y[idx], **pkwargs)
axes = ax
else:
raise ValueError("Unexpected shape encountered")
ax = plt.gca()
# assert isinstance(ax, plt.Axes)
assert key is not None
_ = ax.set_ylabel(key)
_ = ax.set_xlabel("step")
# if num_chains > 0 and len(arr.shape) > 1:
# lw = LW / 2.
# #for idx in range(min(num_chains, arr.shape[1])):
# nchains = len(val.coords['chains'])
# for idx in range(min(nchains, num_chains)):
# # ax = subfigs[0].subplots(1, 1)
# # plot values of invidual chains, arr[:, idx]
# # where arr[:, idx].shape = [ndraws, 1]
# ax.plot(steps, val
# alpha=0.5, lw=lw/2., **plot_kwargs)
if title is not None:
fig = plt.gcf()
_ = fig.suptitle(title)
if logfreq is not None:
ax = plt.gca()
xticks = ax.get_xticks() # type: ignore
_ = ax.set_xticklabels( # type: ignore
[f"{logfreq * int(i)}" for i in xticks] # type: ignore
)
save_dir: Optional[Path]
if outdir is not None:
save_dir = Path(outdir).expanduser().resolve()
elif self.report_enabled:
save_dir = self._report_dir.joinpath("dataarray")
else:
save_dir = None
primary_asset: Optional[Path] = None
if save_dir is not None:
dirs = {
"png": Path(save_dir).joinpath("pngs/"),
"svg": Path(save_dir).joinpath("svgs/"),
}
_ = [i.mkdir(exist_ok=True, parents=True) for i in dirs.values()]
if verbose:
logger.info(
f"Saving {key} plot to: {Path(save_dir).resolve()}"
)
for ext, d in dirs.items():
outfile = d.joinpath(f"{key}.{ext}")
plt.savefig(outfile, dpi=400, bbox_inches="tight")
if primary_asset is None and ext == "png":
primary_asset = outfile
if (
self.report_enabled
and primary_asset is not None
and Path(primary_asset).exists()
):
metadata = {"dims": list(val.dims)}
self._write_plot_report(
key,
primary_asset,
kind="dataarray",
metadata=metadata,
)
self._wandb_log_matplotlib_asset(key, primary_asset, kind="dataarray")
return (fig, subfigs, axes)
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot_dataset(
self,
title: Optional[str] = None,
nchains: Optional[int] = None,
outdir: Optional[os.PathLike] = None,
dataset: Optional[xr.Dataset] = None,
data: Optional[dict] = None,
warmup: Optional[int | float] = None,
# subplots_kwargs: Optional[dict[str, Any]] = None,
# plot_kwargs: Optional[dict[str, Any]] = None,
):
"""Plot the full xarray Dataset via ``ezplot.plot_dataset``."""
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
return ezplot.plot_dataset(
dataset=dataset,
nchains=nchains,
title=title,
outdir=outdir,
)
def plot_2d_xarr(
self,
xarr: xr.DataArray,
label: Optional[str] = None,
num_chains: Optional[int] = None,
title: Optional[str] = None,
outdir: Optional[os.PathLike] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
):
"""Plot a 2D xarray DataArray (chain x draw) with matplotlib/seaborn."""
import matplotlib.pyplot as plt
import seaborn as sns
LW = plt.rcParams.get("axes.linewidth", 1.75)
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
assert len(xarr.shape) == 2
assert "draw" in xarr.coords and "chain" in xarr.coords
num_chains = len(xarr.chain) if num_chains is None else num_chains
# _ = subplots_kwargs.pop('constrained_layout', True)
figsize = plt.rcParams.get("figure.figsize", (8, 6))
figsize = (3 * figsize[0], 1.5 * figsize[1])
fig = plt.figure(figsize=figsize, constrained_layout=True)
subfigs = fig.subfigures(1, 2)
gs_kw = {"width_ratios": [1.33, 0.33]}
(ax, ax1) = subfigs[1].subplots(1, 2, sharey=True, gridspec_kw=gs_kw)
ax.grid(alpha=0.2)
ax1.grid(False)
color = plot_kwargs.get("color", f"C{np.random.randint(6)}")
label = r"$\langle$" + f" {label} " + r"$\rangle$"
ax.plot(
xarr.draw.values,
xarr.mean("chain"),
color=color,
lw=1.5 * LW,
label=label,
**plot_kwargs,
)
for idx in range(num_chains):
# ax = subfigs[0].subplots(1, 1)
# plot values of invidual chains, arr[:, idx]
# where arr[:, idx].shape = [ndraws, 1]
# ax0.plot(
# xarr.draw.values,
# xarr[xarr.chain == idx][0],
# lw=1.,
# alpha=0.7,
# color=color
# )
ax.plot(
xarr.draw.values,
xarr[xarr.chain == idx][0],
color=color,
alpha=0.5,
lw=LW / 2.0,
**plot_kwargs,
)
axes = (ax, ax1)
sns.kdeplot(y=xarr.values.flatten(), ax=ax1, color=color, shade=True)
ax1.set_xticks([])
ax1.set_xticklabels([])
# ax1.set_yticks([])
# ax1.set_yticklabels([])
sns.despine(ax=ax, top=True, right=True)
sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
# ax.legend(loc='best', frameon=False)
ax1.set_xlabel("")
# ax1.set_ylabel('')
# ax.set_yticks(ax.get_yticks())
# ax.set_yticklabels(ax.get_yticklabels())
# ax.set_ylabel(key)
# _ = subfigs[1].subplots_adjust(wspace=-0.75)
# if num_chains > 0 and len(arr.shape) > 1:
# lw = LW / 2.
# num_chains = np.min([
# 16,
# len(xarr.coords['chain']),
# ])
sns.despine(subfigs[0])
ax0 = subfigs[0].subplots(1, 1)
im = xarr.plot(ax=ax0) # type:ignore
im.colorbar.set_label(label) # type:ignore
# ax0.plot(
# xarr.draw.values,
# xarr.mean('chain'),
# lw=2.,
# color=color
# )
# for idx in range(min(num_chains, i.shape[1])):
ax.set_xlabel("step")
if title is not None:
fig.suptitle(title)
if outdir is not None:
assert label is not None
# plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
# dpi=400, bbox_inches='tight')
outfile = Path(outdir).joinpath(f"{label}.svg")
if outfile.is_file():
tstamp = get_timestamp("%Y-%m-%d-%H%M%S")
pngdir = Path(outdir).joinpath("pngs")
pngdir.mkdir(exist_ok=True, parents=True)
pngfile = pngdir.joinpath(f"{label}-{tstamp}.png")
svgfile = Path(outdir).joinpath(f"{label}-{tstamp}.svg")
plt.savefig(pngfile, dpi=400, bbox_inches="tight")
plt.savefig(svgfile, dpi=400, bbox_inches="tight")
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def tplot_all(
self,
outdir: Optional[os.PathLike] = None,
warmup: Optional[float] = 0.0,
append: bool = True,
xkey: Optional[str] = None,
dataset: Optional[xr.Dataset] = None,
data: Optional[dict] = None,
logfreq: Optional[int] = None,
plot_type: Optional[str] = None,
verbose: bool = False,
group_prefix: str = "",
):
"""Create terminal plots for all metrics using plotext.
Counter metrics (iter, epoch, step, etc.) are skipped.
"""
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
outdir_path = Path(os.getcwd()) if outdir is None else Path(outdir)
groups = self._group_metric_variables(dataset)
for metric_name, metric_vars in sorted(groups.items()):
parts = metric_name.replace("\\", "/").split("/")
last = parts[-1] if parts else metric_name
if last in {"iter", "epoch", "step", "batch", "idx", "bidx"}:
continue
if any(
last.endswith(f"_{token}")
for token in ("iter", "epoch", "step", "batch", "idx", "bidx")
):
continue
if (xkey is not None and metric_name == xkey) or xkey in [
"iter",
"draw",
]:
continue
display_name = (
f"{group_prefix}/{metric_name}"
if group_prefix
else metric_name
)
self._tplot_metric_group(
display_name,
metric_vars,
warmup=warmup,
outdir=outdir_path,
plot_type=plot_type,
verbose=verbose,
logfreq=logfreq,
)
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot_all(
self,
num_chains: int = 128,
warmup: Optional[float | int] = 0.0,
title: Optional[str] = None,
verbose: bool = False,
outdir: Optional[os.PathLike] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
dataset: Optional[xr.Dataset] = None,
data: Optional[dict] = None,
group_prefix: str = "",
):
"""Create matplotlib ridge plots for all metrics in the dataset."""
plot_kwargs = {} if plot_kwargs is None else dict(plot_kwargs)
subplots_kwargs = (
{} if subplots_kwargs is None else dict(subplots_kwargs)
)
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
_ = ezplot.make_ridgeplots(
dataset,
outdir=outdir,
drop_nans=True,
drop_zeros=False,
num_chains=num_chains,
cmap="viridis",
save_plot=(outdir is not None),
)
groups = self._group_metric_variables(dataset)
for idx, (metric_name, metric_vars) in enumerate(
sorted(groups.items())
):
display_name = (
f"{group_prefix}/{metric_name}"
if group_prefix
else metric_name
)
plot_kwargs["color"] = f"C{idx % 9}"
asset = self._plot_metric_group(
display_name,
metric_vars,
warmup=warmup,
title=title,
outdir=Path(outdir) if outdir is not None else None,
subplots_kwargs=subplots_kwargs,
plot_kwargs=plot_kwargs,
verbose=verbose,
)
self._wandb_log_matplotlib_asset(
metric_name, asset, kind="matplotlib"
)
if asset is not None and self.report_enabled and asset.exists():
components = sorted(metric_vars.keys())
sample_series = self._series_from_dataarray(
metric_vars[components[0]]
)
self._write_plot_report(
metric_name,
asset,
kind="matplotlib",
metadata={
"components": ", ".join(components),
"points": len(sample_series),
},
)
return dataset
def history_to_dict(self) -> dict:
"""Convert internal history to a dictionary of numpy arrays."""
# return {k: np.stack(v).squeeze() for k, v in self.history.items()}
return {
k: torch.Tensor(v).detach().numpy(force=True)
for k, v in self.history.items()
}
def to_DataArray(
self,
x: Union[list, np.ndarray, torch.Tensor],
warmup: Optional[float] = 0.0,
) -> xr.DataArray:
"""Convert a list, array, or tensor to an xarray DataArray.
Args:
x: Input data (1D, 2D, or 3D).
warmup: Fraction of initial samples to drop (0.0 to 1.0).
"""
if isinstance(x, tuple):
x = list(x)
if (
isinstance(x, list)
and len(x) > 0
and isinstance(x[0], torch.Tensor)
):
x = torch.Tensor(x).detach().numpy(force=True)
try:
arr = grab_tensor(x)
except ValueError:
arr = np.array(x).real
# arr = np.array(x)
logger.info(f"len(x): {len(x)}")
x0_shape = getattr(x[0], "shape", None) if len(x) > 0 else None
logger.info(f"x[0].shape: {x0_shape}")
logger.info(f"arr.shape: {arr.shape}")
assert isinstance(arr, np.ndarray)
if warmup is not None and warmup > 0 and len(arr) > 0:
if isinstance(warmup, int):
warmup = warmup / len(arr)
# drop = int(warmup * arr.shape[0])
drop = int(warmup * len(arr))
arr = arr[drop:]
# steps = np.arange(len(arr))
if len(arr.shape) == 1: # [ndraws]
ndraws = arr.shape[0]
dims = ["draw"]
coords = [np.arange(len(arr))]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore
if len(arr.shape) == 2: # [nchains, ndraws]
arr = arr.T
nchains, ndraws = arr.shape
dims = ("chain", "draw")
coords = [np.arange(nchains), np.arange(ndraws)]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore
if len(arr.shape) == 3: # [nchains, nlf, ndraws]
arr = arr.T
nchains, nlf, ndraws = arr.shape
dims = ("chain", "leapfrog", "draw")
coords = [np.arange(nchains), np.arange(nlf), np.arange(ndraws)]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore
else:
print(f"arr.shape: {arr.shape}")
raise ValueError("Invalid shape encountered")
def get_grouped_datasets(
self,
warmup: Optional[float] = 0.0,
) -> dict[str, xr.Dataset]:
"""Build one xarray Dataset per metric group (prefix).
Each group's metrics share the same ``draw`` dimension, so
``train/`` and ``eval/`` metrics get independent lengths instead
of being padded to the longest array.
Returns:
Dict mapping group prefix (``""`` for unprefixed) to Dataset.
"""
datasets: dict[str, xr.Dataset] = {}
for prefix, group_data in self._groups.items():
data_vars: dict[str, xr.DataArray] = {}
for key, val_list in group_data.items():
name = key.replace("/", "_")
try:
arr = torch.Tensor(val_list).detach().numpy(force=True)
data_vars[name] = self.to_DataArray(arr, warmup)
except (ValueError, RuntimeError):
logger.error(
"Unable to create DataArray for %s/%s! Skipping!",
prefix,
key,
)
if data_vars:
datasets[prefix] = xr.Dataset(data_vars)
return datasets
def get_dataset(
self,
data: Optional[
dict[str, Union[list, np.ndarray, torch.Tensor]]
] = None,
warmup: Optional[float] = 0.0,
):
"""Build a single xarray Dataset from the history data.
For grouped datasets with independent dimensions, use
:meth:`get_grouped_datasets` instead.
Args:
data: Dict of metric arrays; defaults to ``self.history``.
warmup: Fraction of initial samples to drop.
"""
data = self.history_to_dict() if data is None else data
data_vars = {}
for key, val in data.items():
name = key.replace("/", "_")
try:
data_vars[name] = self.to_DataArray(val, warmup)
except ValueError:
logger.error(
f"Unable to create DataArray for {key}! Skipping!"
)
logger.error(f"{key}.shape= {np.stack(val).shape}") # type:ignore
return xr.Dataset(data_vars)
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def save_dataset(
self,
outdir: PathLike,
fname: str = "dataset",
use_hdf5: bool = True,
data: Optional[
dict[str, Union[list, np.ndarray, torch.Tensor]]
] = None,
dataset: Optional[xr.Dataset] = None,
warmup: Optional[int | float] = None,
**kwargs,
) -> Path:
"""Save the history dataset to disk.
Args:
outdir: Directory to write the dataset file.
fname: Base filename (default ``"dataset"``).
use_hdf5: If True, save as HDF5; otherwise NetCDF.
Returns:
Path to the saved file.
"""
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
if dataset is not None:
dset_name = f"{fname}_dataset" if fname != "dataset" else fname
try:
dataframe = dataset.to_dataframe()
columns = list(dataframe.columns)
data_rows = dataframe.values.tolist()
self._tracker.log_table(
dset_name, columns=columns, data=data_rows
)
except Exception as e:
logger.warning(
"Unable to log dataset table via tracker: %s", e
)
return save_dataset(
dataset,
outdir=outdir,
fname=fname,
use_hdf5=use_hdf5,
**kwargs,
)
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def finalize(
self,
outdir: Optional[PathLike] = None,
run_name: Optional[str] = None,
dataset_fname: Optional[str] = None,
num_chains: int = 128,
warmup: Optional[int | float] = 0.0,
verbose: bool = False,
save: bool = True,
plot: bool = True,
append_tplot: bool = True,
title: Optional[str] = None,
data: Optional[
dict[str, Union[list, np.ndarray, torch.Tensor]]
] = None,
dataset: Optional[xr.Dataset] = None,
xkey: Optional[str] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
tplot_type: Optional[str] = None,
env_info: Optional[dict[str, Any]] = None,
timings: Optional[dict[str, float]] = None,
) -> dict[str, xr.Dataset] | xr.Dataset:
"""End-of-training cleanup: save dataset, generate plots, log artifacts.
Returns:
Dict mapping group prefix to xarray Dataset (one per group).
If no groups exist, returns a single flat Dataset for backward
compat.
"""
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
run_name = (
f"History-{get_timestamp()}" if run_name is None else run_name
)
if outdir is None:
base_dir = (
Path(os.getcwd())
.joinpath("outputs", run_name, get_timestamp())
.expanduser()
.resolve()
)
else:
base_dir = Path(outdir).expanduser().resolve()
base_dir.mkdir(parents=True, exist_ok=True)
# Redirect CSV backend to base_dir so all output is co-located
_csv_be = self._tracker.get_backend("csv")
if _csv_be is not None and hasattr(_csv_be, "_csv_path"):
old_csv = _csv_be._csv_path
_csv_be._outdir = base_dir # type: ignore[attr-defined]
_csv_be._csv_path = base_dir / "metrics.csv" # type: ignore[attr-defined]
if old_csv.exists() and old_csv != _csv_be._csv_path:
try:
old_csv.unlink()
except OSError:
pass
# Redirect JSONL to base_dir so all output is co-located.
#
# Hold the JSONL lock for the entire move so a concurrent
# _write_jsonl_record() either appends to the old path before
# the move (and we relocate that record) or to the new path
# after the swap. Without this lock the writer could open a
# handle to old_jsonl, the move could complete between
# open/write, and the record would land in the deleted inode.
if self._jsonl_enabled and self._jsonl_path is not None:
with self._jsonl_lock:
old_jsonl = self._jsonl_path
new_jsonl = base_dir / old_jsonl.name
if old_jsonl != new_jsonl:
if old_jsonl.exists():
try:
new_jsonl.parent.mkdir(
parents=True, exist_ok=True,
)
# shutil.move falls back to copy+remove on
# cross-filesystem moves (e.g. /tmp β
# Lustre), which is non-atomic. That's
# OK because the lock prevents writers
# from racing it. We swap the path
# *only after* the move succeeds so a
# mid-move failure leaves the writer
# pointing at the still-valid old file.
shutil.move(str(old_jsonl), str(new_jsonl))
self._jsonl_path = new_jsonl
except OSError as exc:
logger.warning(
"Failed to relocate JSONL %s -> %s: %s",
old_jsonl, new_jsonl, exc,
)
else:
# Nothing to move (no metrics ever logged).
# Still update the path so future writes go
# to the co-located location.
self._jsonl_path = new_jsonl
dataset_label = (
dataset_fname if dataset_fname is not None else "dataset"
)
report_dir = (
base_dir.joinpath(dataset_label)
if dataset_fname is not None
else base_dir
)
if dataset_fname is not None:
self._report_filename = f"report-{dataset_label}.md"
self._configure_report_destination(report_dir)
env_details = (
env_info
if env_info is not None
else self._default_environment_info()
)
if timings:
env_details["Timings"] = {
k: f"{v:.2f}s" for k, v in timings.items()
}
paths: dict[str, str] = {}
existing_paths = env_details.get("Paths")
if isinstance(existing_paths, dict):
paths.update(existing_paths)
paths.setdefault("Working Directory", str(Path.cwd()))
paths["Output Directory"] = str(base_dir)
output_files: dict[str, str] = {
"Output Directory": str(base_dir),
}
if self.report_enabled:
paths["Report"] = str(report_dir / self._report_filename)
output_files["Report"] = paths["Report"]
plotdir = None
if plot:
plotdir = (
base_dir.joinpath("plots", dataset_label)
if dataset_fname is not None
else base_dir.joinpath("plots")
)
paths["Plots (matplotlib)"] = str(plotdir / "mplot")
paths["Plots (terminal)"] = str(plotdir / "tplot")
output_files["Plots (matplotlib)"] = paths["Plots (matplotlib)"]
output_files["Plots (terminal)"] = paths["Plots (terminal)"]
json_log = get_json_log_file()
if json_log is not None and json_log.exists():
link_path = base_dir / json_log.name
if not link_path.exists():
try:
link_path.symlink_to(json_log.resolve())
except OSError:
pass
# Report the symlink inside the output dir (co-located)
reported = link_path if link_path.exists() else json_log
paths["JSON Log"] = str(reported)
output_files["JSON Log"] = paths["JSON Log"]
if self._jsonl_path is not None:
paths["Metrics JSONL"] = str(self._jsonl_path)
output_files["Metrics JSONL"] = paths["Metrics JSONL"]
if _csv_be is not None and hasattr(_csv_be, "_csv_path"):
paths["Metrics CSV"] = str(_csv_be._csv_path)
output_files["Metrics CSV"] = paths["Metrics CSV"]
env_details["Paths"] = paths
self._write_environment_section(env_details)
self._write_metric_summary(dataset)
if plot and plotdir is not None:
logger.info(
"Saving plots to %s (matplotlib) and %s (tplot)",
plotdir.joinpath("mplot"),
plotdir.joinpath("tplot"),
)
tplotdir = plotdir.joinpath("tplot")
mplotdir = plotdir.joinpath("mplot")
tplotdir.mkdir(exist_ok=True, parents=True)
mplotdir.mkdir(exist_ok=True, parents=True)
# Plot each metric group independently so train/ and eval/
# metrics get their own x-axis dimension.
grouped = self.get_grouped_datasets(warmup=warmup)
if not grouped:
# Fallback: use the flat dataset if no groups exist
grouped = {"": dataset}
for group_prefix, group_ds in sorted(grouped.items()):
group_suffix = f"_{group_prefix}" if group_prefix else ""
group_tplotdir = (
tplotdir / group_prefix if group_prefix else tplotdir
)
group_mplotdir = (
mplotdir / group_prefix if group_prefix else mplotdir
)
group_tplotdir.mkdir(exist_ok=True, parents=True)
group_mplotdir.mkdir(exist_ok=True, parents=True)
group_title = (
f"{title} [{group_prefix}]"
if title and group_prefix
else (group_prefix or title)
)
_ = self.plot_all(
dataset=group_ds,
outdir=group_mplotdir,
verbose=verbose,
num_chains=num_chains,
warmup=0.0, # already applied in get_grouped_datasets
title=group_title or None,
plot_kwargs=plot_kwargs,
subplots_kwargs=subplots_kwargs,
group_prefix=group_prefix,
)
_ = self.tplot_all(
dataset=group_ds,
outdir=group_tplotdir,
warmup=0.0, # already applied
append=append_tplot,
plot_type=tplot_type,
xkey=xkey,
verbose=verbose,
group_prefix=group_prefix,
)
if save:
try:
import h5py
use_hdf5 = True
except ImportError:
logger.warning(
"h5py not found! Saving dataset as netCDF instead."
)
use_hdf5 = False
ext = ".h5" if use_hdf5 else ".nc"
grouped = self.get_grouped_datasets(warmup=warmup)
if len(grouped) > 1:
# Save one dataset per group (no NaN padding)
for gprefix, gds in grouped.items():
label = gprefix if gprefix else (dataset_fname or "dataset")
_ = self.save_dataset(
dataset=gds,
outdir=base_dir,
fname=label,
use_hdf5=use_hdf5,
)
output_files[f"Dataset ({label})"] = str(
base_dir / f"{label}{ext}"
)
else:
# Single group or no groups: save as a single flat dataset
fname = "dataset" if dataset_fname is None else dataset_fname
ds_to_save = (
next(iter(grouped.values()))
if grouped
else dataset
)
_ = self.save_dataset(
dataset=ds_to_save,
outdir=base_dir,
fname=fname,
use_hdf5=use_hdf5,
)
output_files["Dataset"] = str(base_dir / f"{fname}{ext}")
if self.report_enabled:
logger.info(
"Saving history report to %s",
self._report_dir.joinpath(self._report_filename),
)
_wandb_run = self._tracker.wandb_run
if _wandb_run is not None:
logger.info(f"wandb.run=[{_wandb_run.name}]({_wandb_run.url})")
_mlflow_be = self._tracker.get_backend("mlflow")
if _mlflow_be is not None and getattr(_mlflow_be, "_active", False):
_run_url = getattr(_mlflow_be, "run_url", None)
_run_id = getattr(_mlflow_be, "_run_id", "?")
if _run_url:
logger.info("mlflow.run=[%s](%s)", _run_id, _run_url)
else:
logger.info(
"mlflow.run=%s (tracking_uri=%s)",
_run_id,
getattr(_mlflow_be, "_tracking_uri", "?"),
)
if self.history:
try:
columns = list(self.history.keys())
max_len = max(len(v) for v in self.history.values())
table_data = []
for i in range(max_len):
row = [
self.history[col][i]
if i < len(self.history[col])
else None
for col in columns
]
table_data.append(row)
self._tracker.log_table(
"training_history", columns=columns, data=table_data
)
except Exception:
logger.warning(
"Failed to log training history table via tracker"
)
if output_files:
# Upload output files as artifacts (MLflow, etc.) before finish
self._tracker.log_artifacts(output_files)
logger.info("Output files:")
for label, fpath in output_files.items():
logger.info(" %s: %s", label, fpath)
self._tracker.finish()
grouped = self.get_grouped_datasets(warmup=warmup)
if len(grouped) > 1:
for gname, gds in sorted(grouped.items()):
label = gname if gname else "default"
logger.info("[%s] %s", label, gds)
return grouped
logger.info("%s", dataset)
return dataset
data
property
writable
βοΈ
Alias for :attr:history (backward compat).
groups
property
βοΈ
Access the prefix-grouped metrics directly.
history
property
writable
βοΈ
Flattened view of all metric groups.
Returns a dict keyed by full metric names (e.g. "train/loss",
"loss"), preserving backward compatibility with code that reads
history.history["key"].
tracker
property
βοΈ
The internal Tracker instance.
Use this to access backend-specific features::
wb = history.tracker.get_backend("wandb")
if wb is not None:
wb.watch(model, log="all")
__init__(keys=None, *, report_dir=None, report_enabled=True, jsonl_path=None, jsonl_overwrite=False, distributed_history=AUTO_USE_DISTRIBUTED_HISTORY, project_name=None, backends=None, config=None, outdir=None, tracker=None)
βοΈ
Initialize the History object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
keys
|
Optional[list[str]]
|
List of keys to initialize the history with. If None, initializes with an empty list. |
None
|
report_dir
|
Optional[PathLike]
|
Directory for markdown reports. Defaults
to |
None
|
report_enabled
|
bool
|
Toggle automatic markdown generation. |
True
|
jsonl_path
|
Optional[PathLike]
|
Destination for JSONL metric logging. |
None
|
jsonl_overwrite
|
bool
|
Whether to truncate an existing JSONL log. |
False
|
distributed_history
|
bool
|
Enable distributed history tracking. |
AUTO_USE_DISTRIBUTED_HISTORY
|
project_name
|
Optional[str]
|
Project name for tracker backends (e.g. wandb). |
None
|
backends
|
Optional[str | Sequence[str]]
|
Comma-separated string or list
of tracker backend names (e.g. |
None
|
config
|
Optional[dict[str, Any]]
|
Run-level config (hyperparameters) to log via the tracker. |
None
|
outdir
|
Optional[PathLike]
|
Output directory for file-based tracker backends (e.g. CSV). |
None
|
tracker
|
Optional[Tracker]
|
Inject a pre-built Tracker instance directly. |
None
|
Source code in src/ezpz/history.py
def __init__(
self,
keys: Optional[list[str]] = None,
*,
report_dir: Optional[PathLike] = None,
report_enabled: bool = True,
jsonl_path: Optional[PathLike] = None,
jsonl_overwrite: bool = False,
distributed_history: bool = AUTO_USE_DISTRIBUTED_HISTORY,
project_name: Optional[str] = None,
backends: Optional[str | Sequence[str]] = None,
config: Optional[dict[str, Any]] = None,
outdir: Optional[PathLike] = None,
tracker: Optional[Tracker] = None,
) -> None:
"""
Initialize the History object.
Args:
keys (Optional[list[str]]): List of keys to initialize the history with.
If None, initializes with an empty list.
report_dir (Optional[PathLike]): Directory for markdown reports. Defaults
to ``OUTPUTS_DIR/history``.
report_enabled (bool): Toggle automatic markdown generation.
jsonl_path (Optional[PathLike]): Destination for JSONL metric logging.
jsonl_overwrite (bool): Whether to truncate an existing JSONL log.
distributed_history (bool): Enable distributed history tracking.
project_name (Optional[str]): Project name for tracker backends (e.g. wandb).
backends (Optional[str | Sequence[str]]): Comma-separated string or list
of tracker backend names (e.g. ``"wandb,csv"``).
config (Optional[dict[str, Any]]): Run-level config (hyperparameters) to
log via the tracker.
outdir (Optional[PathLike]): Output directory for file-based tracker
backends (e.g. CSV).
tracker (Optional[Tracker]): Inject a pre-built Tracker instance directly.
"""
self.keys = [] if keys is None else keys
self._groups: dict[str, dict[str, list[Any]]] = {}
self._flat_cache: dict[str, list[Any]] | None = None
if (
os.environ.get("EZPZ_NO_DISTRIBUTED_HISTORY", None)
or os.environ.get("EZPZ_LOCAL_HISTORY", False)
or ezpz.distributed.get_world_size() <= 1
):
logger.info(
"Not using distributed metrics! Will only be tracked from a single rank..."
)
distributed_history = False
# aggregate_metrics = False
self.distributed_history = distributed_history
logger.info(
f"Using {self.__class__.__name__} with distributed_history={self.distributed_history}"
)
# self._aggregate_metrics = aggregate_metrics
self._rank = get_rank()
now = datetime.now(timezone.utc)
self._run_id = now.strftime("%Y%m%d-%H%M%S")
self.report_enabled = report_enabled
base_report_root = (
Path(report_dir)
if report_dir is not None
else Path(OUTPUTS_DIR).joinpath("history")
)
self._report_root = Path(base_report_root).expanduser().resolve()
self._report_dir = self._report_root.joinpath(self._run_id)
self._report_path: Optional[Path] = None
self._asset_dir: Optional[Path] = None
self._report_filename = "report.md"
self._report_initialized = False
self._jsonl_path_explicit = jsonl_path is not None
if jsonl_path is None:
default_jsonl_dir = (
self._report_dir if report_enabled else Path(OUTPUTS_DIR)
)
self._jsonl_path = (
Path(default_jsonl_dir)
.expanduser()
.resolve()
.joinpath(f"{self._run_id}.jsonl")
)
else:
self._jsonl_path = Path(jsonl_path).expanduser().resolve()
if jsonl_overwrite and self._jsonl_path.exists():
try:
self._jsonl_path.unlink()
except OSError:
logger.warning(
"Unable to remove existing JSONL log at %s",
self._jsonl_path,
)
self._jsonl_enabled = True
# Serializes JSONL writes against the finalize() move so a
# background update() can't race the cross-FS shutil.move
# and either (a) write to a half-moved file or (b) lose the
# in-flight record altogether.
self._jsonl_lock = threading.Lock()
self._dist = torch.distributed
self._environment_written = False
self._metric_summary_written = False
# -- Tracker integration --
if tracker is not None:
self._tracker: Tracker = tracker
elif any(
arg is not None for arg in (project_name, backends)
) or os.environ.get("EZPZ_TRACKER_BACKENDS"):
self._tracker = setup_tracker(
project_name=project_name,
backends=backends,
config=config,
outdir=str(outdir) if outdir is not None else None,
)
else:
# Backward compat: auto-detect existing wandb.run
if wandb is not None and getattr(wandb, "run", None) is not None:
warnings.warn(
"History detected an active wandb.run but no 'backends' "
"argument was provided. Automatically using "
"backends='wandb'. In a future version, pass "
"backends='wandb' explicitly.",
DeprecationWarning,
stacklevel=2,
)
self._tracker = setup_tracker(backends="wandb")
else:
self._tracker = NullTracker()
# Forward config to the tracker when the backends didn't receive it
# in their constructors. This covers:
# - Injected tracker= (backends never saw config)
# - Auto-detect wandb.run path (setup_tracker called without config)
# The setup_tracker(config=...) path already handles config internally,
# so we skip it there to avoid duplicates.
_tracker_got_config = tracker is None and (
any(arg is not None for arg in (project_name, backends))
or os.environ.get(
"EZPZ_TRACKER_BACKENDS", os.environ.get("EZPZ_TRACKER_BACKEND")
)
)
if config is not None and not _tracker_got_config:
self._tracker.log_config(config)
finalize(outdir=None, run_name=None, dataset_fname=None, num_chains=128, warmup=0.0, verbose=False, save=True, plot=True, append_tplot=True, title=None, data=None, dataset=None, xkey=None, plot_kwargs=None, subplots_kwargs=None, tplot_type=None, env_info=None, timings=None)
βοΈ
End-of-training cleanup: save dataset, generate plots, log artifacts.
Returns:
| Type | Description |
|---|---|
dict[str, Dataset] | Dataset
|
Dict mapping group prefix to xarray Dataset (one per group). |
dict[str, Dataset] | Dataset
|
If no groups exist, returns a single flat Dataset for backward |
dict[str, Dataset] | Dataset
|
compat. |
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def finalize(
self,
outdir: Optional[PathLike] = None,
run_name: Optional[str] = None,
dataset_fname: Optional[str] = None,
num_chains: int = 128,
warmup: Optional[int | float] = 0.0,
verbose: bool = False,
save: bool = True,
plot: bool = True,
append_tplot: bool = True,
title: Optional[str] = None,
data: Optional[
dict[str, Union[list, np.ndarray, torch.Tensor]]
] = None,
dataset: Optional[xr.Dataset] = None,
xkey: Optional[str] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
tplot_type: Optional[str] = None,
env_info: Optional[dict[str, Any]] = None,
timings: Optional[dict[str, float]] = None,
) -> dict[str, xr.Dataset] | xr.Dataset:
"""End-of-training cleanup: save dataset, generate plots, log artifacts.
Returns:
Dict mapping group prefix to xarray Dataset (one per group).
If no groups exist, returns a single flat Dataset for backward
compat.
"""
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
run_name = (
f"History-{get_timestamp()}" if run_name is None else run_name
)
if outdir is None:
base_dir = (
Path(os.getcwd())
.joinpath("outputs", run_name, get_timestamp())
.expanduser()
.resolve()
)
else:
base_dir = Path(outdir).expanduser().resolve()
base_dir.mkdir(parents=True, exist_ok=True)
# Redirect CSV backend to base_dir so all output is co-located
_csv_be = self._tracker.get_backend("csv")
if _csv_be is not None and hasattr(_csv_be, "_csv_path"):
old_csv = _csv_be._csv_path
_csv_be._outdir = base_dir # type: ignore[attr-defined]
_csv_be._csv_path = base_dir / "metrics.csv" # type: ignore[attr-defined]
if old_csv.exists() and old_csv != _csv_be._csv_path:
try:
old_csv.unlink()
except OSError:
pass
# Redirect JSONL to base_dir so all output is co-located.
#
# Hold the JSONL lock for the entire move so a concurrent
# _write_jsonl_record() either appends to the old path before
# the move (and we relocate that record) or to the new path
# after the swap. Without this lock the writer could open a
# handle to old_jsonl, the move could complete between
# open/write, and the record would land in the deleted inode.
if self._jsonl_enabled and self._jsonl_path is not None:
with self._jsonl_lock:
old_jsonl = self._jsonl_path
new_jsonl = base_dir / old_jsonl.name
if old_jsonl != new_jsonl:
if old_jsonl.exists():
try:
new_jsonl.parent.mkdir(
parents=True, exist_ok=True,
)
# shutil.move falls back to copy+remove on
# cross-filesystem moves (e.g. /tmp β
# Lustre), which is non-atomic. That's
# OK because the lock prevents writers
# from racing it. We swap the path
# *only after* the move succeeds so a
# mid-move failure leaves the writer
# pointing at the still-valid old file.
shutil.move(str(old_jsonl), str(new_jsonl))
self._jsonl_path = new_jsonl
except OSError as exc:
logger.warning(
"Failed to relocate JSONL %s -> %s: %s",
old_jsonl, new_jsonl, exc,
)
else:
# Nothing to move (no metrics ever logged).
# Still update the path so future writes go
# to the co-located location.
self._jsonl_path = new_jsonl
dataset_label = (
dataset_fname if dataset_fname is not None else "dataset"
)
report_dir = (
base_dir.joinpath(dataset_label)
if dataset_fname is not None
else base_dir
)
if dataset_fname is not None:
self._report_filename = f"report-{dataset_label}.md"
self._configure_report_destination(report_dir)
env_details = (
env_info
if env_info is not None
else self._default_environment_info()
)
if timings:
env_details["Timings"] = {
k: f"{v:.2f}s" for k, v in timings.items()
}
paths: dict[str, str] = {}
existing_paths = env_details.get("Paths")
if isinstance(existing_paths, dict):
paths.update(existing_paths)
paths.setdefault("Working Directory", str(Path.cwd()))
paths["Output Directory"] = str(base_dir)
output_files: dict[str, str] = {
"Output Directory": str(base_dir),
}
if self.report_enabled:
paths["Report"] = str(report_dir / self._report_filename)
output_files["Report"] = paths["Report"]
plotdir = None
if plot:
plotdir = (
base_dir.joinpath("plots", dataset_label)
if dataset_fname is not None
else base_dir.joinpath("plots")
)
paths["Plots (matplotlib)"] = str(plotdir / "mplot")
paths["Plots (terminal)"] = str(plotdir / "tplot")
output_files["Plots (matplotlib)"] = paths["Plots (matplotlib)"]
output_files["Plots (terminal)"] = paths["Plots (terminal)"]
json_log = get_json_log_file()
if json_log is not None and json_log.exists():
link_path = base_dir / json_log.name
if not link_path.exists():
try:
link_path.symlink_to(json_log.resolve())
except OSError:
pass
# Report the symlink inside the output dir (co-located)
reported = link_path if link_path.exists() else json_log
paths["JSON Log"] = str(reported)
output_files["JSON Log"] = paths["JSON Log"]
if self._jsonl_path is not None:
paths["Metrics JSONL"] = str(self._jsonl_path)
output_files["Metrics JSONL"] = paths["Metrics JSONL"]
if _csv_be is not None and hasattr(_csv_be, "_csv_path"):
paths["Metrics CSV"] = str(_csv_be._csv_path)
output_files["Metrics CSV"] = paths["Metrics CSV"]
env_details["Paths"] = paths
self._write_environment_section(env_details)
self._write_metric_summary(dataset)
if plot and plotdir is not None:
logger.info(
"Saving plots to %s (matplotlib) and %s (tplot)",
plotdir.joinpath("mplot"),
plotdir.joinpath("tplot"),
)
tplotdir = plotdir.joinpath("tplot")
mplotdir = plotdir.joinpath("mplot")
tplotdir.mkdir(exist_ok=True, parents=True)
mplotdir.mkdir(exist_ok=True, parents=True)
# Plot each metric group independently so train/ and eval/
# metrics get their own x-axis dimension.
grouped = self.get_grouped_datasets(warmup=warmup)
if not grouped:
# Fallback: use the flat dataset if no groups exist
grouped = {"": dataset}
for group_prefix, group_ds in sorted(grouped.items()):
group_suffix = f"_{group_prefix}" if group_prefix else ""
group_tplotdir = (
tplotdir / group_prefix if group_prefix else tplotdir
)
group_mplotdir = (
mplotdir / group_prefix if group_prefix else mplotdir
)
group_tplotdir.mkdir(exist_ok=True, parents=True)
group_mplotdir.mkdir(exist_ok=True, parents=True)
group_title = (
f"{title} [{group_prefix}]"
if title and group_prefix
else (group_prefix or title)
)
_ = self.plot_all(
dataset=group_ds,
outdir=group_mplotdir,
verbose=verbose,
num_chains=num_chains,
warmup=0.0, # already applied in get_grouped_datasets
title=group_title or None,
plot_kwargs=plot_kwargs,
subplots_kwargs=subplots_kwargs,
group_prefix=group_prefix,
)
_ = self.tplot_all(
dataset=group_ds,
outdir=group_tplotdir,
warmup=0.0, # already applied
append=append_tplot,
plot_type=tplot_type,
xkey=xkey,
verbose=verbose,
group_prefix=group_prefix,
)
if save:
try:
import h5py
use_hdf5 = True
except ImportError:
logger.warning(
"h5py not found! Saving dataset as netCDF instead."
)
use_hdf5 = False
ext = ".h5" if use_hdf5 else ".nc"
grouped = self.get_grouped_datasets(warmup=warmup)
if len(grouped) > 1:
# Save one dataset per group (no NaN padding)
for gprefix, gds in grouped.items():
label = gprefix if gprefix else (dataset_fname or "dataset")
_ = self.save_dataset(
dataset=gds,
outdir=base_dir,
fname=label,
use_hdf5=use_hdf5,
)
output_files[f"Dataset ({label})"] = str(
base_dir / f"{label}{ext}"
)
else:
# Single group or no groups: save as a single flat dataset
fname = "dataset" if dataset_fname is None else dataset_fname
ds_to_save = (
next(iter(grouped.values()))
if grouped
else dataset
)
_ = self.save_dataset(
dataset=ds_to_save,
outdir=base_dir,
fname=fname,
use_hdf5=use_hdf5,
)
output_files["Dataset"] = str(base_dir / f"{fname}{ext}")
if self.report_enabled:
logger.info(
"Saving history report to %s",
self._report_dir.joinpath(self._report_filename),
)
_wandb_run = self._tracker.wandb_run
if _wandb_run is not None:
logger.info(f"wandb.run=[{_wandb_run.name}]({_wandb_run.url})")
_mlflow_be = self._tracker.get_backend("mlflow")
if _mlflow_be is not None and getattr(_mlflow_be, "_active", False):
_run_url = getattr(_mlflow_be, "run_url", None)
_run_id = getattr(_mlflow_be, "_run_id", "?")
if _run_url:
logger.info("mlflow.run=[%s](%s)", _run_id, _run_url)
else:
logger.info(
"mlflow.run=%s (tracking_uri=%s)",
_run_id,
getattr(_mlflow_be, "_tracking_uri", "?"),
)
if self.history:
try:
columns = list(self.history.keys())
max_len = max(len(v) for v in self.history.values())
table_data = []
for i in range(max_len):
row = [
self.history[col][i]
if i < len(self.history[col])
else None
for col in columns
]
table_data.append(row)
self._tracker.log_table(
"training_history", columns=columns, data=table_data
)
except Exception:
logger.warning(
"Failed to log training history table via tracker"
)
if output_files:
# Upload output files as artifacts (MLflow, etc.) before finish
self._tracker.log_artifacts(output_files)
logger.info("Output files:")
for label, fpath in output_files.items():
logger.info(" %s: %s", label, fpath)
self._tracker.finish()
grouped = self.get_grouped_datasets(warmup=warmup)
if len(grouped) > 1:
for gname, gds in sorted(grouped.items()):
label = gname if gname else "default"
logger.info("[%s] %s", label, gds)
return grouped
logger.info("%s", dataset)
return dataset
get_dataset(data=None, warmup=0.0)
βοΈ
Build a single xarray Dataset from the history data.
For grouped datasets with independent dimensions, use
:meth:get_grouped_datasets instead.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Optional[dict[str, Union[list, ndarray, Tensor]]]
|
Dict of metric arrays; defaults to |
None
|
warmup
|
Optional[float]
|
Fraction of initial samples to drop. |
0.0
|
Source code in src/ezpz/history.py
def get_dataset(
self,
data: Optional[
dict[str, Union[list, np.ndarray, torch.Tensor]]
] = None,
warmup: Optional[float] = 0.0,
):
"""Build a single xarray Dataset from the history data.
For grouped datasets with independent dimensions, use
:meth:`get_grouped_datasets` instead.
Args:
data: Dict of metric arrays; defaults to ``self.history``.
warmup: Fraction of initial samples to drop.
"""
data = self.history_to_dict() if data is None else data
data_vars = {}
for key, val in data.items():
name = key.replace("/", "_")
try:
data_vars[name] = self.to_DataArray(val, warmup)
except ValueError:
logger.error(
f"Unable to create DataArray for {key}! Skipping!"
)
logger.error(f"{key}.shape= {np.stack(val).shape}") # type:ignore
return xr.Dataset(data_vars)
get_grouped_datasets(warmup=0.0)
βοΈ
Build one xarray Dataset per metric group (prefix).
Each group's metrics share the same draw dimension, so
train/ and eval/ metrics get independent lengths instead
of being padded to the longest array.
Returns:
| Type | Description |
|---|---|
dict[str, Dataset]
|
Dict mapping group prefix ( |
Source code in src/ezpz/history.py
def get_grouped_datasets(
self,
warmup: Optional[float] = 0.0,
) -> dict[str, xr.Dataset]:
"""Build one xarray Dataset per metric group (prefix).
Each group's metrics share the same ``draw`` dimension, so
``train/`` and ``eval/`` metrics get independent lengths instead
of being padded to the longest array.
Returns:
Dict mapping group prefix (``""`` for unprefixed) to Dataset.
"""
datasets: dict[str, xr.Dataset] = {}
for prefix, group_data in self._groups.items():
data_vars: dict[str, xr.DataArray] = {}
for key, val_list in group_data.items():
name = key.replace("/", "_")
try:
arr = torch.Tensor(val_list).detach().numpy(force=True)
data_vars[name] = self.to_DataArray(arr, warmup)
except (ValueError, RuntimeError):
logger.error(
"Unable to create DataArray for %s/%s! Skipping!",
prefix,
key,
)
if data_vars:
datasets[prefix] = xr.Dataset(data_vars)
return datasets
history_to_dict()
βοΈ
Convert internal history to a dictionary of numpy arrays.
log_metrics(metrics, *, logger=None, debug_prefixes=('hist/',), include_summary=True, rank0_only_summary=True, precision=6, omit_counter_metrics=True, counter_tokens=('iter', 'epoch', 'step', 'batch', 'idx', 'bidx'))
βοΈ
Log metrics, routing debug-prefixed keys to debug level.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, Any]
|
Dict of metric name to scalar value. |
required |
include_summary
|
bool
|
If True, append distributed min/max/std summary. |
True
|
omit_counter_metrics
|
bool
|
If True, skip counter keys (iter, epoch, etc.). |
True
|
Source code in src/ezpz/history.py
def log_metrics(
self,
metrics: dict[str, Any],
*,
logger: Optional[Any] = None,
debug_prefixes: tuple[str, ...] = ("hist/",),
include_summary: bool = True,
rank0_only_summary: bool = True,
precision: int = 6,
omit_counter_metrics: bool = True,
counter_tokens: tuple[str, ...] = (
"iter",
"epoch",
"step",
"batch",
"idx",
"bidx",
),
) -> None:
"""Log metrics, routing debug-prefixed keys to debug level.
Args:
metrics: Dict of metric name to scalar value.
include_summary: If True, append distributed min/max/std summary.
omit_counter_metrics: If True, skip counter keys (iter, epoch, etc.).
"""
log = logger if logger is not None else get_logger(__name__)
info_metrics, debug_metrics = self.split_metrics_for_logging(
metrics, debug_prefixes=debug_prefixes
)
def _is_counter_key(key: str) -> bool:
parts = key.replace("\\", "/").split("/")
if not parts:
return False
last = parts[-1]
for token in counter_tokens:
if last == token or last.endswith(f"_{token}"):
return True
return False
# Merge distributed min/max/std stats INTO the base info dict so
# format_compact_summary can collapse them into `key=value(Β±std)`
# form instead of emitting a second verbose line.
merged_for_summary: dict[str, Any] = dict(info_metrics)
if include_summary:
summary_input = info_metrics
if omit_counter_metrics:
summary_input = {
k: v
for k, v in info_metrics.items()
if not _is_counter_key(k)
}
summary_stats = self.summarize_distributed_min_max_std(
summary_input
)
if summary_stats and (not rank0_only_summary or self._rank == 0):
merged_for_summary.update(summary_stats)
from ezpz.utils import (
format_compact_summary,
format_memory_summary,
)
# format_compact_summary handles the noise reduction:
# - collapses base + */std into `key=value(Β±std)`
# - drops */mean /min /max /avg companions
# - strips memory keys (formatted separately below)
# - leaves counter keys (iter/step/epoch/...) bare
base = format_compact_summary(
merged_for_summary, precision=precision
).replace("train/", "")
# prefix=None lets format_memory_summary auto-detect "train/" /
# "eval/" / "" from the keys, so we don't have to probe twice.
memory_str = format_memory_summary(info_metrics)
parts = [
p
for p in (base, f"memory={memory_str}" if memory_str else "")
if p
]
info_msg = " ".join(parts)
if info_msg:
log.info(info_msg)
debug_msg = summarize_dict(debug_metrics, precision=precision).replace(
"train/", ""
)
if debug_msg:
log.debug(debug_msg)
plot(val, key=None, warmup=0.0, num_chains=128, title=None, outdir=None, subplots_kwargs=None, plot_kwargs=None, verbose=False)
βοΈ
Plot a single variable from the history.
NOTE: The warmup argument can be used to drop the first warmup
iterations (as a percent of the total number of iterations) from the
plot.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
val
|
ndarray
|
The data to plot. |
required |
key
|
Optional[str]
|
The key for the data. |
None
|
warmup
|
Optional[float]
|
The percentage of iterations to drop from the beginning of the plot. |
0.0
|
num_chains
|
Optional[int]
|
The number of chains to plot. |
128
|
title
|
Optional[str]
|
The title of the plot. |
None
|
outdir
|
Optional[PathLike]
|
The directory to save the plot to. |
None
|
subplots_kwargs
|
Optional[dict[str, Any]]
|
Additional arguments for subplots. |
None
|
plot_kwargs
|
Optional[dict[str, Any]]
|
Additional arguments for plotting. |
None
|
verbose
|
bool
|
Emit additional logging when saving plots. |
False
|
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot(
self,
val: np.ndarray,
key: Optional[str] = None,
warmup: Optional[float] = 0.0,
num_chains: Optional[int] = 128,
title: Optional[str] = None,
outdir: Optional[os.PathLike] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
verbose: bool = False,
):
"""
Plot a single variable from the history.
NOTE: The `warmup` argument can be used to drop the first `warmup`
iterations (as a percent of the total number of iterations) from the
plot.
Args:
val (np.ndarray): The data to plot.
key (Optional[str]): The key for the data.
warmup (Optional[float]): The percentage of iterations to drop from the
beginning of the plot.
num_chains (Optional[int]): The number of chains to plot.
title (Optional[str]): The title of the plot.
outdir (Optional[os.PathLike]): The directory to save the plot to.
subplots_kwargs (Optional[dict[str, Any]]): Additional arguments for
subplots.
plot_kwargs (Optional[dict[str, Any]]): Additional arguments for plotting.
verbose (bool): Emit additional logging when saving plots.
"""
import matplotlib.pyplot as plt
LW = plt.rcParams.get("axes.linewidth", 1.75)
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
figsize = subplots_kwargs.get("figsize", ezplot.set_size())
subplots_kwargs.update({"figsize": figsize})
num_chains = 16 if num_chains is None else num_chains
# tmp = val[0]
arr = np.array(val)
subfigs = None
steps = np.arange(arr.shape[0])
if warmup is not None and warmup > 0 and arr.size > 0:
if isinstance(warmup, int) or warmup >= 1:
warmup_frac = float(warmup) / float(arr.shape[0])
else:
warmup_frac = float(warmup)
warmup_frac = min(max(warmup_frac, 0.0), 1.0)
drop = min(int(round(warmup_frac * arr.shape[0])), arr.shape[0])
if drop > 0:
arr = arr[drop:]
steps = steps[drop:]
if len(arr.shape) == 2:
import seaborn as sns
_ = subplots_kwargs.pop("constrained_layout", True)
figsize = (3 * figsize[0], 1.5 * figsize[1])
fig = plt.figure(figsize=figsize, constrained_layout=True)
subfigs = fig.subfigures(1, 2)
gs_kw = {"width_ratios": [1.33, 0.33]}
(ax, ax1) = subfigs[1].subplots(
1, 2, sharey=True, gridspec_kw=gs_kw
)
ax.grid(alpha=0.2)
ax1.grid(False)
color = plot_kwargs.get("color", None)
label = r"$\langle$" + f" {key} " + r"$\rangle$"
ax.plot(
steps, arr.mean(-1), lw=1.5 * LW, label=label, **plot_kwargs
)
sns.kdeplot(y=arr.flatten(), ax=ax1, color=color, shade=True)
ax1.set_xticks([])
ax1.set_xticklabels([])
# ax1.set_yticks([])
# ax1.set_yticklabels([])
sns.despine(ax=ax, top=True, right=True)
sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
# ax.legend(loc='best', frameon=False)
ax1.set_xlabel("")
# ax1.set_ylabel('')
# ax.set_yticks(ax.get_yticks())
# ax.set_yticklabels(ax.get_yticklabels())
# ax.set_ylabel(key)
# _ = subfigs[1].subplots_adjust(wspace=-0.75)
axes = (ax, ax1)
else:
if len(arr.shape) == 1:
fig, ax = plt.subplots(**subplots_kwargs)
# assert isinstance(ax, plt.Axes)
ax.plot(steps, arr, **plot_kwargs)
axes = ax
elif len(arr.shape) == 3:
fig, ax = plt.subplots(**subplots_kwargs)
# assert isinstance(ax, plt.Axes)
cmap = plt.get_cmap("viridis")
nlf = arr.shape[1]
for idx in range(nlf):
# y = arr[:, idx, :].mean(-1)
# pkwargs = {
# 'color': cmap(idx / nlf),
# 'label': f'{idx}',
# }
# ax.plot(steps, y, **pkwargs)
label = plot_kwargs.pop("label", None)
if label is not None:
label = f"{label}-{idx}"
y = arr[:, idx, :]
color = cmap(idx / y.shape[1])
plot_kwargs["color"] = cmap(idx / y.shape[1])
if len(y.shape) == 2:
# TOO: Plot chains
if num_chains > 0:
for idx in range(min((num_chains, y.shape[1]))):
_ = ax.plot(
steps,
y[:, idx], # color,
lw=LW / 2.0,
alpha=0.8,
**plot_kwargs,
)
_ = ax.plot(
steps,
y.mean(-1), # color=color,
label=label,
**plot_kwargs,
)
else:
_ = ax.plot(
steps,
y, # color=color,
label=label,
**plot_kwargs,
)
axes = ax
else:
raise ValueError("Unexpected shape encountered")
ax.set_ylabel(key)
if num_chains > 0 and len(arr.shape) > 1:
# lw = LW / 2.
for idx in range(min(num_chains, arr.shape[1])):
# ax = subfigs[0].subplots(1, 1)
# plot values of invidual chains, arr[:, idx]
# where arr[:, idx].shape = [ndraws, 1]
ax.plot(
steps, arr[:, idx], alpha=0.5, lw=LW / 2.0, **plot_kwargs
)
ax.set_xlabel("step")
if title is not None:
fig.suptitle(title)
save_dir: Optional[Path]
if outdir is not None:
save_dir = Path(outdir).expanduser().resolve()
elif self.report_enabled:
save_dir = self._report_dir.joinpath("mplot")
else:
save_dir = None
if save_dir is not None:
# plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
# dpi=400, bbox_inches='tight')
save_dir.mkdir(parents=True, exist_ok=True)
outfile = save_dir.joinpath(f"{key}.svg")
if outfile.is_file():
tstamp = ezpz.get_timestamp()
pngdir = save_dir.joinpath("pngs")
pngdir.mkdir(exist_ok=True, parents=True)
pngfile = pngdir.joinpath(f"{key}-{tstamp}.png")
svgfile = save_dir.joinpath(f"{key}-{tstamp}.svg")
plt.savefig(pngfile, dpi=400, bbox_inches="tight")
plt.savefig(svgfile, dpi=400, bbox_inches="tight")
primary_asset: Optional[Path] = None
if save_dir is not None:
dirs = {
"png": Path(save_dir).joinpath("pngs/"),
"svg": Path(save_dir).joinpath("svgs/"),
}
_ = [i.mkdir(exist_ok=True, parents=True) for i in dirs.values()]
for ext, d in dirs.items():
outfile = d.joinpath(f"{key}.{ext}")
if outfile.is_file():
outfile = d.joinpath(f"{key}-subfig.{ext}")
if verbose:
logger.info(f"Saving {key} plot to: {outfile.resolve()}")
plt.savefig(outfile, dpi=400, bbox_inches="tight")
if primary_asset is None and ext == "png":
primary_asset = outfile
if (
self.report_enabled
and primary_asset is not None
and Path(primary_asset).exists()
):
self._write_plot_report(
key,
primary_asset,
kind="matplotlib",
metadata={"shape": list(arr.shape)},
)
self._wandb_log_matplotlib_asset(key, primary_asset, kind="matplotlib")
return fig, subfigs, axes
plot_2d_xarr(xarr, label=None, num_chains=None, title=None, outdir=None, subplots_kwargs=None, plot_kwargs=None)
βοΈ
Plot a 2D xarray DataArray (chain x draw) with matplotlib/seaborn.
Source code in src/ezpz/history.py
def plot_2d_xarr(
self,
xarr: xr.DataArray,
label: Optional[str] = None,
num_chains: Optional[int] = None,
title: Optional[str] = None,
outdir: Optional[os.PathLike] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
):
"""Plot a 2D xarray DataArray (chain x draw) with matplotlib/seaborn."""
import matplotlib.pyplot as plt
import seaborn as sns
LW = plt.rcParams.get("axes.linewidth", 1.75)
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
assert len(xarr.shape) == 2
assert "draw" in xarr.coords and "chain" in xarr.coords
num_chains = len(xarr.chain) if num_chains is None else num_chains
# _ = subplots_kwargs.pop('constrained_layout', True)
figsize = plt.rcParams.get("figure.figsize", (8, 6))
figsize = (3 * figsize[0], 1.5 * figsize[1])
fig = plt.figure(figsize=figsize, constrained_layout=True)
subfigs = fig.subfigures(1, 2)
gs_kw = {"width_ratios": [1.33, 0.33]}
(ax, ax1) = subfigs[1].subplots(1, 2, sharey=True, gridspec_kw=gs_kw)
ax.grid(alpha=0.2)
ax1.grid(False)
color = plot_kwargs.get("color", f"C{np.random.randint(6)}")
label = r"$\langle$" + f" {label} " + r"$\rangle$"
ax.plot(
xarr.draw.values,
xarr.mean("chain"),
color=color,
lw=1.5 * LW,
label=label,
**plot_kwargs,
)
for idx in range(num_chains):
# ax = subfigs[0].subplots(1, 1)
# plot values of invidual chains, arr[:, idx]
# where arr[:, idx].shape = [ndraws, 1]
# ax0.plot(
# xarr.draw.values,
# xarr[xarr.chain == idx][0],
# lw=1.,
# alpha=0.7,
# color=color
# )
ax.plot(
xarr.draw.values,
xarr[xarr.chain == idx][0],
color=color,
alpha=0.5,
lw=LW / 2.0,
**plot_kwargs,
)
axes = (ax, ax1)
sns.kdeplot(y=xarr.values.flatten(), ax=ax1, color=color, shade=True)
ax1.set_xticks([])
ax1.set_xticklabels([])
# ax1.set_yticks([])
# ax1.set_yticklabels([])
sns.despine(ax=ax, top=True, right=True)
sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
# ax.legend(loc='best', frameon=False)
ax1.set_xlabel("")
# ax1.set_ylabel('')
# ax.set_yticks(ax.get_yticks())
# ax.set_yticklabels(ax.get_yticklabels())
# ax.set_ylabel(key)
# _ = subfigs[1].subplots_adjust(wspace=-0.75)
# if num_chains > 0 and len(arr.shape) > 1:
# lw = LW / 2.
# num_chains = np.min([
# 16,
# len(xarr.coords['chain']),
# ])
sns.despine(subfigs[0])
ax0 = subfigs[0].subplots(1, 1)
im = xarr.plot(ax=ax0) # type:ignore
im.colorbar.set_label(label) # type:ignore
# ax0.plot(
# xarr.draw.values,
# xarr.mean('chain'),
# lw=2.,
# color=color
# )
# for idx in range(min(num_chains, i.shape[1])):
ax.set_xlabel("step")
if title is not None:
fig.suptitle(title)
if outdir is not None:
assert label is not None
# plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
# dpi=400, bbox_inches='tight')
outfile = Path(outdir).joinpath(f"{label}.svg")
if outfile.is_file():
tstamp = get_timestamp("%Y-%m-%d-%H%M%S")
pngdir = Path(outdir).joinpath("pngs")
pngdir.mkdir(exist_ok=True, parents=True)
pngfile = pngdir.joinpath(f"{label}-{tstamp}.png")
svgfile = Path(outdir).joinpath(f"{label}-{tstamp}.svg")
plt.savefig(pngfile, dpi=400, bbox_inches="tight")
plt.savefig(svgfile, dpi=400, bbox_inches="tight")
plot_all(num_chains=128, warmup=0.0, title=None, verbose=False, outdir=None, subplots_kwargs=None, plot_kwargs=None, dataset=None, data=None, group_prefix='')
βοΈ
Create matplotlib ridge plots for all metrics in the dataset.
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot_all(
self,
num_chains: int = 128,
warmup: Optional[float | int] = 0.0,
title: Optional[str] = None,
verbose: bool = False,
outdir: Optional[os.PathLike] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
dataset: Optional[xr.Dataset] = None,
data: Optional[dict] = None,
group_prefix: str = "",
):
"""Create matplotlib ridge plots for all metrics in the dataset."""
plot_kwargs = {} if plot_kwargs is None else dict(plot_kwargs)
subplots_kwargs = (
{} if subplots_kwargs is None else dict(subplots_kwargs)
)
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
_ = ezplot.make_ridgeplots(
dataset,
outdir=outdir,
drop_nans=True,
drop_zeros=False,
num_chains=num_chains,
cmap="viridis",
save_plot=(outdir is not None),
)
groups = self._group_metric_variables(dataset)
for idx, (metric_name, metric_vars) in enumerate(
sorted(groups.items())
):
display_name = (
f"{group_prefix}/{metric_name}"
if group_prefix
else metric_name
)
plot_kwargs["color"] = f"C{idx % 9}"
asset = self._plot_metric_group(
display_name,
metric_vars,
warmup=warmup,
title=title,
outdir=Path(outdir) if outdir is not None else None,
subplots_kwargs=subplots_kwargs,
plot_kwargs=plot_kwargs,
verbose=verbose,
)
self._wandb_log_matplotlib_asset(
metric_name, asset, kind="matplotlib"
)
if asset is not None and self.report_enabled and asset.exists():
components = sorted(metric_vars.keys())
sample_series = self._series_from_dataarray(
metric_vars[components[0]]
)
self._write_plot_report(
metric_name,
asset,
kind="matplotlib",
metadata={
"components": ", ".join(components),
"points": len(sample_series),
},
)
return dataset
plot_dataArray(val, key=None, warmup=0.0, num_chains=0, title=None, outdir=None, subplots_kwargs=None, plot_kwargs=None, verbose=False, line_labels=False, logfreq=None)
βοΈ
Plot a single variable from the history as an xarray DataArray.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
val
|
DataArray
|
The data to plot. |
required |
key
|
Optional[str]
|
The key for the data. |
None
|
warmup
|
Optional[float]
|
The percentage of iterations to drop from the beginning of the plot. |
0.0
|
num_chains
|
Optional[int]
|
The number of chains to plot. |
0
|
title
|
Optional[str]
|
The title of the plot. |
None
|
outdir
|
Optional[str]
|
The directory to save the plot to. |
None
|
subplots_kwargs
|
Optional[dict[str, Any]]
|
Additional arguments for subplots. |
None
|
plot_kwargs
|
Optional[dict[str, Any]]
|
Additional arguments for plotting. |
None
|
verbose
|
bool
|
Whether to print the plot. |
False
|
line_labels
|
bool
|
Whether to label lines in the plot. |
False
|
logfreq
|
Optional[int]
|
The log frequency of the plot. |
None
|
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot_dataArray(
self,
val: xr.DataArray,
key: Optional[str] = None,
warmup: Optional[float] = 0.0,
num_chains: Optional[int] = 0,
title: Optional[str] = None,
outdir: Optional[str] = None,
subplots_kwargs: Optional[dict[str, Any]] = None,
plot_kwargs: Optional[dict[str, Any]] = None,
verbose: bool = False,
line_labels: bool = False,
logfreq: Optional[int] = None,
):
"""
Plot a single variable from the history as an xarray DataArray.
Args:
val (xr.DataArray): The data to plot.
key (Optional[str]): The key for the data.
warmup (Optional[float]): The percentage of iterations to drop from the
beginning of the plot.
num_chains (Optional[int]): The number of chains to plot.
title (Optional[str]): The title of the plot.
outdir (Optional[str]): The directory to save the plot to.
subplots_kwargs (Optional[dict[str, Any]]): Additional arguments for
subplots.
plot_kwargs (Optional[dict[str, Any]]): Additional arguments for plotting.
verbose (bool): Whether to print the plot.
line_labels (bool): Whether to label lines in the plot.
logfreq (Optional[int]): The log frequency of the plot.
"""
import matplotlib.pyplot as plt
plot_kwargs = {} if plot_kwargs is None else plot_kwargs
subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
ezplot.set_plot_style()
plt.rcParams["axes.labelcolor"] = "#bdbdbd"
figsize = subplots_kwargs.get("figsize", ezplot.set_size())
subplots_kwargs.update({"figsize": figsize})
subfigs = None
# if key == 'dt':
# warmup = 0.2
arr = val.values # shape: [nchains, ndraws]
# steps = np.arange(len(val.coords['draw']))
steps = val.coords["draw"]
if warmup is not None and warmup > 0.0 and arr.size > 0:
if isinstance(warmup, int) or warmup >= 1:
warmup_frac = float(warmup) / float(arr.shape[0])
else:
warmup_frac = float(warmup)
warmup_frac = min(max(warmup_frac, 0.0), 1.0)
drop = min(int(round(warmup_frac * arr.shape[0])), arr.shape[0])
if drop > 0:
arr = arr[drop:]
steps = steps[drop:]
if len(arr.shape) == 2:
fig, axes = ezplot.plot_combined(
val,
key=key,
num_chains=num_chains,
plot_kwargs=plot_kwargs,
subplots_kwargs=subplots_kwargs,
)
else:
if len(arr.shape) == 1:
fig, ax = ezplot.subplots(**subplots_kwargs)
try:
ax.plot(steps, arr, **plot_kwargs)
except ValueError:
try:
ax.plot(steps, arr[~np.isnan(arr)], **plot_kwargs)
except Exception:
logger.error(f"Unable to plot {key}! Continuing")
_ = ax.grid(True, alpha=0.2)
axes = ax
elif len(arr.shape) == 3:
fig, ax = ezplot.subplots(**subplots_kwargs)
cmap = plt.get_cmap("viridis")
y = val.mean("chain")
for idx in range(len(val.coords["leapfrog"])):
pkwargs = {
"color": cmap(idx / len(val.coords["leapfrog"])),
"label": f"{idx}",
}
ax.plot(steps, y[idx], **pkwargs)
axes = ax
else:
raise ValueError("Unexpected shape encountered")
ax = plt.gca()
# assert isinstance(ax, plt.Axes)
assert key is not None
_ = ax.set_ylabel(key)
_ = ax.set_xlabel("step")
# if num_chains > 0 and len(arr.shape) > 1:
# lw = LW / 2.
# #for idx in range(min(num_chains, arr.shape[1])):
# nchains = len(val.coords['chains'])
# for idx in range(min(nchains, num_chains)):
# # ax = subfigs[0].subplots(1, 1)
# # plot values of invidual chains, arr[:, idx]
# # where arr[:, idx].shape = [ndraws, 1]
# ax.plot(steps, val
# alpha=0.5, lw=lw/2., **plot_kwargs)
if title is not None:
fig = plt.gcf()
_ = fig.suptitle(title)
if logfreq is not None:
ax = plt.gca()
xticks = ax.get_xticks() # type: ignore
_ = ax.set_xticklabels( # type: ignore
[f"{logfreq * int(i)}" for i in xticks] # type: ignore
)
save_dir: Optional[Path]
if outdir is not None:
save_dir = Path(outdir).expanduser().resolve()
elif self.report_enabled:
save_dir = self._report_dir.joinpath("dataarray")
else:
save_dir = None
primary_asset: Optional[Path] = None
if save_dir is not None:
dirs = {
"png": Path(save_dir).joinpath("pngs/"),
"svg": Path(save_dir).joinpath("svgs/"),
}
_ = [i.mkdir(exist_ok=True, parents=True) for i in dirs.values()]
if verbose:
logger.info(
f"Saving {key} plot to: {Path(save_dir).resolve()}"
)
for ext, d in dirs.items():
outfile = d.joinpath(f"{key}.{ext}")
plt.savefig(outfile, dpi=400, bbox_inches="tight")
if primary_asset is None and ext == "png":
primary_asset = outfile
if (
self.report_enabled
and primary_asset is not None
and Path(primary_asset).exists()
):
metadata = {"dims": list(val.dims)}
self._write_plot_report(
key,
primary_asset,
kind="dataarray",
metadata=metadata,
)
self._wandb_log_matplotlib_asset(key, primary_asset, kind="dataarray")
return (fig, subfigs, axes)
plot_dataset(title=None, nchains=None, outdir=None, dataset=None, data=None, warmup=None)
βοΈ
Plot the full xarray Dataset via ezplot.plot_dataset.
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def plot_dataset(
self,
title: Optional[str] = None,
nchains: Optional[int] = None,
outdir: Optional[os.PathLike] = None,
dataset: Optional[xr.Dataset] = None,
data: Optional[dict] = None,
warmup: Optional[int | float] = None,
# subplots_kwargs: Optional[dict[str, Any]] = None,
# plot_kwargs: Optional[dict[str, Any]] = None,
):
"""Plot the full xarray Dataset via ``ezplot.plot_dataset``."""
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
return ezplot.plot_dataset(
dataset=dataset,
nchains=nchains,
title=title,
outdir=outdir,
)
save_dataset(outdir, fname='dataset', use_hdf5=True, data=None, dataset=None, warmup=None, **kwargs)
βοΈ
Save the history dataset to disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
outdir
|
PathLike
|
Directory to write the dataset file. |
required |
fname
|
str
|
Base filename (default |
'dataset'
|
use_hdf5
|
bool
|
If True, save as HDF5; otherwise NetCDF. |
True
|
Returns:
| Type | Description |
|---|---|
Path
|
Path to the saved file. |
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def save_dataset(
self,
outdir: PathLike,
fname: str = "dataset",
use_hdf5: bool = True,
data: Optional[
dict[str, Union[list, np.ndarray, torch.Tensor]]
] = None,
dataset: Optional[xr.Dataset] = None,
warmup: Optional[int | float] = None,
**kwargs,
) -> Path:
"""Save the history dataset to disk.
Args:
outdir: Directory to write the dataset file.
fname: Base filename (default ``"dataset"``).
use_hdf5: If True, save as HDF5; otherwise NetCDF.
Returns:
Path to the saved file.
"""
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
if dataset is not None:
dset_name = f"{fname}_dataset" if fname != "dataset" else fname
try:
dataframe = dataset.to_dataframe()
columns = list(dataframe.columns)
data_rows = dataframe.values.tolist()
self._tracker.log_table(
dset_name, columns=columns, data=data_rows
)
except Exception as e:
logger.warning(
"Unable to log dataset table via tracker: %s", e
)
return save_dataset(
dataset,
outdir=outdir,
fname=fname,
use_hdf5=use_hdf5,
**kwargs,
)
split_metrics_for_logging(metrics, debug_prefixes=('hist/',))
staticmethod
βοΈ
Split metrics into info-level and debug-level groups.
Keys starting with any of debug_prefixes are placed in the
debug dict; everything else goes into the info dict.
Source code in src/ezpz/history.py
@staticmethod
def split_metrics_for_logging(
metrics: dict[str, Any],
debug_prefixes: tuple[str, ...] = ("hist/",),
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Split metrics into info-level and debug-level groups.
Keys starting with any of ``debug_prefixes`` are placed in the
debug dict; everything else goes into the info dict.
"""
info_metrics: dict[str, Any] = {}
debug_metrics: dict[str, Any] = {}
for key, value in metrics.items():
if key.startswith(debug_prefixes):
debug_metrics[key] = value
else:
info_metrics[key] = value
return info_metrics, debug_metrics
summarize_distributed_min_max_std(metrics)
βοΈ
Compute distributed mean/min/max/std via all-reduce.
Falls back to local summarize_min_max_std when distributed
stats are unavailable. All-zero entries are pruned.
Source code in src/ezpz/history.py
def summarize_distributed_min_max_std(
self, metrics: dict[str, Any]
) -> dict[str, float]:
"""Compute distributed mean/min/max/std via all-reduce.
Falls back to local ``summarize_min_max_std`` when distributed
stats are unavailable. All-zero entries are pruned.
"""
summary_stats = self._compute_distributed_metrics(metrics)
if not summary_stats:
summary_stats = self.summarize_min_max_std(metrics)
filtered: dict[str, float] = {
k: v
for k, v in summary_stats.items()
if k.endswith(("/mean", "/min", "/max", "/std"))
}
keys = {k.rsplit("/", 1)[0] for k in filtered}
pruned: dict[str, float] = {}
for base in keys:
mean_v = filtered.get(f"{base}/mean")
min_v = filtered.get(f"{base}/min")
max_v = filtered.get(f"{base}/max")
std_v = filtered.get(f"{base}/std")
if (
mean_v == 0.0
and min_v == 0.0
and max_v == 0.0
and std_v == 0.0
):
continue
# Skip zero-variance metrics (min == max) β no useful info
if min_v is not None and max_v is not None and min_v == max_v:
continue
if mean_v is not None:
pruned[f"{base}/mean"] = mean_v
if min_v is not None:
pruned[f"{base}/min"] = min_v
if max_v is not None:
pruned[f"{base}/max"] = max_v
if std_v is not None:
pruned[f"{base}/std"] = std_v
return pruned
summarize_min_max_std(metrics)
staticmethod
βοΈ
Compute mean/min/max/std for each numeric metric.
Source code in src/ezpz/history.py
@staticmethod
def summarize_min_max_std(
metrics: dict[str, Any],
) -> dict[str, float]:
"""Compute mean/min/max/std for each numeric metric."""
numeric: dict[str, list[float]] = {}
for key, value in metrics.items():
if isinstance(value, (int, float)):
numeric[key] = [float(value)]
elif torch.is_tensor(value) and value.numel() == 1:
numeric[key] = [float(value.item())]
summary: dict[str, float] = {}
for key, values in numeric.items():
if not values:
continue
t = torch.tensor(values)
summary[f"{key}/mean"] = float(t.mean().item())
summary[f"{key}/min"] = float(t.min().item())
summary[f"{key}/max"] = float(t.max().item())
summary[f"{key}/std"] = float(t.std(unbiased=False).item())
return summary
to_DataArray(x, warmup=0.0)
βοΈ
Convert a list, array, or tensor to an xarray DataArray.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Union[list, ndarray, Tensor]
|
Input data (1D, 2D, or 3D). |
required |
warmup
|
Optional[float]
|
Fraction of initial samples to drop (0.0 to 1.0). |
0.0
|
Source code in src/ezpz/history.py
def to_DataArray(
self,
x: Union[list, np.ndarray, torch.Tensor],
warmup: Optional[float] = 0.0,
) -> xr.DataArray:
"""Convert a list, array, or tensor to an xarray DataArray.
Args:
x: Input data (1D, 2D, or 3D).
warmup: Fraction of initial samples to drop (0.0 to 1.0).
"""
if isinstance(x, tuple):
x = list(x)
if (
isinstance(x, list)
and len(x) > 0
and isinstance(x[0], torch.Tensor)
):
x = torch.Tensor(x).detach().numpy(force=True)
try:
arr = grab_tensor(x)
except ValueError:
arr = np.array(x).real
# arr = np.array(x)
logger.info(f"len(x): {len(x)}")
x0_shape = getattr(x[0], "shape", None) if len(x) > 0 else None
logger.info(f"x[0].shape: {x0_shape}")
logger.info(f"arr.shape: {arr.shape}")
assert isinstance(arr, np.ndarray)
if warmup is not None and warmup > 0 and len(arr) > 0:
if isinstance(warmup, int):
warmup = warmup / len(arr)
# drop = int(warmup * arr.shape[0])
drop = int(warmup * len(arr))
arr = arr[drop:]
# steps = np.arange(len(arr))
if len(arr.shape) == 1: # [ndraws]
ndraws = arr.shape[0]
dims = ["draw"]
coords = [np.arange(len(arr))]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore
if len(arr.shape) == 2: # [nchains, ndraws]
arr = arr.T
nchains, ndraws = arr.shape
dims = ("chain", "draw")
coords = [np.arange(nchains), np.arange(ndraws)]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore
if len(arr.shape) == 3: # [nchains, nlf, ndraws]
arr = arr.T
nchains, nlf, ndraws = arr.shape
dims = ("chain", "leapfrog", "draw")
coords = [np.arange(nchains), np.arange(nlf), np.arange(ndraws)]
return xr.DataArray(arr, dims=dims, coords=coords) # type:ignore
else:
print(f"arr.shape: {arr.shape}")
raise ValueError("Invalid shape encountered")
tplot_all(outdir=None, warmup=0.0, append=True, xkey=None, dataset=None, data=None, logfreq=None, plot_type=None, verbose=False, group_prefix='')
βοΈ
Create terminal plots for all metrics using plotext.
Counter metrics (iter, epoch, step, etc.) are skipped.
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def tplot_all(
self,
outdir: Optional[os.PathLike] = None,
warmup: Optional[float] = 0.0,
append: bool = True,
xkey: Optional[str] = None,
dataset: Optional[xr.Dataset] = None,
data: Optional[dict] = None,
logfreq: Optional[int] = None,
plot_type: Optional[str] = None,
verbose: bool = False,
group_prefix: str = "",
):
"""Create terminal plots for all metrics using plotext.
Counter metrics (iter, epoch, step, etc.) are skipped.
"""
dataset = (
dataset
if dataset is not None
else (
self.get_dataset(
data=(data if data is not None else self.history),
warmup=warmup,
)
)
)
outdir_path = Path(os.getcwd()) if outdir is None else Path(outdir)
groups = self._group_metric_variables(dataset)
for metric_name, metric_vars in sorted(groups.items()):
parts = metric_name.replace("\\", "/").split("/")
last = parts[-1] if parts else metric_name
if last in {"iter", "epoch", "step", "batch", "idx", "bidx"}:
continue
if any(
last.endswith(f"_{token}")
for token in ("iter", "epoch", "step", "batch", "idx", "bidx")
):
continue
if (xkey is not None and metric_name == xkey) or xkey in [
"iter",
"draw",
]:
continue
display_name = (
f"{group_prefix}/{metric_name}"
if group_prefix
else metric_name
)
self._tplot_metric_group(
display_name,
metric_vars,
warmup=warmup,
outdir=outdir_path,
plot_type=plot_type,
verbose=verbose,
logfreq=logfreq,
)
update(metrics, precision=6, use_wandb=None, commit=True, summarize=True, step=None)
βοΈ
Update the history with a dictionary of metrics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict
|
Dictionary of metrics to update the history with. |
required |
precision
|
int
|
Precision for summarizing the metrics. |
6
|
use_wandb
|
Optional[bool]
|
Whether to log the metrics to Weights & Biases. |
None
|
commit
|
Optional[bool]
|
Whether to commit the log to Weights & Biases. |
True
|
summarize
|
Optional[bool]
|
Whether to summarize the metrics. |
True
|
Source code in src/ezpz/history.py
@timeitlogit(rank=get_rank(), record=True, verbose=False, prefix="history")
def update(
self,
metrics: dict,
precision: int = 6,
use_wandb: Optional[bool] = None,
commit: Optional[bool] = True,
summarize: Optional[bool] = True,
step: Optional[int] = None,
) -> str:
"""
Update the history with a dictionary of metrics.
Args:
metrics (dict): Dictionary of metrics to update the history with.
precision (int): Precision for summarizing the metrics.
use_wandb (Optional[bool]): Whether to log the metrics to Weights & Biases.
commit (Optional[bool]): Whether to commit the log to Weights & Biases.
summarize (Optional[bool]): Whether to summarize the metrics.
"""
prefix, stripped = self._split_prefix(metrics)
group = self._groups.setdefault(prefix, {})
for key, val in stripped.items():
try:
group[key].append(val)
except KeyError:
group[key] = [val]
self._invalidate_flat_cache()
aggregated_metrics = self._compute_distributed_metrics(metrics)
if aggregated_metrics and self._rank == 0:
for agg_key, agg_val in aggregated_metrics.items():
# Aggregated keys look like "train/loss/mean" β strip the
# same prefix so they land in the same group as raw metrics.
if prefix and agg_key.startswith(f"{prefix}/"):
short_agg_key = agg_key[len(prefix) + 1 :]
else:
short_agg_key = agg_key
self._update(short_agg_key, agg_val, prefix=prefix)
metrics_for_logging = dict(metrics)
if aggregated_metrics and self._rank == 0:
metrics_for_logging.update(aggregated_metrics)
sanitized_metrics = self._sanitize_metrics(metrics_for_logging)
summary_source = (
sanitized_metrics
if aggregated_metrics and self._rank == 0
else self._sanitize_metrics(metrics)
)
if use_wandb is not None:
warnings.warn(
"The 'use_wandb' parameter is deprecated. Use "
"backends='wandb' in the History constructor instead.",
DeprecationWarning,
stacklevel=2,
)
self._tracker.log(sanitized_metrics, step=step, commit=commit)
self._write_jsonl_entry(sanitized_metrics, aggregated_metrics)
if summarize:
from ezpz.utils import (
format_compact_summary,
format_memory_summary,
)
# format_compact_summary handles all the noise reduction:
# - collapses base + */std into `key=value(Β±std)`
# - drops the */mean /min /max /avg companions
# - strips memory keys (handled separately below)
# - leaves counter-like keys (iter/step/epoch/...) bare
base = format_compact_summary(
summary_source, precision=precision
)
# Build the compact memory string from the RAW metrics dict
# (which still has the 4 keys even after the filter above).
# Empty string when no memory keys, e.g. on CPU/MPS.
#
# `prefix` here came from `_split_prefix(metrics)` and is the
# bare namespace WITHOUT a trailing slash (e.g. "train").
# `format_memory_summary` expects either a full prefix WITH
# slash ("train/") or None for auto-detection. Passing
# "train" directly would make the lookup miss
# ("trainmem_alloc" β no key matches) and silently drop the
# memory= token from the line. Use None and let the helper
# infer the prefix from the *mem_alloc keys it scans.
memory_str = format_memory_summary(metrics, prefix=None)
parts = [p for p in (base, f"memory={memory_str}" if memory_str else "") if p]
return " ".join(parts)
return ""
StopWatch
βοΈ
Bases: ContextDecorator
A simple stopwatch context manager for measuring time taken by a block of code.
Source code in src/ezpz/history.py
class StopWatch(ContextDecorator):
"""
A simple stopwatch context manager for measuring time taken by a block of code.
"""
def __init__(
self,
msg: str,
wbtag: Optional[str] = None,
iter: Optional[int] = None,
commit: Optional[bool] = False,
prefix: str = "StopWatch/",
log_output: bool = True,
) -> None:
"""
Initialize the StopWatch.
Args:
msg (str): Message to log when the stopwatch is started.
wbtag (Optional[str]): Optional tag for logging to Weights & Biases.
iter (Optional[int]): Optional iteration number to log.
commit (Optional[bool]): Whether to commit the log to Weights & Biases.
prefix (str): Prefix for the log data.
log_output (bool): Whether to log the output message.
"""
self.msg = msg
self.data = {}
self.iter = iter if iter is not None else None
self.prefix = prefix
self.wbtag = wbtag if wbtag is not None else None
self.log_output = log_output
self.commit = commit
if wbtag is not None:
self.data = {
f"{self.wbtag}/dt": None,
}
if iter is not None:
self.data |= {
f"{self.wbtag}/iter": self.iter,
}
def __enter__(self):
"""Start the stopwatch."""
self.time = time.perf_counter()
return self
def __exit__(self, t, v, traceback):
"""Stop the stopwatch and log the time taken."""
dt = time.perf_counter() - self.time
# if self.wbtag is not None and wandb.run is not None:
# if len(self.data) > 0 and wandb.run is not None:
try:
if (
len(self.data) > 0
and wandb is not None
and (wbrun := getattr(wandb, "run", None)) is not None
):
self.data |= {f"{self.wbtag}/dt": dt}
wbrun.log({self.prefix: self.data}, commit=self.commit)
except Exception as e:
logger.error(f"Unable to log to wandb: {e}")
if self.log_output:
logger.info(f"{self.msg} took {dt:.3f} seconds")
__enter__()
βοΈ
__exit__(t, v, traceback)
βοΈ
Stop the stopwatch and log the time taken.
Source code in src/ezpz/history.py
def __exit__(self, t, v, traceback):
"""Stop the stopwatch and log the time taken."""
dt = time.perf_counter() - self.time
# if self.wbtag is not None and wandb.run is not None:
# if len(self.data) > 0 and wandb.run is not None:
try:
if (
len(self.data) > 0
and wandb is not None
and (wbrun := getattr(wandb, "run", None)) is not None
):
self.data |= {f"{self.wbtag}/dt": dt}
wbrun.log({self.prefix: self.data}, commit=self.commit)
except Exception as e:
logger.error(f"Unable to log to wandb: {e}")
if self.log_output:
logger.info(f"{self.msg} took {dt:.3f} seconds")
__init__(msg, wbtag=None, iter=None, commit=False, prefix='StopWatch/', log_output=True)
βοΈ
Initialize the StopWatch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
msg
|
str
|
Message to log when the stopwatch is started. |
required |
wbtag
|
Optional[str]
|
Optional tag for logging to Weights & Biases. |
None
|
iter
|
Optional[int]
|
Optional iteration number to log. |
None
|
commit
|
Optional[bool]
|
Whether to commit the log to Weights & Biases. |
False
|
prefix
|
str
|
Prefix for the log data. |
'StopWatch/'
|
log_output
|
bool
|
Whether to log the output message. |
True
|
Source code in src/ezpz/history.py
def __init__(
self,
msg: str,
wbtag: Optional[str] = None,
iter: Optional[int] = None,
commit: Optional[bool] = False,
prefix: str = "StopWatch/",
log_output: bool = True,
) -> None:
"""
Initialize the StopWatch.
Args:
msg (str): Message to log when the stopwatch is started.
wbtag (Optional[str]): Optional tag for logging to Weights & Biases.
iter (Optional[int]): Optional iteration number to log.
commit (Optional[bool]): Whether to commit the log to Weights & Biases.
prefix (str): Prefix for the log data.
log_output (bool): Whether to log the output message.
"""
self.msg = msg
self.data = {}
self.iter = iter if iter is not None else None
self.prefix = prefix
self.wbtag = wbtag if wbtag is not None else None
self.log_output = log_output
self.commit = commit
if wbtag is not None:
self.data = {
f"{self.wbtag}/dt": None,
}
if iter is not None:
self.data |= {
f"{self.wbtag}/iter": self.iter,
}