Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Fix broken gradients logging and add lr logging to tensorboard (#1158)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1158

This should help to monitor lr when using warmup/annealing etc

Reviewed By: geof90

Differential Revision: D18624642

fbshipit-source-id: 53f3bbf73c285fb88cd81f260771e31c0083e4c9
  • Loading branch information
arbabu123 authored and facebook-github-bot committed Nov 21, 2019
1 parent d63edff commit 3a0ec8a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
15 changes: 9 additions & 6 deletions pytext/metric_reporters/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def report(
context,
meta,
model,
optimizer,
*args,
):
"""
Expand Down Expand Up @@ -213,17 +214,19 @@ def report(
self.add_scalars(prefix, metrics, epoch)

if stage == Stage.TRAIN:
if optimizer is not None:
for idx, param_group in enumerate(optimizer.param_groups):
self.summary_writer.add_scalar(
f"optimizer.lr.param_group.{idx}", param_group["lr"], epoch
)
for key, val in model.named_parameters():
if val is not None and len(val) > 0 and not (val == 0).all():
limit = 9.9e19
grad = val.grad
val = torch.clamp(val.float(), -limit, limit)
self.summary_writer.add_histogram(key, val, epoch)
if (
val.grad is not None
and len(val.grad) > 0
and not (val.grad == 0).all()
):
grad = torch.clamp(val.grad.float(), -limit, limit)
if grad is not None and len(grad) > 0 and not (grad == 0).all():
grad = torch.clamp(grad.float(), -limit, limit)
self.summary_writer.add_histogram(
key + "_gradients", grad, epoch
)
Expand Down
5 changes: 4 additions & 1 deletion pytext/metric_reporters/metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ def get_meta(self):
"""
return {}

def report_metric(self, model, stage, epoch, reset=True, print_to_channels=True):
def report_metric(
self, model, stage, epoch, reset=True, print_to_channels=True, optimizer=None
):
"""
Calculate metrics and average loss, report all statistic data to channels
Expand Down Expand Up @@ -241,6 +243,7 @@ def report_metric(self, model, stage, epoch, reset=True, print_to_channels=True)
self.all_context,
self.get_meta(),
model,
optimizer,
)

if reset:
Expand Down
8 changes: 7 additions & 1 deletion pytext/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,13 @@ def run_epoch(
if report_metric:
with timing.time("report metrics"):
metrics = metric_reporter.report_metric(
model, state.stage, state.epoch, print_to_channels=(state.rank == 0)
model,
state.stage,
state.epoch,
print_to_channels=(state.rank == 0),
optimizer=getattr(
state, "optimizer", None
), # optimizer is not present during test
)
else:
metric_reporter._reset()
Expand Down

0 comments on commit 3a0ec8a

Please sign in to comment.