Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/models/swin2sr/modeling_swin2sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[
relative_coords_table = (
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
)
relative_coords_table = relative_coords_table.to(
next(self.continuous_position_bias_mlp.parameters()).dtype
) # set to same dtype as mlp weight
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)

# get pair-wise relative position index for each token inside the window
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/swinv2/modeling_swinv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,9 @@ def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[
relative_coords_table = (
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
)
relative_coords_table = relative_coords_table.to(
next(self.continuous_position_bias_mlp.parameters()).dtype
) # set to same dtype as mlp weight
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)

# get pair-wise relative position index for each token inside the window
Expand Down