@@ -378,39 +378,35 @@ def _compute_self_attention_mask(
378378 decoder_sequence ,
379379 decoder_padding_mask ,
380380 decoder_attention_mask ,
381- use_causal_mask ,
382381 self_attention_cache ,
383382 self_attention_cache_update_index ,
384383 ):
385384 decoder_mask = merge_padding_and_attention_mask (
386385 decoder_sequence , decoder_padding_mask , decoder_attention_mask
387386 )
388- if use_causal_mask :
389- batch_size = ops .shape (decoder_sequence )[0 ]
390- input_length = output_length = ops .shape (decoder_sequence )[1 ]
391- # We need to handle a rectangular causal mask when doing cached
392- # decoding. For generative inference, `decoder_sequence` will
393- # generally be length 1, and `cache` will be the full generation
394- # length.
395- if self_attention_cache is not None :
396- input_length = ops .shape (self_attention_cache )[2 ]
397-
398- causal_mask = compute_causal_mask (
399- batch_size ,
400- input_length ,
401- output_length ,
402- (
403- 0
404- if self_attention_cache_update_index is None
405- else self_attention_cache_update_index
406- ),
407- )
408- return (
409- ops .minimum (decoder_mask , causal_mask )
410- if decoder_mask is not None
411- else causal_mask
412- )
413- return decoder_mask
387+ batch_size = ops .shape (decoder_sequence )[0 ]
388+ input_length = output_length = ops .shape (decoder_sequence )[1 ]
389+ # We need to handle a rectangular causal mask when doing cached
390+ # decoding. For generative inference, `decoder_sequence` will
391+ # generally be length 1, and `cache` will be the full generation length.
392+ if self_attention_cache is not None :
393+ input_length = ops .shape (self_attention_cache )[2 ]
394+
395+ cache_update_index = (
396+ 0
397+ if self_attention_cache_update_index is None
398+ else self_attention_cache_update_index
399+ )
400+
401+ causal_mask = compute_causal_mask (
402+ batch_size , input_length , output_length , cache_update_index
403+ )
404+
405+ return (
406+ ops .minimum (decoder_mask , causal_mask )
407+ if decoder_mask is not None
408+ else causal_mask
409+ )
414410
415411 def build (self , input_shape ):
416412 """
0 commit comments