diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index ec467ee43..1d7aed1b9 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -301,6 +301,7 @@ def __init__( ): super().__init__() + self.use_sdpa = use_sdpa self.tie_word_embedding = tie_word_embedding self.left_decoder = TransformerDecoder( vocab_size,