Skip to content

Commit 5dc07fd

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Cadence ops: Support strongly typed softmax (pytorch#15201)
Summary: As titled. Differential Revision: D84845481
1 parent 5d71c9b commit 5dc07fd

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def _validate_ref_impl_exists() -> None:
5353
# 1. be removed
5454
# 2. have a reference implementation added to ref_implementations.py
5555
_WARN_ONLY = {
56-
"cadence::_softmax_f32_f32",
5756
"cadence::quantized_softmax.per_tensor",
5857
"cadence::quantized_softmax",
5958
"cadence::quantized_w8a32_gru",
@@ -640,10 +639,10 @@ def register_fake(
640639
"int sampling_ratio, bool aligned) -> (Tensor out)"
641640
)
642641
lib.define(
643-
"_softmax_f32_f32(Tensor self, int dim, bool? half_to_float) -> (Tensor out)"
642+
"_softmax_f32_f32(Tensor self, int dim, bool? half_to_float = None) -> (Tensor out)"
644643
)
645644
lib.define(
646-
"_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)"
645+
"_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float = None, *, Tensor(a!) out) -> Tensor(a!)"
647646
)
648647

649648
lib.define(
@@ -2652,12 +2651,13 @@ def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_meta(
26522651

26532652
@register_fake("cadence::_softmax_f32_f32")
26542653
def softmax_f32_f32_meta(
2655-
self: torch.Tensor,
2654+
input_tensor: torch.Tensor,
26562655
dim: int,
2657-
dtype: torch.dtype,
26582656
half_to_float: Optional[bool] = None,
26592657
) -> torch.Tensor:
2660-
return self.new_empty(self.size(), dtype=self.dtype)
2658+
assert input_tensor.dtype == torch.float32, "input_tensor must be float32"
2659+
assert half_to_float is None, "half_to_float is not supported"
2660+
return input_tensor.new_empty(input_tensor.size(), dtype=torch.float32)
26612661

26622662

26632663
@register_fake("cadence::quantized_softmax")

backends/cadence/aot/ref_implementations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,3 +1979,14 @@ def linalg_svd(
19791979
assert compute_uv
19801980
U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver)
19811981
return U.contiguous(), S.contiguous(), Vh.contiguous()
1982+
1983+
1984+
@impl_tracked(m, "_softmax_f32_f32")
1985+
def softmax_f32_f32(
1986+
input_tensor: torch.Tensor,
1987+
dim: int,
1988+
half_to_float: bool | None = None,
1989+
) -> torch.Tensor:
1990+
assert input_tensor.dtype == torch.float32, "input_tensor must be float32"
1991+
assert half_to_float is None, "half_to_float is not supported"
1992+
return torch.nn.functional.softmax(input_tensor, dim=dim, dtype=torch.float32)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2885,3 +2885,12 @@ def test_quantized_layer_norm(self) -> None:
28852885
output_scale,
28862886
output_zero_point,
28872887
)
2888+
2889+
def test_softmax_f32_f32(self) -> None:
2890+
# Just a wrapper around torch.nn.functional.softmax, so just ensure that it runs
2891+
input_tensor = torch.tensor(
2892+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32
2893+
)
2894+
output = torch.ops.cadence._softmax_f32_f32(input_tensor, dim=1)
2895+
self.assertEqual(output.dtype, torch.float32)
2896+
self.assertEqual(output.shape, input_tensor.shape)

0 commit comments

Comments
 (0)