diff --git a/apex/normalization/instance_norm.py b/apex/normalization/instance_norm.py index 48870b33e..7f9ee1fd4 100644 --- a/apex/normalization/instance_norm.py +++ b/apex/normalization/instance_norm.py @@ -91,7 +91,7 @@ def __init__( factory_kwargs = {'device': device, 'dtype': dtype} super(_InstanceNormNVFuser, self).__init__( num_features, eps, momentum, affine, track_running_stats, **factory_kwargs) - self.dummy = torch.empty([], device='cuda') + self.dummy = torch.empty([], device=device) def _check_input_dim(self, input): raise NotImplementedError