Skip to content

Commit 24d9afd

Browse files
authored
[Relax] Fix Relax Operator PReLU (#18179)
1 parent fe36bb9 commit 24d9afd

File tree

5 files changed

+147
-5
lines changed

5 files changed

+147
-5
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,8 +1100,7 @@ class PRelu(OnnxOpConverter):
11001100
def _impl_v1(cls, bb, inputs, attr, params):
11011101
x = inputs[0]
11021102
slope = inputs[1]
1103-
# TODO(tvm-team): Should add a new op for this.
1104-
return x * slope + relax.op.nn.relu(x) * (relax.const(1.0) - slope)
1103+
return relax.op.nn.prelu(x, slope)
11051104

11061105

11071106
class ThresholdedRelu(OnnxOpConverter):

src/relax/op/nn/nn.cc

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,54 @@ TVM_FFI_STATIC_INIT_BLOCK({
120120
refl::GlobalDef().def("relax.op.nn.prelu", prelu);
121121
});
122122

123+
StructInfo InferStructInfoPRelu(const Call& call, const BlockBuilder& ctx) {
124+
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
125+
if (data_sinfo->IsUnknownNdim()) {
126+
return data_sinfo;
127+
}
128+
if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
129+
ctx->ReportFatal(Diagnostic::Error(call) << "Prelu requires the input tensor to have float "
130+
"dtype. However, the given input dtype is "
131+
<< data_sinfo->dtype);
132+
}
133+
const auto* attrs = call->attrs.as<PReluAttrs>();
134+
NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis);
135+
136+
return data_sinfo;
137+
}
138+
139+
InferLayoutOutput InferLayoutPRelu(const Call& call,
140+
const Map<String, Array<String>>& desired_layouts,
141+
const VarLayoutMap& var_layout_map) {
142+
ICHECK(NoDesiredLayout(call, desired_layouts));
143+
const auto* attrs = call->attrs.as<PReluAttrs>();
144+
ICHECK(attrs) << "Invalid Call";
145+
146+
LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
147+
148+
// TODO(Siva): We could handle if the axis is not the sub indexed one.
149+
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
150+
const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
151+
ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
152+
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
153+
int ndim = tensor_sinfo->ndim;
154+
layout = LayoutDecision(InitialLayout(ndim));
155+
}
156+
157+
ObjectPtr<PReluAttrs> new_attrs = make_object<PReluAttrs>(*attrs);
158+
new_attrs->axis = FindAxis(layout->layout, attrs->axis);
159+
160+
LayoutDecision alpha_layout = GetLayoutDecision(var_layout_map, call->args[1]);
161+
return InferLayoutOutput({layout, alpha_layout}, {layout}, Attrs(new_attrs));
162+
}
163+
123164
TVM_REGISTER_OP("relax.nn.prelu")
124165
.set_num_inputs(2)
125166
.add_argument("data", "Tensor", "The input tensor.")
126167
.add_argument("alpha", "Tensor", "The channel-wise learnable slope.")
127168
.set_attrs_type<PReluAttrs>()
128-
.set_attr<FInferStructInfo>("FInferStructInfo",
129-
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
169+
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPRelu)
170+
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPRelu)
130171
.set_attr<Bool>("FPurity", Bool(true));
131172

132173
/* relax.nn.softmax */

tests/python/relax/test_frontend_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def test_mish():
948948

949949

950950
def test_prelu():
951-
verify_binary("PRelu", [3, 32, 32], [3, 32, 32], [3, 32, 32])
951+
verify_binary("PRelu", [3, 32, 32], [1], [3, 32, 32])
952952

953953

954954
def test_thresholded_relu():

tests/python/relax/test_transform_legalize_ops_nn.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,89 @@ def leaky_relu(var_rxplaceholder: T.handle, var_compute: T.handle):
11591159
tvm.ir.assert_structural_equal(mod, Expected)
11601160

11611161

1162+
def test_prelu():
1163+
# fmt: off
1164+
@tvm.script.ir_module
1165+
class PRelu:
1166+
@R.function
1167+
def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1,), "float32")) -> R.Tensor((2, 3), "float32"):
1168+
gv: R.Tensor((2, 3), "float32") = R.nn.prelu(x, y)
1169+
return gv
1170+
1171+
@tvm.script.ir_module
1172+
class Expected:
1173+
@R.function
1174+
def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1,), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
1175+
gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((2, 3), dtype="float32"))
1176+
return gv
1177+
1178+
@T.prim_func(private=True)
1179+
def prelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), y: T.Buffer((T.int64(1),), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
1180+
T.func_attr({"tir.noalias": True})
1181+
# with T.block("root"):
1182+
slope_broadcasted = T.alloc_buffer((T.int64(3),))
1183+
for c in range(T.int64(3)):
1184+
with T.block("slope_broadcasted"):
1185+
v_c = T.axis.spatial(T.int64(3), c)
1186+
T.reads(y[T.int64(0)])
1187+
T.writes(slope_broadcasted[v_c])
1188+
slope_broadcasted[v_c] = y[T.int64(0)]
1189+
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
1190+
with T.block("compute"):
1191+
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
1192+
T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1])
1193+
T.writes(compute[v_i0, v_i1])
1194+
compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1])
1195+
# fmt: on
1196+
1197+
mod = LegalizeOps()(PRelu)
1198+
tvm.ir.assert_structural_equal(mod, Expected)
1199+
1200+
1201+
def test_prelu_symbolic():
1202+
# fmt: off
1203+
@tvm.script.ir_module
1204+
class PRelu:
1205+
@R.function
1206+
def main(x: R.Tensor(("m", 7), "float32"), y: R.Tensor((1,), "float32")) -> R.Tensor(("m", 7), "float32"):
1207+
m = T.int64()
1208+
gv: R.Tensor((m, 7), "float32") = R.nn.prelu(x, y)
1209+
return gv
1210+
1211+
@tvm.script.ir_module
1212+
class Expected:
1213+
@R.function
1214+
def main(x: R.Tensor(("m", 7), dtype="float32"), y: R.Tensor((1,), dtype="float32")) -> R.Tensor(("m", 7), dtype="float32"):
1215+
m = T.int64()
1216+
gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((m, 7), dtype="float32"))
1217+
return gv
1218+
1219+
@T.prim_func(private=True)
1220+
def prelu(var_x: T.handle, y: T.Buffer((T.int64(1),), "float32"), var_compute: T.handle):
1221+
T.func_attr({"tir.noalias": True})
1222+
m = T.int64()
1223+
x = T.match_buffer(var_x, (m, T.int64(7)))
1224+
compute = T.match_buffer(var_compute, (m, T.int64(7)))
1225+
# with T.block("root"):
1226+
slope_broadcasted = T.alloc_buffer((T.int64(7),))
1227+
for c in range(T.int64(7)):
1228+
with T.block("slope_broadcasted"):
1229+
v_c = T.axis.spatial(T.int64(7), c)
1230+
T.reads(y[T.int64(0)])
1231+
T.writes(slope_broadcasted[v_c])
1232+
slope_broadcasted[v_c] = y[T.int64(0)]
1233+
for i0, i1 in T.grid(m, T.int64(7)):
1234+
with T.block("compute"):
1235+
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
1236+
T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1])
1237+
T.writes(compute[v_i0, v_i1])
1238+
compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1])
1239+
# fmt: on
1240+
1241+
mod = LegalizeOps()(PRelu)
1242+
tvm.ir.assert_structural_equal(mod, Expected)
1243+
1244+
11621245
def test_gelu():
11631246
# fmt: off
11641247
@tvm.script.ir_module

tests/python/relax/test_tvmscript_parser_op_nn.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,5 +364,24 @@ def foo(
364364
_check(foo, bb.get()["foo"])
365365

366366

367+
def test_prelu():
368+
@R.function
369+
def foo(
370+
x: R.Tensor((2, 4, 4, 5), "float32"),
371+
alpha: R.Tensor((1,), "float32"),
372+
) -> R.Tensor((2, 4, 4, 5), "float32"):
373+
gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.prelu(x, alpha)
374+
return gv
375+
376+
x = relax.Var("x", R.Tensor((2, 4, 4, 5), "float32"))
377+
alpha = relax.Var("alpha", R.Tensor((1,), "float32"))
378+
bb = relax.BlockBuilder()
379+
with bb.function("foo", [x, alpha]):
380+
gv = bb.emit(relax.op.nn.prelu(x, alpha))
381+
bb.emit_func_output(gv)
382+
383+
_check(foo, bb.get()["foo"])
384+
385+
367386
if __name__ == "__main__":
368387
tvm.testing.main()

0 commit comments

Comments
 (0)