Skip to content

Commit abb901f

Browse files
author
Siyuan Feng
authored
[Relax] Support left_shift and right_shift op (#17448)
Introduced left_shift and right_shift op in Relax with ONNX frontend support.
1 parent a5d04a5 commit abb901f

File tree

10 files changed

+184
-8
lines changed

10 files changed

+184
-8
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ class BinaryBase(OnnxOpConverter):
244244
relax_op: Callable = None
245245

246246
@classmethod
247-
def _impl_v1(cls, bb, inputs, attr, params):
247+
def base_impl(cls, bb, inputs, attr, params):
248+
"""Base implementation for binary operations."""
248249
if cls.numpy_op is None or cls.relax_op is None:
249250
raise ValueError("Numpy and Relax operators must be defined for BinaryBase.")
250251
if all([isinstance(inp, relax.Constant) for inp in inputs]):
@@ -274,83 +275,131 @@ class Add(BinaryBase):
274275
numpy_op = _np.add
275276
relax_op = relax.op.add
276277

278+
@classmethod
279+
def _impl_v1(cls, bb, inputs, attr, params):
280+
return cls.base_impl(bb, inputs, attr, params)
281+
277282

278283
class Sub(BinaryBase):
279284
"""Converts an onnx Sub node into an equivalent Relax expression."""
280285

281286
numpy_op = _np.subtract
282287
relax_op = relax.op.subtract
283288

289+
@classmethod
290+
def _impl_v1(cls, bb, inputs, attr, params):
291+
return cls.base_impl(bb, inputs, attr, params)
292+
284293

285294
class Mul(BinaryBase):
286295
"""Converts an onnx Mul node into an equivalent Relax expression."""
287296

288297
numpy_op = _np.multiply
289298
relax_op = relax.op.multiply
290299

300+
@classmethod
301+
def _impl_v1(cls, bb, inputs, attr, params):
302+
return cls.base_impl(bb, inputs, attr, params)
303+
291304

292305
class Div(BinaryBase):
293306
"""Converts an onnx Div node into an equivalent Relax expression."""
294307

295308
numpy_op = _np.divide
296309
relax_op = relax.op.divide
297310

311+
@classmethod
312+
def _impl_v1(cls, bb, inputs, attr, params):
313+
return cls.base_impl(bb, inputs, attr, params)
314+
298315

299316
class Pow(BinaryBase):
300317
"""Converts an onnx Pow node into an equivalent Relax expression."""
301318

302319
numpy_op = _np.power
303320
relax_op = relax.op.power
304321

322+
@classmethod
323+
def _impl_v1(cls, bb, inputs, attr, params):
324+
return cls.base_impl(bb, inputs, attr, params)
325+
305326

306327
class And(BinaryBase):
307328
"""Converts an onnx And node into an equivalent Relax expression."""
308329

309330
numpy_op = _np.logical_and
310331
relax_op = relax.op.logical_and
311332

333+
@classmethod
334+
def _impl_v1(cls, bb, inputs, attr, params):
335+
return cls.base_impl(bb, inputs, attr, params)
336+
312337

313338
class Or(BinaryBase):
314339
"""Converts an onnx Or node into an equivalent Relax expression."""
315340

316341
numpy_op = _np.logical_or
317342
relax_op = relax.op.logical_or
318343

344+
@classmethod
345+
def _impl_v1(cls, bb, inputs, attr, params):
346+
return cls.base_impl(bb, inputs, attr, params)
347+
319348

320349
class Xor(BinaryBase):
321350
"""Converts an onnx Xor node into an equivalent Relax expression."""
322351

323352
numpy_op = _np.logical_xor
324353
relax_op = relax.op.logical_xor
325354

355+
@classmethod
356+
def _impl_v1(cls, bb, inputs, attr, params):
357+
return cls.base_impl(bb, inputs, attr, params)
358+
326359

327360
class Less(BinaryBase):
328361
"""Converts an onnx Less node into an equivalent Relax expression."""
329362

330363
numpy_op = _np.less
331364
relax_op = relax.op.less
332365

366+
@classmethod
367+
def _impl_v1(cls, bb, inputs, attr, params):
368+
return cls.base_impl(bb, inputs, attr, params)
369+
333370

334371
class LessOrEqual(BinaryBase):
335372
"""Converts an onnx LessEqual node into an equivalent Relax expression."""
336373

337374
numpy_op = _np.less_equal
338375
relax_op = relax.op.less_equal
339376

377+
@classmethod
378+
def _impl_v1(cls, bb, inputs, attr, params):
379+
return cls.base_impl(bb, inputs, attr, params)
380+
340381

341382
class Greater(BinaryBase):
342383
"""Converts an onnx Greater node into an equivalent Relax expression."""
343384

344385
numpy_op = _np.greater
345386
relax_op = relax.op.greater
346387

388+
@classmethod
389+
def _impl_v1(cls, bb, inputs, attr, params):
390+
return cls.base_impl(bb, inputs, attr, params)
391+
347392

348393
class GreaterOrEqual(BinaryBase):
349394
"""Converts an onnx GreaterEqual node into an equivalent Relax expression."""
350395

351396
numpy_op = _np.greater_equal
352397
relax_op = relax.op.greater_equal
353398

399+
@classmethod
400+
def _impl_v1(cls, bb, inputs, attr, params):
401+
return cls.base_impl(bb, inputs, attr, params)
402+
354403

355404
class Equal(OnnxOpConverter):
356405
"""Converts an onnx Equal node into an equivalent Relax expression."""
@@ -374,39 +423,78 @@ class BitwiseBase(BinaryBase):
374423
"""Converts an onnx BitwiseBase node into an equivalent Relax expression."""
375424

376425
@classmethod
377-
def base_impl(cls, bb, inputs, attr, params, py_func, relax_op):
426+
def base_impl(cls, bb, inputs, attr, params):
427+
"""Base implementation for bitwise operations."""
378428
valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]
379429
for num, inp in enumerate(inputs):
380430
if inp.struct_info.dtype not in valid_types:
381431
raise ValueError(
382432
f"Bitwise operations expect all inputs to have integer types, "
383433
f"got {inp.struct_info.dtype} for input {num}"
384434
)
385-
return BinaryBase.base_impl(bb, inputs, attr, params, py_func, relax_op)
435+
return super().base_impl(bb, inputs, attr, params)
386436

387437

388438
class BitwiseAnd(BitwiseBase):
389439
"""Converts an onnx BitwiseAnd node into an equivalent Relax expression."""
390440

441+
numpy_op = _np.bitwise_and
442+
relax_op = relax.op.bitwise_and
443+
391444
@classmethod
392445
def _impl_v18(cls, bb, inputs, attr, params):
393-
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y, relax.op.bitwise_and)
446+
return cls.base_impl(bb, inputs, attr, params)
394447

395448

396449
class BitwiseOr(BitwiseBase):
397450
"""Converts an onnx BitwiseOr node into an equivalent Relax expression."""
398451

452+
numpy_op = _np.bitwise_or
453+
relax_op = relax.op.bitwise_or
454+
399455
@classmethod
400456
def _impl_v18(cls, bb, inputs, attr, params):
401-
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y, relax.op.bitwise_or)
457+
return cls.base_impl(bb, inputs, attr, params)
402458

403459

404460
class BitwiseXor(BitwiseBase):
405461
"""Converts an onnx BitwiseXor node into an equivalent Relax expression."""
406462

463+
numpy_op = _np.bitwise_xor
464+
relax_op = relax.op.bitwise_xor
465+
407466
@classmethod
408467
def _impl_v18(cls, bb, inputs, attr, params):
409-
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y, relax.op.bitwise_xor)
468+
return cls.base_impl(bb, inputs, attr, params)
469+
470+
471+
class BitwiseNot(BitwiseBase):
472+
"""Converts an onnx BitwiseNot node into an equivalent Relax expression."""
473+
474+
numpy_op = _np.bitwise_not
475+
relax_op = relax.op.bitwise_not
476+
477+
@classmethod
478+
def _impl_v18(cls, bb, inputs, attr, params):
479+
return cls.base_impl(bb, inputs, attr, params)
480+
481+
482+
class BitShift(BitwiseBase):
483+
"""Converts an onnx BitShift node into an equivalent Relax expression."""
484+
485+
@classmethod
486+
def _impl_v11(cls, bb, inputs, attr, params):
487+
direction = attr.get("direction", "LEFT").decode("ascii")
488+
if direction == "LEFT":
489+
cls.numpy_op = _np.left_shift
490+
cls.relax_op = relax.op.left_shift
491+
elif direction == "RIGHT":
492+
cls.numpy_op = _np.right_shift
493+
cls.relax_op = relax.op.right_shift
494+
else:
495+
raise ValueError("Unsupported Shift Direction: " + direction)
496+
497+
return cls.base_impl(bb, inputs, attr, params)
410498

411499

412500
class Sigmoid(OnnxOpConverter):
@@ -2654,8 +2742,8 @@ def _get_convert_map():
26542742
"BitwiseAnd": BitwiseAnd,
26552743
"BitwiseOr": BitwiseOr,
26562744
"BitwiseXor": BitwiseXor,
2657-
# "BitwiseNot": BitwiseNot,
2658-
# "BitwiseShift": BitwiseShift,
2745+
"BitwiseNot": BitwiseNot,
2746+
"BitShift": BitShift,
26592747
"And": And,
26602748
"Or": Or,
26612749
"Xor": Xor,

python/tvm/relax/op/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
floor_divide,
5353
greater,
5454
greater_equal,
55+
left_shift,
5556
less,
5657
less_equal,
5758
logical_and,
@@ -62,6 +63,7 @@
6263
multiply,
6364
not_equal,
6465
power,
66+
right_shift,
6567
subtract,
6668
)
6769
from .create import (

python/tvm/relax/op/binary.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,35 @@ def bitwise_xor(x1: Expr, x2: Expr) -> Expr:
386386
The computed result.
387387
"""
388388
return _ffi_api.bitwise_xor(x1, x2)
389+
390+
391+
def left_shift(x1: Expr, x2: Expr) -> Expr:
392+
"""Bitwise Shift Left
393+
Parameters
394+
----------
395+
x1 : relax.Expr
396+
The input tensor to be shifted.
397+
x2 : relax.Expr
398+
The number of positions to shift.
399+
Returns
400+
-------
401+
result : relax.Expr
402+
The computed result.
403+
"""
404+
return _ffi_api.left_shift(x1, x2)
405+
406+
407+
def right_shift(x1: Expr, x2: Expr) -> Expr:
408+
"""Bitwise Shift Right
409+
Parameters
410+
----------
411+
x1 : relax.Expr
412+
The input tensor to be shifted.
413+
x2 : relax.Expr
414+
The number of positions to shift.
415+
Returns
416+
-------
417+
result : relax.Expr
418+
The computed result.
419+
"""
420+
return _ffi_api.right_shift(x1, x2)

python/tvm/relax/transform/legalize_ops/binary.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr:
6262
register_legalize("relax.bitwise_and", _binary(topi.bitwise_and))
6363
register_legalize("relax.bitwise_or", _binary(topi.bitwise_or))
6464
register_legalize("relax.bitwise_xor", _binary(topi.bitwise_xor))
65+
register_legalize("relax.left_shift", _binary(topi.left_shift))
66+
register_legalize("relax.right_shift", _binary(topi.right_shift))
6567

6668
# logical
6769
register_legalize("relax.logical_and", _binary(topi.logical_and))

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
isinf,
103103
isnan,
104104
layout_transform,
105+
left_shift,
105106
less,
106107
less_equal,
107108
linear,
@@ -133,6 +134,7 @@
133134
quantize,
134135
repeat,
135136
reshape,
137+
right_shift,
136138
round,
137139
rsqrt,
138140
scatter_elements,
@@ -773,6 +775,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
773775
"isinf",
774776
"isnan",
775777
"layout_transform",
778+
"left_shift",
776779
"less",
777780
"less_equal",
778781
"linear",
@@ -809,6 +812,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
809812
"repeat",
810813
"reshape",
811814
"rewriter",
815+
"right_shift",
812816
"tensor_to_shape",
813817
"shape_to_tensor",
814818
"rocm",

src/relax/op/distributed/binary.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(logical_xor);
6868
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_and);
6969
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_or);
7070
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_xor);
71+
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(left_shift);
72+
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(right_shift);
7173

7274
} // namespace distributed
7375
} // namespace relax

src/relax/op/tensor/binary.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_xor);
207207
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_and);
208208
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_or);
209209
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_xor);
210+
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(left_shift);
211+
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(right_shift);
210212

211213
} // namespace relax
212214
} // namespace tvm

src/relax/op/tensor/binary.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ Expr bitwise_or(Expr x1, Expr x2);
129129
/*! \brief Broadcasted element-wise bitwise xor */
130130
Expr bitwise_xor(Expr x1, Expr x2);
131131

132+
/*! \brief Broadcasted element-wise bitwise shift left */
133+
Expr left_shift(Expr x1, Expr x2);
134+
135+
/*! \brief Broadcasted element-wise bitwise shift right */
136+
Expr right_shift(Expr x1, Expr x2);
137+
132138
} // namespace relax
133139
} // namespace tvm
134140

0 commit comments

Comments
 (0)