diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 35f8bbd3473a..303267f0494d 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -417,7 +417,7 @@ def _link_hp_params(self, hp_frag_address.numel) for key, value in self.optimizer.state[flat_hp_partition].items() - if torch.is_tensor(value) + if torch.is_tensor(value) and value.dim() > 0 } lp_frag_address = fragment_address(start=fragment_start - lp_start,