diff --git a/rl_games/algos_torch/running_mean_std.py b/rl_games/algos_torch/running_mean_std.py index 9bedc544..b6996fd1 100644 --- a/rl_games/algos_torch/running_mean_std.py +++ b/rl_games/algos_torch/running_mean_std.py @@ -45,7 +45,7 @@ def _update_mean_var_count_from_moments(self, mean, var, count, batch_mean, batc def forward(self, input, denorm=False, mask=None): if self.training: if mask is not None: - mean, var = torch_ext.get_mean_std_with_masks(input, mask) + mean, var = torch_ext.get_mean_var_with_masks(input, mask) else: mean = input.mean(self.axis) # along channel axis var = input.var(self.axis)