Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,12 @@ def forward(
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.warning(
logger.warning(

This should be warn_once here (or warning_once) can't rememver the syntax.

@gante gante Mar 8, 2023

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning_once :) (@soma2000-lang your fork is probably a bit stale, we moved from logger.warning to logger.warning_once on this statement ~2 weeks ago)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante done !

"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..."
Comment thread
gante marked this conversation as resolved.
Outdated
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand All @@ -1053,12 +1059,6 @@ def forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache ="
" False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down