Skip to content

Commit

Permalink
Add .detach() in scalar_to_float() (#551)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #551

Calling from an integration test with full training loop that used in mem logger was throwing this exception:
```
RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
```

This diff adds the `.detach()`

Reviewed By: JKSenthil

Differential Revision: D49600572

fbshipit-source-id: 996d5216612fe741a078c8b5fea94b76e7a5c84e
  • Loading branch information
gunchu authored and facebook-github-bot committed Sep 25, 2023
1 parent 0e4c97d commit cd2c4f8
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchtnt/utils/loggers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def scalar_to_float(scalar: Scalar) -> float:
f"Scalar tensor must contain a single item, {numel} given."
)

return float(scalar.cpu().numpy().item())
return float(scalar.cpu().detach().numpy().item())
elif isinstance(scalar, ndarray):
numel = scalar.size
if numel != 1:
Expand Down

0 comments on commit cd2c4f8

Please sign in to comment.