Skip to content

Commit 7d38cf2

Browse files
authored
[Relax][PyTorch] Support several binary ops for ExportedProgram importer (#17689)
* Update exported_program_translator.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py
1 parent 7bedfeb commit 7d38cf2

File tree

2 files changed

+103
-273
lines changed

2 files changed

+103
-273
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,16 +204,33 @@ def create_convert_map(
204204
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
205205
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
206206
"floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv),
207+
"ge.Scalar": self._binary_op(relax.op.greater_equal, operator.ge),
208+
"ge.Tensor": self._binary_op(relax.op.greater_equal, operator.ge),
209+
"gt.Scalar": self._binary_op(relax.op.greater, operator.gt),
210+
"gt.Tensor": self._binary_op(relax.op.greater, operator.gt),
211+
"le.Scalar": self._binary_op(relax.op.less_equal, operator.le),
212+
"le.Tensor": self._binary_op(relax.op.less_equal, operator.le),
207213
"lt.Scalar": self._binary_op(relax.op.less, operator.lt),
208214
"lt.Tensor": self._binary_op(relax.op.less, operator.lt),
209215
"matmul.default": self._binary_op(
210216
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
211217
),
212218
"max.other": self._binary_op(relax.op.maximum, max),
219+
"min.other": self._binary_op(relax.op.minimum, min),
220+
"remainder.Tensor": self._binary_op(relax.op.mod, operator.mod),
221+
"remainder.Scalar": self._binary_op(relax.op.mod, operator.mod),
213222
"mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
223+
"ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne),
224+
"ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne),
214225
"pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
215226
"pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
216227
"sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
228+
"__and__.Tensor": self._binary_op(relax.op.bitwise_and, operator.and_),
229+
"__and__.Scalar": self._binary_op(relax.op.bitwise_and, operator.and_),
230+
"__or__.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
231+
"__or__.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_),
232+
"__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor),
233+
"__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor),
217234
# neural network
218235
"_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training,
219236
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,

0 commit comments

Comments
 (0)