diff --git a/detectron2/utils/events.py b/detectron2/utils/events.py index c4f9dadfd8..f6f9d0e7a2 100644 --- a/detectron2/utils/events.py +++ b/detectron2/utils/events.py @@ -10,6 +10,7 @@ from typing import Optional import torch from fvcore.common.history_buffer import HistoryBuffer +import wandb from detectron2.utils.file_io import PathManager @@ -191,6 +192,62 @@ def close(self): if "_writer" in self.__dict__: self._writer.close() +class WandbWriter(EventWriter): + def __init__(self, project_name, run_name=None, window_size=20, **kwargs): + """ + Args: + project_name (str): The name of the W&B project. + run_name (str): The name of the W&B run. + window_size (int): The window size for smoothing metrics. + kwargs: Additional arguments for wandb.init(). + """ + self._window_size = window_size + self._last_write = -1 + wandb.init(project=project_name, name=run_name, **kwargs) + + def write(self): + storage = get_event_storage() + new_last_write = self._last_write + metrics_dict = storage.latest_with_smoothing_hint(self._window_size).items() + wandb_metrics = {} + new_last_write = self._last_write + for k, (v, iter) in metrics_dict: + if iter > self._last_write: + wandb_metrics[k] = v + new_last_write = max(new_last_write, iter) + self._last_write = new_last_write + + if len(storage._vis_data) >= 1: + # Create a list to store all images for this step + images_dict = {} + + for img_name, img, step_num in storage._vis_data: + # Transpose from C,H,W to H,W,C + img = img.transpose(1, 2, 0) + # Add image to dictionary + images_dict[img_name] = wandb.Image(img) + + # Log both metrics and all images for this step + log_dict = { + **wandb_metrics, # Unpack all metrics + **images_dict # Unpack all images + } + wandb.log(log_dict, step=iter) + + # Storage stores all image data and rely on this writer to clear them. + # As a result it assumes only one writer will use its image data. + # An alternative design is to let storage store limited recent + # data (e.g. only the most recent image) that all writers can access. + # In that case a writer may not see all image data if its period is long. + storage.clear_images() + else: + wandb.log(wandb_metrics, step=new_last_write) + + self._last_write = new_last_write + + def close(self): + wandb.finish() + class CommonMetricPrinter(EventWriter): """