Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 203 additions & 15 deletions tf2onnx/onnx_opset/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,61 @@ class CommonFFTOp:
onnx_pb.TensorProto.COMPLEX128,
]

@classmethod
def any_version_dft(cls, opset, ctx, node, is_rfft=True, **kwargs):
"""
Implementation using ONNX DFT operator (opset 17+).
Much simpler than the manual implementation.
"""
# Get input tensor and optional fft_length
input_tensor = node.input[0]
fft_length_tensor = node.input[1] if len(node.input) > 1 else None

# Get input properties
input_dtype = ctx.get_dtype(input_tensor)

# Validate input
utils.make_sure(input_dtype in CommonFFTOp.supported_dtypes,
"Unsupported input type for FFT/RFFT.")

# For real inputs, we need to add a dimension for complex representation
# Input shape: [..., signal_length] -> [..., signal_length, 1]
ones_const = ctx.make_const(name=utils.make_name('ones_const'),
np_val=np.array([1], dtype=np.int64))
input_shape_node = ctx.make_node('Shape', [input_tensor],
name=utils.make_name('input_shape'))
new_shape = ctx.make_node('Concat', [input_shape_node.output[0], ones_const.name],
attr={'axis': 0}, name=utils.make_name('new_shape'))
input_expanded = ctx.make_node('Reshape', [input_tensor, new_shape.output[0]],
name=utils.make_name('input_expanded'))

# Perform DFT on the last dimension (signal dimension)
# axis = -2 (signal dimension, -1 is reserved for complex representation)
if fft_length_tensor is not None:
# Reshape fft_length to ensure it's a scalar (rank 0 tensor) as required by DFT
empty_shape = ctx.make_const(name=utils.make_name('empty_shape'),
np_val=np.array([], dtype=np.int64))
fft_length_scalar = ctx.make_node('Reshape',
[fft_length_tensor, empty_shape.name],
name=utils.make_name('fft_length_scalar'))
dft_inputs = [input_expanded.output[0], fft_length_scalar.output[0]]
else:
dft_inputs = [input_expanded.output[0]]

dft_attrs = {'axis': -2}
if is_rfft:
dft_attrs['onesided'] = 1

dft_result = ctx.make_node('DFT', dft_inputs,
attr=dft_attrs,
name=utils.make_name('dft_result'))

# Remove the original node and replace with DFT result
ctx.remove_node(node.name)
ctx.replace_all_inputs(node.output[0], dft_result.output[0])

return dft_result

@classmethod
def any_version(cls, const_length, opset, ctx, node, axis=None,
fft_length=None, dim=None, onnx_dtype=None, shape=None,
Expand Down Expand Up @@ -303,20 +358,24 @@ def DFT_real(x, fft_length=None):
"MatMul", inputs=[onx_real_imag_part_name, trx],
name=utils.make_name('CPLX_M_' + node_name + 'rfft'))

if not const_length or (axis == 1 or (fft_length < shape_n and axis != 0)):
size = fft_length // 2 + 1
if opset >= 10:
cst_axis = ctx.make_const(
name=utils.make_name('CPLX_csta'), np_val=np.array([-2], dtype=np.int64))
cst_length = ctx.make_const(
name=utils.make_name('CPLX_cstl'), np_val=np.array([size], dtype=np.int64))
sliced_mult = ctx.make_node(
"Slice", inputs=[mult.output[0], zero.name, cst_length.name, cst_axis.name],
name=utils.make_name('CPLX_S2_' + node_name + 'rfft'))
if not const_length or (axis == 1 or (fft_length is not None and fft_length < shape_n and axis != 0)):
if fft_length is not None:
size = fft_length // 2 + 1
if opset >= 10:
cst_axis = ctx.make_const(
name=utils.make_name('CPLX_csta'), np_val=np.array([-2], dtype=np.int64))
cst_length = ctx.make_const(
name=utils.make_name('CPLX_cstl'), np_val=np.array([size], dtype=np.int64))
sliced_mult = ctx.make_node(
"Slice", inputs=[mult.output[0], zero.name, cst_length.name, cst_axis.name],
name=utils.make_name('CPLX_S2_' + node_name + 'rfft'))
else:
sliced_mult = ctx.make_node(
"Slice", inputs=[mult.output[0]], attr=dict(starts=[0], ends=[size], axes=[-2]),
name=utils.make_name('CPLX_S2_' + node_name + 'rfft'))
else:
sliced_mult = ctx.make_node(
"Slice", inputs=[mult.output[0]], attr=dict(starts=[0], ends=[size], axes=[-2]),
name=utils.make_name('CPLX_S2_' + node_name + 'rfft'))
# Dynamic case - use shape inference to determine size
sliced_mult = mult
else:
sliced_mult = mult

Expand Down Expand Up @@ -358,6 +417,11 @@ def version_13(cls, ctx, node, **kwargs):
# Unsqueeze changed in opset 13.
return cls.any_version(True, 13, ctx, node, **kwargs)

@classmethod
def version_17(cls, ctx, node, **kwargs):
# DFT operator available in opset 17+, use native ONNX DFT implementation
return cls.any_version_dft(17, ctx, node, is_rfft=True, **kwargs)


@tf_op("FFT")
class FFTOp(CommonFFTOp):
Expand All @@ -371,9 +435,128 @@ def version_1(cls, ctx, node, **kwargs):
def version_13(cls, ctx, node, **kwargs):
return cls.any_version(False, 13, ctx, node, **kwargs)

@classmethod
def version_17(cls, ctx, node, **kwargs):
# DFT operator available in opset 17+, use native ONNX DFT implementation
return cls.any_version_dft(17, ctx, node, is_rfft=False, **kwargs)


class CommonFFT2DOp(CommonFFTOp):

@classmethod
def any_version_2d_dft(cls, opset, ctx, node, **kwargs):
"""
Implementation using ONNX DFT operator (opset 17+).
This is much simpler and more efficient than the manual implementation.
"""
# Get input tensor and fft_length
input_tensor = node.input[0]
fft_length_tensor = node.input[1]

# Cast fft_length to int64 to ensure compatibility
fft_length_tensor_int64 = ctx.make_node('Cast', [fft_length_tensor],
attr={'to': onnx_pb.TensorProto.INT64},
name=utils.make_name('fft_length_cast'))
fft_length_tensor = fft_length_tensor_int64.output[0]

# Get input properties
input_shape = ctx.get_shape(input_tensor)
input_dtype = ctx.get_dtype(input_tensor)

# DFT-based implementation is more flexible and can work with various consumers
# No need to restrict consumer types like the manual implementation

# Validate input
utils.make_sure(input_dtype in CommonFFT2DOp.supported_dtypes,
"Unsupported input type for RFFT2D.")

# For RFFT2D, we need to perform DFT on the last two dimensions
# First, we need to add a dimension for complex representation if input is real
if input_shape is not None and len(input_shape) > 0:
# Add dimension for real/imaginary parts if input is real
# Input shape: [..., height, width] -> [..., height, width, 1]
ones_const = ctx.make_const(name=utils.make_name('ones_const'),
np_val=np.array([1], dtype=np.int64))
input_shape_node = ctx.make_node('Shape', [input_tensor],
name=utils.make_name('input_shape'))
new_shape = ctx.make_node('Concat', [input_shape_node.output[0], ones_const.name],
attr={'axis': 0}, name=utils.make_name('new_shape'))
input_expanded = ctx.make_node('Reshape', [input_tensor, new_shape.output[0]],
name=utils.make_name('input_expanded'))
current_input = input_expanded.output[0]
else:
# Dynamic shape case
input_shape_node = ctx.make_node('Shape', [input_tensor],
name=utils.make_name('input_shape'))
ones_const = ctx.make_const(name=utils.make_name('ones_const'),
np_val=np.array([1], dtype=np.int64))
new_shape = ctx.make_node('Concat', [input_shape_node.output[0], ones_const.name],
attr={'axis': 0}, name=utils.make_name('new_shape'))
input_expanded = ctx.make_node('Reshape', [input_tensor, new_shape.output[0]],
name=utils.make_name('input_expanded'))
current_input = input_expanded.output[0]

# Extract fft_length for each dimension (assuming fft_length has shape [2])
zero_const = ctx.make_const(name=utils.make_name('zero_const'),
np_val=np.array([0], dtype=np.int64))
one_const = ctx.make_const(name=utils.make_name('one_const'),
np_val=np.array([1], dtype=np.int64))

height_fft_length = ctx.make_node('Gather', [fft_length_tensor, zero_const.name],
attr={'axis': 0}, name=utils.make_name('height_fft_length'))
width_fft_length = ctx.make_node('Gather', [fft_length_tensor, one_const.name],
attr={'axis': 0}, name=utils.make_name('width_fft_length'))

# Reshape fft_length values to ensure they're scalars (rank 0 tensors) as required by DFT
empty_shape = ctx.make_const(name=utils.make_name('empty_shape'),
np_val=np.array([], dtype=np.int64))
height_fft_length_scalar = ctx.make_node('Reshape',
[height_fft_length.output[0], empty_shape.name],
name=utils.make_name('height_fft_length_scalar'))
width_fft_length_scalar = ctx.make_node('Reshape',
[width_fft_length.output[0], empty_shape.name],
name=utils.make_name('width_fft_length_scalar'))

# Perform DFT on the second-to-last dimension (height)
# Use onesided=0 for intermediate transforms
# axis = -3 (height dimension)
dft_height = ctx.make_node('DFT', [current_input, height_fft_length_scalar.output[0]],
attr={'axis': -3, 'onesided': 0},
name=utils.make_name('dft_height'))

# Perform DFT on the last dimension (width)
# Use onesided=1 for the second transform to get the one-sided FFT result
# axis = -2 (width dimension)
dft_width = ctx.make_node('DFT', [dft_height.output[0], width_fft_length_scalar.output[0]],
attr={'axis': -2, 'onesided': 1},
name=utils.make_name('dft_width'))

# The DFT output has shape [..., height, width//2+1, 2] where last dim is [real, imag]
# We only need the real part (index 0) of the complex dimension
zero_const_slice = ctx.make_const(name=utils.make_name('zero_slice'),
np_val=np.array([0], dtype=np.int64))
one_const_end = ctx.make_const(name=utils.make_name('one_const_end'),
np_val=np.array([1], dtype=np.int64))
minus_one_axis = ctx.make_const(name=utils.make_name('minus_one_axis'),
np_val=np.array([-1], dtype=np.int64))

# Slice to get real part: dft_width[..., 0:1]
dft_real = ctx.make_node('Slice',
[dft_width.output[0], zero_const_slice.name, one_const_end.name, minus_one_axis.name],
name=utils.make_name('dft_real_part'))

# Squeeze the last dimension to remove the complex dimension
dft_result = GraphBuilder(ctx).make_squeeze(
{'data': dft_real.output[0], 'axes': [-1]},
name=utils.make_name('dft_result_squeezed'),
return_node=True)

# Remove the original node and replace with DFT result
ctx.remove_node(node.name)
ctx.replace_all_inputs(node.output[0], dft_result.output[0])

return dft_result

@classmethod
def any_version_2d(cls, const_length, opset, ctx, node, **kwargs):
"""
Expand Down Expand Up @@ -463,8 +646,8 @@ def onnx_rfft_2d_any_test(x, fft_length):
consumers = ctx.find_output_consumers(node.output[0])
consumer_types = set(op.type for op in consumers)
utils.make_sure(
consumer_types == {'ComplexAbs'},
"Current implementation of RFFT2D only allows ComplexAbs as consumer not %r",
consumer_types in [{'ComplexAbs'}, {'Squeeze'}],
"Current implementation of RFFT2D only allows ComplexAbs or Squeeze as consumer not %r",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is problematic

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would need to add more unit tests to check this is working. The current code assumes complex produced FFT are converted into real numbers soon after.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - this change is incorrect

consumer_types)

oldnode = node
Expand Down Expand Up @@ -905,6 +1088,11 @@ def version_13(cls, ctx, node, **kwargs):
# Unsqueeze changed in opset 13.
return cls.any_version_2d(True, 13, ctx, node, **kwargs)

@classmethod
def version_17(cls, ctx, node, **kwargs):
# DFT operator available in opset 17+, use native ONNX DFT implementation
return cls.any_version_2d_dft(17, ctx, node, **kwargs)


@tf_op("ComplexAbs")
class ComplexAbsOp:
Expand Down
Loading