Skip to content

Commit

Permalink
set step in tensorboard
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Jan 31, 2025
1 parent c542be0 commit dd4afc3
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions shimmer_ssd/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@


def log_image(
logger: Logger, key: str, image: torch.Tensor | Image.Image, step: int | None = None
logger: Logger,
key: str,
image: torch.Tensor | Image.Image,
tensorboard_step: int | None = None,
):
if isinstance(logger, WandbLogger):
logger.log_image(key, [image], step)
logger.log_image(key, [image])
elif isinstance(logger, TensorBoardLogger):
torch_image = to_tensor(image) if isinstance(image, Image.Image) else image
logger.experiment.add_image(key, torch_image, step)
logger.experiment.add_image(key, torch_image, tensorboard_step)
else:
LOGGER.warning(
"[Sample Logger] Only logging to tensorboard or wandb is supported"
Expand All @@ -52,14 +55,14 @@ def log_text(
key: str,
columns: list[str],
data: list[list[str]],
step: int | None = None,
tensorboard_step: int | None = None,
):
if isinstance(logger, WandbLogger):
logger.log_text(key, columns, data, step=step)
logger.log_text(key, columns, data)
elif isinstance(logger, TensorBoardLogger):
text = ", ".join(columns) + "\n"
text += "\n".join([", ".join(d) for d in data])
logger.experiment.add_image(key, text, step)
logger.experiment.add_text(key, text, tensorboard_step)
else:
LOGGER.warning(
"[Sample Logger] Only logging to tensorboard or wandb is supported"
Expand All @@ -80,6 +83,11 @@ def __init__(
self.every_n_epochs = every_n_epochs
self.log_key = log_key
self.mode = mode
self._global_step = 0

def get_step(self) -> int:
self._global_step += 1
return self._global_step - 1

def to(self, samples: _T, device: torch.device) -> _T:
raise NotImplementedError
Expand Down Expand Up @@ -281,7 +289,7 @@ def log_samples(
image = attribute_image_grid(
samples, image_size=self.image_size, ncols=self.ncols
)
log_image(logger, f"{self.log_key}_{mode}", image)
log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())


class LogTextCallback(LogSamplesCallback[Mapping[str, torch.Tensor]]):
Expand Down Expand Up @@ -350,7 +358,7 @@ def log_samples(
samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True
)
text = [[t.replace("<pad>", "")] for t in text]
log_text(logger, f"{self.log_key}_{mode}_str", ["text"], text)
log_text(logger, f"{self.log_key}_{mode}_str", ["text"], text, self.get_step())


class LogVisualCallback(LogSamplesCallback[torch.Tensor]):
Expand Down Expand Up @@ -456,15 +464,21 @@ def log_samples(
domain["tokens"].detach().cpu().tolist(), skip_special_tokens=True
)
text = [[t.replace("<pad>", "")] for t in text]
log_text(logger, f"{self.log_key}_{mode}_str", ["text"], text)
log_text(
logger,
f"{self.log_key}_{mode}_str",
["text"],
text,
self.get_step(),
)
elif domain_name == "attr":
assert isinstance(domain, list)
image = attribute_image_grid(
domain,
image_size=self.image_size,
ncols=self.ncols,
)
log_image(logger, f"{self.log_key}_{mode}", image)
log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())


def batch_to_device(
Expand Down Expand Up @@ -522,6 +536,11 @@ def __init__(
self.tokenizer = None
if vocab is not None and merges is not None:
self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
self._global_step = 0

def get_step(self):
self._global_step += 1
return self._global_step - 1

def to(
self,
Expand Down Expand Up @@ -695,7 +714,7 @@ def log_visual_samples(
mode: str,
) -> None:
images = make_grid(samples, nrow=self.ncols, pad_value=1)
log_image(logger, f"{self.log_key}/{mode}", images)
log_image(logger, f"{self.log_key}/{mode}", images, self.get_step())

def log_attribute_samples(
self,
Expand All @@ -708,7 +727,7 @@ def log_attribute_samples(
image_size=self.image_size,
ncols=self.ncols,
)
log_image(logger, f"{self.log_key}/{mode}", image)
log_image(logger, f"{self.log_key}/{mode}", image, self.get_step())

def log_text_samples(
self,
Expand All @@ -721,4 +740,4 @@ def log_text_samples(
samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True
)
text = [[t.replace("<pad>", "")] for t in text]
log_text(logger, f"{self.log_key}/{mode}", ["text"], text)
log_text(logger, f"{self.log_key}/{mode}", ["text"], text, self.get_step())

0 comments on commit dd4afc3

Please sign in to comment.