Skip to content

Commit 48793f3

Browse files
author
Xingyu Zhou
authored
Add ONNX LinearRegressor operator support (#10477)
1 parent 1f60529 commit 48793f3

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3444,6 +3444,35 @@ def body_fn(*loop_inputs):
34443444
return outputs
34453445

34463446

3447+
class LinearRegressor(OnnxOpConverter):
3448+
"""Operator converter for LinearRegressor."""
3449+
3450+
@classmethod
3451+
def _impl_v1(cls, inputs, attr, params):
3452+
data = inputs[0]
3453+
coefficients = attr.get("coefficients", 0)
3454+
data_shape = infer_shape(data)
3455+
targets = attr.get("targets", 1)
3456+
coefficients = _expr.const(list(coefficients), dtype="float32")
3457+
coefficients_shape = infer_shape(coefficients)
3458+
3459+
coefficients = _op.reshape(coefficients, (targets, coefficients_shape[0] // targets))
3460+
if coefficients_shape[0] // targets < data_shape[-1]:
3461+
data = _op.split(data, [coefficients_shape[0] // targets], -1)[0]
3462+
3463+
mm_out = _op.nn.dense(data, coefficients)
3464+
3465+
if "intercepts" in attr:
3466+
intercepts = attr.get("intercepts", 0)
3467+
intercepts = _expr.const(list(intercepts), dtype="float32")
3468+
3469+
if targets == 1:
3470+
return _op.nn.bias_add(mm_out, intercepts, axis=-1)
3471+
return get_relay_op("add")(mm_out, intercepts)
3472+
3473+
return mm_out
3474+
3475+
34473476
class NonMaxSuppression(OnnxOpConverter):
34483477
"""Operator converter for NonMaxSuppression."""
34493478

@@ -4770,6 +4799,8 @@ def _get_convert_map(opset):
47704799
"Adam": Adam.get_converter(opset),
47714800
"Momentum": Momentum.get_converter(opset),
47724801
"Scan": Scan.get_converter(opset),
4802+
# ML
4803+
"LinearRegressor": LinearRegressor.get_converter(opset),
47734804
}
47744805

47754806

tests/python/frontend/onnx/test_forward.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6275,6 +6275,49 @@ def verify_scan(
62756275
verify_scan(input_shapes, output_shapes, 3, [-4, -1, -2], [1] * 3, [-3, -2], [1] * 2, 9)
62766276

62776277

6278+
@tvm.testing.parametrize_targets
6279+
def test_LinearRegressor(target, dev):
6280+
def verify_LinearRegressor(a_shape, c_shape, i_shape, targets=1, batch=1):
6281+
a_array = np.random.uniform(size=a_shape).astype("float32")
6282+
out_shape = (batch, targets)
6283+
6284+
coefficients = np.random.uniform(size=c_shape).astype("float32")
6285+
intercepts = np.random.uniform(size=i_shape).astype("float32")
6286+
6287+
mul_node = helper.make_node(
6288+
"LinearRegressor",
6289+
["a"],
6290+
["out"],
6291+
coefficients=coefficients,
6292+
intercepts=intercepts,
6293+
targets=targets,
6294+
domain="ai.onnx.ml",
6295+
)
6296+
6297+
graph = helper.make_graph(
6298+
[mul_node],
6299+
"LinearRegressor_test",
6300+
inputs=[
6301+
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
6302+
],
6303+
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)],
6304+
)
6305+
model = helper.make_model(
6306+
graph,
6307+
producer_name="LinearRegressor_test",
6308+
opset_imports=[
6309+
onnx.helper.make_opsetid("ai.onnx.ml", 1),
6310+
],
6311+
)
6312+
verify_with_ort_with_inputs(model, [a_array], target=target, dev=dev)
6313+
6314+
verify_LinearRegressor((1, 3), (3), (1))
6315+
verify_LinearRegressor((2, 10), (10), (1), batch=2)
6316+
verify_LinearRegressor((1, 3), (30), (10), targets=10)
6317+
verify_LinearRegressor((10, 3), (30), (10), targets=10, batch=10)
6318+
verify_LinearRegressor((1, 4), (3), (1))
6319+
6320+
62786321
if __name__ == "__main__":
62796322
test_flatten()
62806323
test_reshape()
@@ -6371,3 +6414,4 @@ def verify_scan(
63716414
test_random_uniform_like()
63726415
test_random_normal()
63736416
test_random_normal_like()
6417+
test_LinearRegressor()

0 commit comments

Comments
 (0)