diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index 4b3104b6bdea..084587ba2a3b 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -218,7 +218,7 @@ def forward(ctx, output_b, norm_w, norm_b, - config.training, + config.training and config.is_grad_enabled, config.pre_layer_norm, config.attn_dropout_checkpoint, config.normalize_invertible,