Skip to content
Draft
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3538,10 +3538,67 @@ def aten_feature_alpha_dropout(input: TensorType, p: float, train: bool) -> Tens
raise NotImplementedError()


def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType:
@torch_op("aten::feature_dropout", trace_only=True)
def aten_feature_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat:
"""feature_dropout(Tensor input, float p, bool train) -> Tensor"""

raise NotImplementedError()
# Feature dropout applies dropout to entire feature maps/channels
# rather than individual elements

# Use ONNX operations to handle control flow
# In inference mode or when p=0, return input unchanged
should_dropout = op.And(train, p > 0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do

if p == 0 or not train:
    return input

# Do this other thing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added early return conditions as requested. The function now returns input directly when p == 0 or not train, avoiding unnecessary computations. Commit: af8d622


# Get input shape
input_shape = op.Shape(input)
ndim = op.Size(input_shape)

# Create mask shape for feature dropout
# For 2D tensors [N, C]: mask shape is [N, C]
# For higher dim tensors [N, C, ...]: mask shape is [N, C, 1, 1, ...]
batch_size = op.Gather(input_shape, [0])
channel_size = op.Gather(input_shape, [1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_size = op.Gather(input_shape, [0])
channel_size = op.Gather(input_shape, [1])
batch_size = op.Shape(input, start=0, end=0)
batch_size = op.Shape(input, start=1, end=1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use op.Shape with start/end parameters for more efficient shape extraction. Fixed the typo in the suggestion where both lines assigned to batch_size. Commit: 96f78cd


# Create the appropriate mask shape based on tensor dimensions
is_2d = op.Equal(ndim, 2)

# For 2D case, mask_shape = [N, C]
mask_shape_2d = op.Concat(batch_size, channel_size, axis=0)

# For higher dimensions, mask_shape = [N, C, 1, 1, ...]
spatial_dims_count = op.Sub(ndim, 2)
ones_for_spatial = op.ConstantOfShape(
op.Reshape(spatial_dims_count, [1]),
value=1
)
mask_shape_nd = op.Concat(batch_size, channel_size, ones_for_spatial, axis=0)

# Select appropriate mask shape
mask_shape = op.Where(is_2d, mask_shape_2d, mask_shape_nd)

# Generate random uniform values between 0 and 1
random_vals = op.RandomUniformLike(
op.ConstantOfShape(mask_shape, value=0.0),
dtype=1, # float32
low=0.0,
high=1.0
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use op.Dropout for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion! I've refactored the implementation to use op.Dropout instead of manual random generation. The new approach creates a dummy tensor with the appropriate feature dropout shape [N, C, 1, 1, ...], applies op.Dropout to it to get the scaled mask, then multiplies with the input. This is much cleaner and leverages the existing ONNX Dropout operation for all the logic (training mode, scaling, random generation). Commit: 29f1968


# Create binary mask: 1 where random_vals >= p, 0 otherwise
mask = op.Cast(random_vals >= p, to=input.dtype)

# Scale by 1/(1-p) to maintain expected value
scale = op.Div(1.0, op.Sub(1.0, p))
scaled_mask = op.Mul(mask, scale)

# Apply dropout only if we should dropout, otherwise use all-ones mask
ones_mask = op.ConstantOfShape(mask_shape, value=1.0)
final_mask = op.Where(should_dropout, scaled_mask, ones_mask)

# Apply mask to input (broadcasting will handle different shapes)
result = op.Mul(input, final_mask)

return result


@torch_op(("aten::fill.Tensor", "aten::fill.Scalar"))
Expand Down