diff --git a/tf2onnx/onnx_opset/signal.py b/tf2onnx/onnx_opset/signal.py index c65d960d6..a1a174bae 100644 --- a/tf2onnx/onnx_opset/signal.py +++ b/tf2onnx/onnx_opset/signal.py @@ -40,6 +40,61 @@ class CommonFFTOp: onnx_pb.TensorProto.COMPLEX128, ] + @classmethod + def any_version_dft(cls, opset, ctx, node, is_rfft=True, **kwargs): + """ + Implementation using ONNX DFT operator (opset 17+). + Much simpler than the manual implementation. + """ + # Get input tensor and optional fft_length + input_tensor = node.input[0] + fft_length_tensor = node.input[1] if len(node.input) > 1 else None + + # Get input properties + input_dtype = ctx.get_dtype(input_tensor) + + # Validate input + utils.make_sure(input_dtype in CommonFFTOp.supported_dtypes, + "Unsupported input type for FFT/RFFT.") + + # For real inputs, we need to add a dimension for complex representation + # Input shape: [..., signal_length] -> [..., signal_length, 1] + ones_const = ctx.make_const(name=utils.make_name('ones_const'), + np_val=np.array([1], dtype=np.int64)) + input_shape_node = ctx.make_node('Shape', [input_tensor], + name=utils.make_name('input_shape')) + new_shape = ctx.make_node('Concat', [input_shape_node.output[0], ones_const.name], + attr={'axis': 0}, name=utils.make_name('new_shape')) + input_expanded = ctx.make_node('Reshape', [input_tensor, new_shape.output[0]], + name=utils.make_name('input_expanded')) + + # Perform DFT on the last dimension (signal dimension) + # axis = -2 (signal dimension, -1 is reserved for complex representation) + if fft_length_tensor is not None: + # Reshape fft_length to ensure it's a scalar (rank 0 tensor) as required by DFT + empty_shape = ctx.make_const(name=utils.make_name('empty_shape'), + np_val=np.array([], dtype=np.int64)) + fft_length_scalar = ctx.make_node('Reshape', + [fft_length_tensor, empty_shape.name], + name=utils.make_name('fft_length_scalar')) + dft_inputs = [input_expanded.output[0], fft_length_scalar.output[0]] + else: + dft_inputs = [input_expanded.output[0]] + + dft_attrs = {'axis': -2} + if is_rfft: + dft_attrs['onesided'] = 1 + + dft_result = ctx.make_node('DFT', dft_inputs, + attr=dft_attrs, + name=utils.make_name('dft_result')) + + # Remove the original node and replace with DFT result + ctx.remove_node(node.name) + ctx.replace_all_inputs(node.output[0], dft_result.output[0]) + + return dft_result + @classmethod def any_version(cls, const_length, opset, ctx, node, axis=None, fft_length=None, dim=None, onnx_dtype=None, shape=None, @@ -303,20 +358,24 @@ def DFT_real(x, fft_length=None): "MatMul", inputs=[onx_real_imag_part_name, trx], name=utils.make_name('CPLX_M_' + node_name + 'rfft')) - if not const_length or (axis == 1 or (fft_length < shape_n and axis != 0)): - size = fft_length // 2 + 1 - if opset >= 10: - cst_axis = ctx.make_const( - name=utils.make_name('CPLX_csta'), np_val=np.array([-2], dtype=np.int64)) - cst_length = ctx.make_const( - name=utils.make_name('CPLX_cstl'), np_val=np.array([size], dtype=np.int64)) - sliced_mult = ctx.make_node( - "Slice", inputs=[mult.output[0], zero.name, cst_length.name, cst_axis.name], - name=utils.make_name('CPLX_S2_' + node_name + 'rfft')) + if not const_length or (axis == 1 or (fft_length is not None and fft_length < shape_n and axis != 0)): + if fft_length is not None: + size = fft_length // 2 + 1 + if opset >= 10: + cst_axis = ctx.make_const( + name=utils.make_name('CPLX_csta'), np_val=np.array([-2], dtype=np.int64)) + cst_length = ctx.make_const( + name=utils.make_name('CPLX_cstl'), np_val=np.array([size], dtype=np.int64)) + sliced_mult = ctx.make_node( + "Slice", inputs=[mult.output[0], zero.name, cst_length.name, cst_axis.name], + name=utils.make_name('CPLX_S2_' + node_name + 'rfft')) + else: + sliced_mult = ctx.make_node( + "Slice", inputs=[mult.output[0]], attr=dict(starts=[0], ends=[size], axes=[-2]), + name=utils.make_name('CPLX_S2_' + node_name + 'rfft')) else: - sliced_mult = ctx.make_node( - "Slice", inputs=[mult.output[0]], attr=dict(starts=[0], ends=[size], axes=[-2]), - name=utils.make_name('CPLX_S2_' + node_name + 'rfft')) + # Dynamic case - use shape inference to determine size + sliced_mult = mult else: sliced_mult = mult @@ -358,6 +417,11 @@ def version_13(cls, ctx, node, **kwargs): # Unsqueeze changed in opset 13. return cls.any_version(True, 13, ctx, node, **kwargs) + @classmethod + def version_17(cls, ctx, node, **kwargs): + # DFT operator available in opset 17+, use native ONNX DFT implementation + return cls.any_version_dft(17, ctx, node, is_rfft=True, **kwargs) + @tf_op("FFT") class FFTOp(CommonFFTOp): @@ -371,9 +435,128 @@ def version_1(cls, ctx, node, **kwargs): def version_13(cls, ctx, node, **kwargs): return cls.any_version(False, 13, ctx, node, **kwargs) + @classmethod + def version_17(cls, ctx, node, **kwargs): + # DFT operator available in opset 17+, use native ONNX DFT implementation + return cls.any_version_dft(17, ctx, node, is_rfft=False, **kwargs) + class CommonFFT2DOp(CommonFFTOp): + @classmethod + def any_version_2d_dft(cls, opset, ctx, node, **kwargs): + """ + Implementation using ONNX DFT operator (opset 17+). + This is much simpler and more efficient than the manual implementation. + """ + # Get input tensor and fft_length + input_tensor = node.input[0] + fft_length_tensor = node.input[1] + + # Cast fft_length to int64 to ensure compatibility + fft_length_tensor_int64 = ctx.make_node('Cast', [fft_length_tensor], + attr={'to': onnx_pb.TensorProto.INT64}, + name=utils.make_name('fft_length_cast')) + fft_length_tensor = fft_length_tensor_int64.output[0] + + # Get input properties + input_shape = ctx.get_shape(input_tensor) + input_dtype = ctx.get_dtype(input_tensor) + + # DFT-based implementation is more flexible and can work with various consumers + # No need to restrict consumer types like the manual implementation + + # Validate input + utils.make_sure(input_dtype in CommonFFT2DOp.supported_dtypes, + "Unsupported input type for RFFT2D.") + + # For RFFT2D, we need to perform DFT on the last two dimensions + # First, we need to add a dimension for complex representation if input is real + if input_shape is not None and len(input_shape) > 0: + # Add dimension for real/imaginary parts if input is real + # Input shape: [..., height, width] -> [..., height, width, 1] + ones_const = ctx.make_const(name=utils.make_name('ones_const'), + np_val=np.array([1], dtype=np.int64)) + input_shape_node = ctx.make_node('Shape', [input_tensor], + name=utils.make_name('input_shape')) + new_shape = ctx.make_node('Concat', [input_shape_node.output[0], ones_const.name], + attr={'axis': 0}, name=utils.make_name('new_shape')) + input_expanded = ctx.make_node('Reshape', [input_tensor, new_shape.output[0]], + name=utils.make_name('input_expanded')) + current_input = input_expanded.output[0] + else: + # Dynamic shape case + input_shape_node = ctx.make_node('Shape', [input_tensor], + name=utils.make_name('input_shape')) + ones_const = ctx.make_const(name=utils.make_name('ones_const'), + np_val=np.array([1], dtype=np.int64)) + new_shape = ctx.make_node('Concat', [input_shape_node.output[0], ones_const.name], + attr={'axis': 0}, name=utils.make_name('new_shape')) + input_expanded = ctx.make_node('Reshape', [input_tensor, new_shape.output[0]], + name=utils.make_name('input_expanded')) + current_input = input_expanded.output[0] + + # Extract fft_length for each dimension (assuming fft_length has shape [2]) + zero_const = ctx.make_const(name=utils.make_name('zero_const'), + np_val=np.array([0], dtype=np.int64)) + one_const = ctx.make_const(name=utils.make_name('one_const'), + np_val=np.array([1], dtype=np.int64)) + + height_fft_length = ctx.make_node('Gather', [fft_length_tensor, zero_const.name], + attr={'axis': 0}, name=utils.make_name('height_fft_length')) + width_fft_length = ctx.make_node('Gather', [fft_length_tensor, one_const.name], + attr={'axis': 0}, name=utils.make_name('width_fft_length')) + + # Reshape fft_length values to ensure they're scalars (rank 0 tensors) as required by DFT + empty_shape = ctx.make_const(name=utils.make_name('empty_shape'), + np_val=np.array([], dtype=np.int64)) + height_fft_length_scalar = ctx.make_node('Reshape', + [height_fft_length.output[0], empty_shape.name], + name=utils.make_name('height_fft_length_scalar')) + width_fft_length_scalar = ctx.make_node('Reshape', + [width_fft_length.output[0], empty_shape.name], + name=utils.make_name('width_fft_length_scalar')) + + # Perform DFT on the second-to-last dimension (height) + # Use onesided=0 for intermediate transforms + # axis = -3 (height dimension) + dft_height = ctx.make_node('DFT', [current_input, height_fft_length_scalar.output[0]], + attr={'axis': -3, 'onesided': 0}, + name=utils.make_name('dft_height')) + + # Perform DFT on the last dimension (width) + # Use onesided=1 for the second transform to get the one-sided FFT result + # axis = -2 (width dimension) + dft_width = ctx.make_node('DFT', [dft_height.output[0], width_fft_length_scalar.output[0]], + attr={'axis': -2, 'onesided': 1}, + name=utils.make_name('dft_width')) + + # The DFT output has shape [..., height, width//2+1, 2] where last dim is [real, imag] + # We only need the real part (index 0) of the complex dimension + zero_const_slice = ctx.make_const(name=utils.make_name('zero_slice'), + np_val=np.array([0], dtype=np.int64)) + one_const_end = ctx.make_const(name=utils.make_name('one_const_end'), + np_val=np.array([1], dtype=np.int64)) + minus_one_axis = ctx.make_const(name=utils.make_name('minus_one_axis'), + np_val=np.array([-1], dtype=np.int64)) + + # Slice to get real part: dft_width[..., 0:1] + dft_real = ctx.make_node('Slice', + [dft_width.output[0], zero_const_slice.name, one_const_end.name, minus_one_axis.name], + name=utils.make_name('dft_real_part')) + + # Squeeze the last dimension to remove the complex dimension + dft_result = GraphBuilder(ctx).make_squeeze( + {'data': dft_real.output[0], 'axes': [-1]}, + name=utils.make_name('dft_result_squeezed'), + return_node=True) + + # Remove the original node and replace with DFT result + ctx.remove_node(node.name) + ctx.replace_all_inputs(node.output[0], dft_result.output[0]) + + return dft_result + @classmethod def any_version_2d(cls, const_length, opset, ctx, node, **kwargs): """ @@ -463,8 +646,8 @@ def onnx_rfft_2d_any_test(x, fft_length): consumers = ctx.find_output_consumers(node.output[0]) consumer_types = set(op.type for op in consumers) utils.make_sure( - consumer_types == {'ComplexAbs'}, - "Current implementation of RFFT2D only allows ComplexAbs as consumer not %r", + consumer_types in [{'ComplexAbs'}, {'Squeeze'}], + "Current implementation of RFFT2D only allows ComplexAbs or Squeeze as consumer not %r", consumer_types) oldnode = node @@ -905,6 +1088,11 @@ def version_13(cls, ctx, node, **kwargs): # Unsqueeze changed in opset 13. return cls.any_version_2d(True, 13, ctx, node, **kwargs) + @classmethod + def version_17(cls, ctx, node, **kwargs): + # DFT operator available in opset 17+, use native ONNX DFT implementation + return cls.any_version_2d_dft(17, ctx, node, **kwargs) + @tf_op("ComplexAbs") class ComplexAbsOp: