diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 3b096c36db8e..0ea405ed6e7a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -640,6 +640,11 @@ def forward( hidden_states, present_key_value_state = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + # clamp inf values to enable fp16 training + if torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: # the actual query length is unknown for cross attention @@ -661,6 +666,10 @@ def forward( output_attentions=output_attentions, ) hidden_states = cross_attention_outputs[0] + if torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + # Combine self attn and cross attn key value states if present_key_value_state is not None: present_key_value_state = present_key_value_state + cross_attention_outputs[1] @@ -670,6 +679,9 @@ def forward( # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) + if torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) outputs = outputs + (present_key_value_state,) + attention_outputs