diff --git a/apex/normalization/instance_norm.py b/apex/normalization/instance_norm.py index 7f9ee1fd4..ce76ea6dd 100644 --- a/apex/normalization/instance_norm.py +++ b/apex/normalization/instance_norm.py @@ -129,6 +129,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def forward(self, input: Tensor) -> Tensor: assert input.is_cuda, "NVFuser InstanceNorm is CUDA only" self._check_input_dim(input) + if self.dummy.device != input.device: + self.dummy = torch.empty([], device=input.device) if self.running_mean is not None: out = InstanceNormNVFuserFunction.apply( input, self.weight if self.weight is not None else self.dummy,