From 324cee00dca82ac6a88266371383b538f3f44233 Mon Sep 17 00:00:00 2001 From: eqy Date: Fri, 6 Jan 2023 11:54:07 -0800 Subject: [PATCH] Update instance_norm.py --- apex/normalization/instance_norm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apex/normalization/instance_norm.py b/apex/normalization/instance_norm.py index 7f9ee1fd4..ce76ea6dd 100644 --- a/apex/normalization/instance_norm.py +++ b/apex/normalization/instance_norm.py @@ -129,6 +129,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def forward(self, input: Tensor) -> Tensor: assert input.is_cuda, "NVFuser InstanceNorm is CUDA only" self._check_input_dim(input) + if self.dummy.device != input.device: + self.dummy = torch.empty([], device=input.device) if self.running_mean is not None: out = InstanceNormNVFuserFunction.apply( input, self.weight if self.weight is not None else self.dummy,