Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,17 @@ struct StftAttrs : public tvm::AttrsNode<StftAttrs> {
}
}; // struct StftAttrs

/*! \brief Attributes used in DFT operator */
struct DFTAttrs : public tvm::AttrsNode<DFTAttrs> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document that this is 1D DFT along the last axis.

We cannot use this op to support 2D DFT (PyTorch fft2 op), correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document that this is 1D DFT along the last axis.

The description of the fact that calculations occur along the last axis is already given in the functions where the algorithm is implemented:

  • python/tvm/topi/signal.py
  • python/tvm/topi/cuda/signal.py
  • python/tvm/relay/op/transform.py

We cannot use this op to support 2D DFT (PyTorch fft2 op), correct?

We can use 1D DFT (which computes along the last axis) to calculate n-D DFT (and 2D DFT in particular). This can be done with the following pseudocode:

def swap_axes(a, axis1, axis2):
    permutation = list(range(len(shape(a))))
    permutation[axis1] = axis2
    permutation[axis2] = axis1
    return transpose(a, permutation)

def dftn(input_tensor, axes):
    output = input_tensor
    for axis in reversed(axes):
        output = swap_axes(output, axis, -1)
        output = DFT(output)
        output = swap_axes(output, axis, -1)
    return output

axes is a vector that contains the axes along which it is necessary to perform n-D Fourier transform (for 2D, the length of this vector must be 2).

The same approach is used in numpy: here and here.

Bool inverse = Bool(false);

TVM_DECLARE_ATTRS(DFTAttrs, "relay.attrs.DFTAttrs") {
TVM_ATTR_FIELD(inverse)
.describe("Whether to perform the inverse discrete Fourier transform")
.set_default(Bool(false));
}
}; // struct DFTAttrs

struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
bool upper;

Expand Down
111 changes: 111 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4877,6 +4877,116 @@ def _impl_v1(cls, inputs, attr, params):
return mm_out


class DFT(OnnxOpConverter):
"""Operator converter for discrete Fourier transform (DFT)."""

@classmethod
def _impl_v17(cls, inputs, attr, params):
# ************************* Read attrs *************************
axis = attr.get("axis")
inverse = attr.get("inverse")
onesided = attr.get("onesided")

# ************************* Read inputs ************************
input_tensor = inputs[0]
dft_length = inputs[1]

# ************************* Parse inputs ***********************
t1 = ["float16", "float32", "float64"]
t2 = ["int32", "int64"]

# input
assert infer_type(input_tensor).checked_type.dtype in t1
input_shape = infer_shape(input_tensor)
assert len(input_shape) >= 3
if axis < 0:
axis = len(input_shape) + axis
assert 1 <= axis <= len(input_shape) - 1, "axis is out of bounds"

# dft_length
if dft_length is None:
dft_length = input_shape[axis]
else:
dft_length_dtype = infer_type(dft_length).checked_type.dtype
assert dft_length_dtype in t2
dft_length = int(infer_value(dft_length, params).numpy())

# ************************
input_tensor = cls._maybe_crop_or_pad(input_tensor, axis, dft_length)

swap_axis = -1
re_input_tensor, im_input_tensor = cls._split_real_and_imag_parts(input_tensor)

re_input_tensor = cls._swap_axes(re_input_tensor, axis, swap_axis)
im_input_tensor = cls._swap_axes(im_input_tensor, axis, swap_axis)

re_input_tensor, im_input_tensor = _op.dft(re_input_tensor, im_input_tensor, inverse)

re_input_tensor = cls._swap_axes(re_input_tensor, axis, swap_axis)
im_input_tensor = cls._swap_axes(im_input_tensor, axis, swap_axis)

if onesided:
re_input_tensor = cls._crop_onesided(re_input_tensor, axis)
im_input_tensor = cls._crop_onesided(im_input_tensor, axis)

return cls._merge_real_and_imag_parts(re_input_tensor, im_input_tensor)

@classmethod
def _crop_axis(cls, tensor, axis, new_dim):
shape = infer_shape(tensor)
slices = [slice(0, a, 1) for a in shape]
slices[axis] = slice(0, new_dim, 1)
return _op.strided_slice(
tensor,
begin=[s.start for s in slices],
end=[s.stop for s in slices],
strides=[s.step for s in slices],
axes=list(range(len(shape))),
)

@classmethod
def _maybe_crop_or_pad(cls, input_tensor, axis, n_fft):
shape = infer_shape(input_tensor)
if shape[axis] != n_fft:
if shape[axis] > n_fft:
return cls._crop_axis(input_tensor, axis, n_fft)
else:
pad_width = [(0, 0)] * len(shape)
pad_width[axis] = (0, n_fft - shape[axis])
return _op.nn.pad(input_tensor, pad_width)
return input_tensor

@classmethod
def _swap_axes(cls, tensor, axis1, axis2):
permutation = list(range(len(infer_shape(tensor))))
permutation[axis1] = axis2
permutation[axis2] = axis1
return _op.transpose(tensor, permutation)

@classmethod
def _split_real_and_imag_parts(cls, tensor):
shape = infer_shape(tensor)
dtype = infer_type(tensor).checked_type.dtype
if shape[-1] == 1:
re = tensor
im = _op.const(np.zeros(shape), dtype=dtype)
else:
re, im = _op.split(tensor, 2, -1)

return _op.squeeze(re, -1), _op.squeeze(im, -1)

@classmethod
def _merge_real_and_imag_parts(cls, re, im):
re = _op.expand_dims(re, axis=-1)
im = _op.expand_dims(im, axis=-1)
return _op.concatenate([re, im], axis=-1)

@classmethod
def _crop_onesided(cls, tensor, axis):
shape = infer_shape(tensor)
return cls._crop_axis(tensor, axis, shape[axis] // 2 + 1)


class NonMaxSuppression(OnnxOpConverter):
"""Operator converter for NonMaxSuppression."""

Expand Down Expand Up @@ -6696,6 +6806,7 @@ def _get_convert_map(opset):
"Scan": Scan.get_converter(opset),
# ML
"LinearRegressor": LinearRegressor.get_converter(opset),
"DFT": DFT.get_converter(opset),
# Sequence operators
"SequenceConstruct": SequenceConstruct.get_converter(opset),
"SequenceEmpty": SequenceEmpty.get_converter(opset),
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,20 @@ def stft_shape_func(attrs, inputs, _):
]


# DFT
@_reg.register_compute("dft")
def compute_dft(attrs, inputs, _):
"""Compute definition of DFT"""
return topi.dft(
inputs[0],
inputs[1],
attrs.inverse,
)


_reg.register_strategy("dft", strategy.dft_strategy)


# trilu
_reg.register_strategy("trilu", strategy.trilu_strategy)

Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,3 +1412,14 @@ def stft_strategy_cuda(attrs, inputs, out_type, target):
name="stft.cuda",
)
return strategy


@dft_strategy.register(["cuda", "gpu"])
def dft_strategy_cuda(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dft(topi.cuda.dft),
wrap_topi_schedule(topi.generic.schedule_extern),
name="dft.cuda",
)
return strategy
26 changes: 26 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,32 @@ def _compute_stft(attrs, inputs, output_type):
return _compute_stft


# dft
@override_native_generic_func("dft_strategy")
def dft_strategy(attrs, outs, out_type, target):
"""DFT generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dft(topi.dft),
wrap_topi_schedule(topi.generic.schedule_extern),
name="dft.generic",
)
return strategy


def wrap_compute_dft(topi_compute):
"""Wrap DFT compute"""

def _compute_dft(attrs, inputs, _):
return topi_compute(
inputs[0],
inputs[1],
attrs.inverse,
)

return _compute_dft


# trilu
@override_native_generic_func("trilu_strategy")
def trilu_strategy(attrs, outs, out_type, target):
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,6 +1987,33 @@ def stft(
return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)


def dft(re_data, im_data, inverse=False):
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid allocating the imaginary part when possible. In practice, I think we only need real to complex DFT and complex to real iDFT. In both cases, we can remove one of input or output tensors.

So rather than defining one big DFT op that does everything, how about making more fine-grained definition for each possible use case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to do this soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if we define real to complex DFT, we can exploit the "onesided" optimization.

Computes the discrete Fourier transform of input (calculation along the last axis).
This gives frequency components of the signal as they change over time.

Parameters
----------
re_data : relay.Expr
N-D tensor, real part of the input signal.

im_data : relay.Expr
N-D tensor, imaginary part of the input signal.
If the signal is real, then the values of this tensor are zeros.

inverse : bool
Whether to perform the inverse discrete fourier transform.

Returns
-------
re_output : relay.Expr
The Fourier Transform of the input (Real part).
im_output : relay.Expr
The Fourier Transform of the input (Imaginary part).
"""
return TupleWrapper(_make.dft(re_data, im_data, inverse), 2)


def trilu(data, k, upper=True):
"""Given a 2-D matrix or batches of 2-D matrices, returns the
upper or lower triangular part of the tensor.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from .einsum import *
from .unique import *
from .searchsorted import *
from .stft import *
from .signal import *
from . import generic
from . import nn
from . import x86
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@
from .transform import *
from .unique import *
from .searchsorted import *
from .stft import *
from .signal import *
96 changes: 96 additions & 0 deletions python/tvm/topi/cuda/stft.py → python/tvm/topi/cuda/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,99 @@ def gen_ir(
name="stft_cuda",
tag="stft_cuda",
)


def dft(
re_data: te.Tensor,
im_data: te.Tensor,
inverse: tir.IntImm,
):
"""
Computes the discrete Fourier transform of input (calculation along the last axis).
This gives frequency components of the signal as they change over time.

Parameters
----------
re_data : relay.Expr
N-D tensor, real part of the input signal.

im_data : relay.Expr
N-D tensor, imaginary part of the input signal.
If the signal is real, then the values of this tensor are zeros.

inverse : bool
Whether to perform the inverse discrete fourier transform.

Returns
-------
re_output : relay.Expr
The Fourier Transform of the input (Real part).
im_output : relay.Expr
The Fourier Transform of the input (Imaginary part).
"""

def gen_ir(
re_data_buf,
im_data_buf,
re_output_buf,
im_output_buf,
):
ib = tir.ir_builder.create()
re_data_ptr = ib.buffer_ptr(re_data_buf)
im_data_ptr = ib.buffer_ptr(im_data_buf)
re_output_ptr = ib.buffer_ptr(re_output_buf)
im_output_ptr = ib.buffer_ptr(im_output_buf)

shape = re_data.shape
n_fft = shape[len(shape) - 1]
base_range = 1
for i in range(len(shape) - 1):
base_range *= shape[i]

sign = -1 if inverse else 1
factor = 1.0 / n_fft if inverse else 1.0

max_threads = _get_max_threads(base_range)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = ceil_div(base_range, max_threads)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)

tid = bx * max_threads + tx
with ib.if_scope(tid < base_range):
base_idx = tid * n_fft
with ib.for_range(0, n_fft) as n:
n_idx = base_idx + n
re_output_ptr[n_idx] = tir.Cast(re_output_ptr.dtype, 0)
im_output_ptr[n_idx] = tir.Cast(im_output_ptr.dtype, 0)
_w = sign * -2 * pi * n / n_fft
with ib.for_range(0, n_fft) as k:
k_idx = base_idx + k
w = _w * k
cos_w = tir.Cast(re_output_ptr.dtype, tir.cos(w))
sin_w = tir.Cast(re_output_ptr.dtype, tir.sin(w))
re_output_ptr[n_idx] += (
re_data_ptr[k_idx] * cos_w - im_data_ptr[k_idx] * sin_w
)
im_output_ptr[n_idx] += (
re_data_ptr[k_idx] * sin_w + im_data_ptr[k_idx] * cos_w
)

re_output_ptr[n_idx] *= tir.Cast(re_output_ptr.dtype, factor)
im_output_ptr[n_idx] *= tir.Cast(im_output_ptr.dtype, factor)

return ib.get()

output_shape = [re_data.shape] * 2

return te.extern(
shape=output_shape,
inputs=[re_data, im_data],
fcompute=lambda ins, outs: gen_ir(ins[0], ins[1], outs[0], outs[1]),
dtype=[re_data.dtype, im_data.dtype],
name="dft_cuda",
tag="dft_cuda",
)
Loading