Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README_ko.md
Original file line number Diff line number Diff line change
Expand Up @@ -414,4 +414,4 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
url = "https://www.aclweb.org/anthology/2020.emnlp-demos.6",
pages = "38--45"
}
```
```
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class DeformableDetrConfig(PretrainedConfig):
based on the predictions from the previous layer.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
use_custom_kernel (`bool`, *optional*, defaults to `False`):
Whether to use custom CUDA kernel to speed up inference and training on GPU.
Comment on lines +119 to +120
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As said in the thread, let's not use a config parameter for this but just try to use the fast version and default to the slow one if the fast one is failing. If a user sets it to True and then pushes their model, anyone not having a GPU or a PyTorch installed without CUDA will get a failure when using this model.

Config parameters are usually bad for flags that depend on hardware/setups.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I don't think there needs to be any changes to the main branch then. Should I go ahead and close the PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's all good on main, then yes :-)


Examples:

Expand Down Expand Up @@ -177,6 +179,7 @@ def __init__(
giou_loss_coefficient=2,
eos_coefficient=0.1,
focal_alpha=0.25,
use_custom_kernel=False,
**kwargs
):
self.num_queries = num_queries
Expand Down Expand Up @@ -220,6 +223,7 @@ def __init__(
self.giou_loss_coefficient = giou_loss_coefficient
self.eos_coefficient = eos_coefficient
self.focal_alpha = focal_alpha
self.use_custom_kernel = use_custom_kernel
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

@property
Expand Down
70 changes: 38 additions & 32 deletions src/transformers/models/deformable_detr/modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@
from .load_custom import load_cuda_kernels


if is_vision_available():
from transformers.image_transforms import center_to_corners_format

if is_scipy_available():
from scipy.optimize import linear_sum_assignment

if is_timm_available():
from timm import create_model

logger = logging.get_logger(__name__)

# Move this to not compile only when importing, this needs to happen later, like in __init__.
Expand All @@ -60,8 +69,14 @@
else:
MultiScaleDeformableAttention = None

if is_vision_available():
from transformers.image_transforms import center_to_corners_format

_CONFIG_FOR_DOC = "DeformableDetrConfig"
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"

DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [
"sensetime/deformable-detr",
# See all Deformable DETR models at https://huggingface.co/models?filter=deformable-detr
]


class MultiScaleDeformableAttentionFunction(Function):
Expand Down Expand Up @@ -112,23 +127,6 @@ def backward(context, grad_output):
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None


if is_scipy_available():
from scipy.optimize import linear_sum_assignment

if is_timm_available():
from timm import create_model

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "DeformableDetrConfig"
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"

DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [
"sensetime/deformable-detr",
# See all Deformable DETR models at https://huggingface.co/models?filter=deformable-detr
]


@dataclass
class DeformableDetrDecoderOutput(ModelOutput):
"""
Expand Down Expand Up @@ -561,7 +559,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
Multiscale deformable attention as proposed in Deformable DETR.
"""

def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int, use_custom_kernel: bool):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
Expand All @@ -582,6 +580,7 @@ def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int)
self.n_levels = n_levels
self.n_heads = num_heads
self.n_points = n_points
self.use_custom_kernel = use_custom_kernel

self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points)
Expand Down Expand Up @@ -664,19 +663,24 @@ def forward(
)
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
try:
# GPU
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# CPU

# Use custom CUDA kernel to speed up on GPU
if self.use_custom_kernel:
try:
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
# Fall back to PyTorch implementation
except Exception:
output = ms_deform_attn_core_pytorch(value, spatial_shapes, sampling_locations, attention_weights)
else:
output = ms_deform_attn_core_pytorch(value, spatial_shapes, sampling_locations, attention_weights)

output = self.output_proj(output)

return output, attention_weights
Expand Down Expand Up @@ -808,6 +812,7 @@ def __init__(self, config: DeformableDetrConfig):
num_heads=config.encoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.encoder_n_points,
use_custom_kernel=config.use_custom_kernel,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
Expand Down Expand Up @@ -909,6 +914,7 @@ def __init__(self, config: DeformableDetrConfig):
num_heads=config.decoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.decoder_n_points,
use_custom_kernel=config.use_custom_kernel,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# feedforward neural networks
Expand Down