Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 96 additions & 8 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ class BinaryBase(OnnxOpConverter):
relax_op: Callable = None

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
def base_impl(cls, bb, inputs, attr, params):
"""Base implementation for binary operations."""
if cls.numpy_op is None or cls.relax_op is None:
raise ValueError("Numpy and Relax operators must be defined for BinaryBase.")
if all([isinstance(inp, relax.Constant) for inp in inputs]):
Expand Down Expand Up @@ -274,83 +275,131 @@ class Add(BinaryBase):
numpy_op = _np.add
relax_op = relax.op.add

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Sub(BinaryBase):
"""Converts an onnx Sub node into an equivalent Relax expression."""

numpy_op = _np.subtract
relax_op = relax.op.subtract

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Mul(BinaryBase):
"""Converts an onnx Mul node into an equivalent Relax expression."""

numpy_op = _np.multiply
relax_op = relax.op.multiply

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Div(BinaryBase):
"""Converts an onnx Div node into an equivalent Relax expression."""

numpy_op = _np.divide
relax_op = relax.op.divide

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Pow(BinaryBase):
"""Converts an onnx Pow node into an equivalent Relax expression."""

numpy_op = _np.power
relax_op = relax.op.power

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class And(BinaryBase):
"""Converts an onnx And node into an equivalent Relax expression."""

numpy_op = _np.logical_and
relax_op = relax.op.logical_and

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Or(BinaryBase):
"""Converts an onnx Or node into an equivalent Relax expression."""

numpy_op = _np.logical_or
relax_op = relax.op.logical_or

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Xor(BinaryBase):
"""Converts an onnx Xor node into an equivalent Relax expression."""

numpy_op = _np.logical_xor
relax_op = relax.op.logical_xor

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Less(BinaryBase):
"""Converts an onnx Less node into an equivalent Relax expression."""

numpy_op = _np.less
relax_op = relax.op.less

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class LessOrEqual(BinaryBase):
"""Converts an onnx LessEqual node into an equivalent Relax expression."""

numpy_op = _np.less_equal
relax_op = relax.op.less_equal

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Greater(BinaryBase):
"""Converts an onnx Greater node into an equivalent Relax expression."""

numpy_op = _np.greater
relax_op = relax.op.greater

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class GreaterOrEqual(BinaryBase):
"""Converts an onnx GreaterEqual node into an equivalent Relax expression."""

numpy_op = _np.greater_equal
relax_op = relax.op.greater_equal

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Equal(OnnxOpConverter):
"""Converts an onnx Equal node into an equivalent Relax expression."""
Expand All @@ -374,39 +423,78 @@ class BitwiseBase(BinaryBase):
"""Converts an onnx BitwiseBase node into an equivalent Relax expression."""

@classmethod
def base_impl(cls, bb, inputs, attr, params, py_func, relax_op):
def base_impl(cls, bb, inputs, attr, params):
"""Base implementation for bitwise operations."""
valid_types = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]
for num, inp in enumerate(inputs):
if inp.struct_info.dtype not in valid_types:
raise ValueError(
f"Bitwise operations expect all inputs to have integer types, "
f"got {inp.struct_info.dtype} for input {num}"
)
return BinaryBase.base_impl(bb, inputs, attr, params, py_func, relax_op)
return super().base_impl(bb, inputs, attr, params)


class BitwiseAnd(BitwiseBase):
"""Converts an onnx BitwiseAnd node into an equivalent Relax expression."""

numpy_op = _np.bitwise_and
relax_op = relax.op.bitwise_and

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x & y, relax.op.bitwise_and)
return cls.base_impl(bb, inputs, attr, params)


class BitwiseOr(BitwiseBase):
"""Converts an onnx BitwiseOr node into an equivalent Relax expression."""

numpy_op = _np.bitwise_or
relax_op = relax.op.bitwise_or

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x | y, relax.op.bitwise_or)
return cls.base_impl(bb, inputs, attr, params)


class BitwiseXor(BitwiseBase):
"""Converts an onnx BitwiseXor node into an equivalent Relax expression."""

numpy_op = _np.bitwise_xor
relax_op = relax.op.bitwise_xor

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params, lambda x, y: x ^ y, relax.op.bitwise_xor)
return cls.base_impl(bb, inputs, attr, params)


class BitwiseNot(BitwiseBase):
"""Converts an onnx BitwiseNot node into an equivalent Relax expression."""

numpy_op = _np.bitwise_not
relax_op = relax.op.bitwise_not

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class BitShift(BitwiseBase):
"""Converts an onnx BitShift node into an equivalent Relax expression."""

@classmethod
def _impl_v11(cls, bb, inputs, attr, params):
direction = attr.get("direction", "LEFT").decode("ascii")
if direction == "LEFT":
cls.numpy_op = _np.left_shift
cls.relax_op = relax.op.left_shift
elif direction == "RIGHT":
cls.numpy_op = _np.right_shift
cls.relax_op = relax.op.right_shift
else:
raise ValueError("Unsupported Shift Direction: " + direction)

return cls.base_impl(bb, inputs, attr, params)


class Sigmoid(OnnxOpConverter):
Expand Down Expand Up @@ -2652,8 +2740,8 @@ def _get_convert_map():
"BitwiseAnd": BitwiseAnd,
"BitwiseOr": BitwiseOr,
"BitwiseXor": BitwiseXor,
# "BitwiseNot": BitwiseNot,
# "BitwiseShift": BitwiseShift,
"BitwiseNot": BitwiseNot,
"BitShift": BitShift,
"And": And,
"Or": Or,
"Xor": Xor,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
floor_divide,
greater,
greater_equal,
left_shift,
less,
less_equal,
logical_and,
Expand All @@ -62,6 +63,7 @@
multiply,
not_equal,
power,
right_shift,
subtract,
)
from .create import (
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relax/op/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,35 @@ def bitwise_xor(x1: Expr, x2: Expr) -> Expr:
The computed result.
"""
return _ffi_api.bitwise_xor(x1, x2)


def left_shift(x1: Expr, x2: Expr) -> Expr:
"""Bitwise Shift Left
Parameters
----------
x1 : relax.Expr
The input tensor to be shifted.
x2 : relax.Expr
The number of positions to shift.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.left_shift(x1, x2)


def right_shift(x1: Expr, x2: Expr) -> Expr:
"""Bitwise Shift Right
Parameters
----------
x1 : relax.Expr
The input tensor to be shifted.
x2 : relax.Expr
The number of positions to shift.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.right_shift(x1, x2)
2 changes: 2 additions & 0 deletions python/tvm/relax/transform/legalize_ops/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr:
register_legalize("relax.bitwise_and", _binary(topi.bitwise_and))
register_legalize("relax.bitwise_or", _binary(topi.bitwise_or))
register_legalize("relax.bitwise_xor", _binary(topi.bitwise_xor))
register_legalize("relax.left_shift", _binary(topi.left_shift))
register_legalize("relax.right_shift", _binary(topi.right_shift))

# logical
register_legalize("relax.logical_and", _binary(topi.logical_and))
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
isinf,
isnan,
layout_transform,
left_shift,
less,
less_equal,
linear,
Expand Down Expand Up @@ -133,6 +134,7 @@
quantize,
repeat,
reshape,
right_shift,
round,
rsqrt,
scatter_elements,
Expand Down Expand Up @@ -773,6 +775,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"isinf",
"isnan",
"layout_transform",
"left_shift",
"less",
"less_equal",
"linear",
Expand Down Expand Up @@ -809,6 +812,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"repeat",
"reshape",
"rewriter",
"right_shift",
"tensor_to_shape",
"shape_to_tensor",
"rocm",
Expand Down
2 changes: 2 additions & 0 deletions src/relax/op/distributed/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(logical_xor);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_and);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_or);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(bitwise_xor);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(left_shift);
RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(right_shift);

} // namespace distributed
} // namespace relax
Expand Down
2 changes: 2 additions & 0 deletions src/relax/op/tensor/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_xor);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_and);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_or);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_xor);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(left_shift);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(right_shift);

} // namespace relax
} // namespace tvm
6 changes: 6 additions & 0 deletions src/relax/op/tensor/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ Expr bitwise_or(Expr x1, Expr x2);
/*! \brief Broadcasted element-wise bitwise xor */
Expr bitwise_xor(Expr x1, Expr x2);

/*! \brief Broadcasted element-wise bitwise shift left */
Expr left_shift(Expr x1, Expr x2);

/*! \brief Broadcasted element-wise bitwise shift right */
Expr right_shift(Expr x1, Expr x2);

} // namespace relax
} // namespace tvm

Expand Down
Loading
Loading