diff --git a/fairseq/modules/native_multihead_attention.py b/fairseq/modules/native_multihead_attention.py index fac4c28873..d6b8607588 100644 --- a/fairseq/modules/native_multihead_attention.py +++ b/fairseq/modules/native_multihead_attention.py @@ -111,17 +111,6 @@ def __init__( self.init_incremental_state() - def is_first_step(self, saved_state): - if saved_state is None: - return True - elif ( - saved_state["prev_key"] is not None - and saved_state["prev_key"].shape[2] == 1 - ): - return True - else: - return False - def forward( self, query: Tensor,