diff --git a/captum/attr/_core/integrated_gradients.py b/captum/attr/_core/integrated_gradients.py index 3f00b612d1..7f9b29085a 100644 --- a/captum/attr/_core/integrated_gradients.py +++ b/captum/attr/_core/integrated_gradients.py @@ -359,7 +359,7 @@ def _attribute( # calling contiguous to avoid `memory whole` problems scaled_grads = [ grad.contiguous().view(n_steps, -1) - * torch.tensor(step_sizes).view(n_steps, 1).to(grad.device) + * torch.tensor(step_sizes).float().view(n_steps, 1).to(grad.device) for grad in grads ]