Skip to content

Commit

Permalink
Update instance_norm.py
Browse files Browse the repository at this point in the history
  • Loading branch information
eqy authored and crcrpar committed Apr 19, 2023
1 parent 988cafe commit 324cee0
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions apex/normalization/instance_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
def forward(self, input: Tensor) -> Tensor:
assert input.is_cuda, "NVFuser InstanceNorm is CUDA only"
self._check_input_dim(input)
if self.dummy.device != input.device:
self.dummy = torch.empty([], device=input.device)
if self.running_mean is not None:
out = InstanceNormNVFuserFunction.apply(
input, self.weight if self.weight is not None else self.dummy,
Expand Down

0 comments on commit 324cee0

Please sign in to comment.