Skip to content

Commit

Permalink
fix device for dummy tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
eqy authored and crcrpar committed Apr 19, 2023
1 parent d94d55a commit 62a2ff9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion apex/normalization/instance_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
factory_kwargs = {'device': device, 'dtype': dtype}
super(_InstanceNormNVFuser, self).__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
self.dummy = torch.empty([], device='cuda')
self.dummy = torch.empty([], device=device)

def _check_input_dim(self, input):
raise NotImplementedError
Expand Down

0 comments on commit 62a2ff9

Please sign in to comment.