diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index e2235fc69055..43beb69f10f2 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -566,8 +566,12 @@ def forward( attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) do_cross_attention = self.is_decoder and encoder_hidden_states is not None @@ -593,8 +597,12 @@ def forward( hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Combine self attn and cross attn key value states @@ -608,8 +616,12 @@ def forward( hidden_states = self.layer[-1](hidden_states) # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 19cb83dac352..c056ef73cc1f 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -703,8 +703,12 @@ def forward( attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) do_cross_attention = self.is_decoder and encoder_hidden_states is not None @@ -730,8 +734,12 @@ def forward( hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Combine self attn and cross attn key value states @@ -745,8 +753,12 @@ def forward( hidden_states = self.layer[-1](hidden_states) # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,)