diff --git a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py index 7eb286b496fc..4911ecb571af 100755 --- a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py +++ b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py @@ -48,9 +48,6 @@ class ModelArguments: freeze_feature_extractor: Optional[bool] = field( default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} ) - gradient_checkpointing: Optional[bool] = field( - default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."} - ) verbose_logging: Optional[bool] = field( default=False, metadata={"help": "Whether to log verbose messages or not."}, @@ -356,7 +353,6 @@ def normalize(batch): config = Wav2Vec2Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, - gradient_checkpointing=model_args.gradient_checkpointing, ) if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": @@ -366,6 +362,10 @@ def normalize(batch): model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + # Activate gradient checkpointing if needed + if training_args.gradient_checkpointing: + model.gradient_checkpointing_enable() + data_collator = FlaxDataCollatorForWav2Vec2Pretraining( model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of )