@@ -53,7 +53,6 @@ def _validate_ref_impl_exists() -> None:
5353 # 1. be removed
5454 # 2. have a reference implementation added to ref_implementations.py
5555 _WARN_ONLY = {
56- "cadence::_softmax_f32_f32" ,
5756 "cadence::quantized_softmax.per_tensor" ,
5857 "cadence::quantized_softmax" ,
5958 "cadence::quantized_w8a32_gru" ,
@@ -640,10 +639,10 @@ def register_fake(
640639 "int sampling_ratio, bool aligned) -> (Tensor out)"
641640)
642641lib .define (
643- "_softmax_f32_f32(Tensor self, int dim, bool? half_to_float) -> (Tensor out)"
642+ "_softmax_f32_f32(Tensor self, int dim, bool? half_to_float = None ) -> (Tensor out)"
644643)
645644lib .define (
646- "_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)"
645+ "_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float = None , *, Tensor(a!) out) -> Tensor(a!)"
647646)
648647
649648lib .define (
@@ -2652,12 +2651,13 @@ def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_meta(
26522651
26532652@register_fake ("cadence::_softmax_f32_f32" )
26542653def softmax_f32_f32_meta (
2655- self : torch .Tensor ,
2654+ input_tensor : torch .Tensor ,
26562655 dim : int ,
2657- dtype : torch .dtype ,
26582656 half_to_float : Optional [bool ] = None ,
26592657) -> torch .Tensor :
2660- return self .new_empty (self .size (), dtype = self .dtype )
2658+ assert input_tensor .dtype == torch .float32 , "input_tensor must be float32"
2659+ assert half_to_float is None , "half_to_float is not supported"
2660+ return input_tensor .new_empty (input_tensor .size (), dtype = torch .float32 )
26612661
26622662
26632663@register_fake ("cadence::quantized_softmax" )
0 commit comments