Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast to >=float32 tensor when passing scalar to self.log #19046

Merged
merged 7 commits into from
Nov 24, 2023
Merged
12 changes: 11 additions & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,18 @@ def __check_not_nested(value: dict, name: str) -> None:
def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")

@staticmethod
MF-FOOM marked this conversation as resolved.
Show resolved Hide resolved
def __get_default_high_precision_dtype() -> torch.dtype:
"""The default dtype for new tensors, but no lower than float32."""
dtype = torch.get_default_dtype()
return dtype if dtype in (torch.float32, torch.float64) else torch.float32

def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value, device=self.device)
value = (
value.clone().detach()
if isinstance(value, Tensor)
else torch.tensor(value, device=self.device, dtype=self.__get_default_high_precision_dtype())
)
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
Expand Down
Loading