@@ -244,7 +244,8 @@ class BinaryBase(OnnxOpConverter):
244244 relax_op : Callable = None
245245
246246 @classmethod
247- def _impl_v1 (cls , bb , inputs , attr , params ):
247+ def base_impl (cls , bb , inputs , attr , params ):
248+ """Base implementation for binary operations."""
248249 if cls .numpy_op is None or cls .relax_op is None :
249250 raise ValueError ("Numpy and Relax operators must be defined for BinaryBase." )
250251 if all ([isinstance (inp , relax .Constant ) for inp in inputs ]):
@@ -274,83 +275,131 @@ class Add(BinaryBase):
274275 numpy_op = _np .add
275276 relax_op = relax .op .add
276277
278+ @classmethod
279+ def _impl_v1 (cls , bb , inputs , attr , params ):
280+ return cls .base_impl (bb , inputs , attr , params )
281+
277282
278283class Sub (BinaryBase ):
279284 """Converts an onnx Sub node into an equivalent Relax expression."""
280285
281286 numpy_op = _np .subtract
282287 relax_op = relax .op .subtract
283288
289+ @classmethod
290+ def _impl_v1 (cls , bb , inputs , attr , params ):
291+ return cls .base_impl (bb , inputs , attr , params )
292+
284293
285294class Mul (BinaryBase ):
286295 """Converts an onnx Mul node into an equivalent Relax expression."""
287296
288297 numpy_op = _np .multiply
289298 relax_op = relax .op .multiply
290299
300+ @classmethod
301+ def _impl_v1 (cls , bb , inputs , attr , params ):
302+ return cls .base_impl (bb , inputs , attr , params )
303+
291304
292305class Div (BinaryBase ):
293306 """Converts an onnx Div node into an equivalent Relax expression."""
294307
295308 numpy_op = _np .divide
296309 relax_op = relax .op .divide
297310
311+ @classmethod
312+ def _impl_v1 (cls , bb , inputs , attr , params ):
313+ return cls .base_impl (bb , inputs , attr , params )
314+
298315
299316class Pow (BinaryBase ):
300317 """Converts an onnx Pow node into an equivalent Relax expression."""
301318
302319 numpy_op = _np .power
303320 relax_op = relax .op .power
304321
322+ @classmethod
323+ def _impl_v1 (cls , bb , inputs , attr , params ):
324+ return cls .base_impl (bb , inputs , attr , params )
325+
305326
306327class And (BinaryBase ):
307328 """Converts an onnx And node into an equivalent Relax expression."""
308329
309330 numpy_op = _np .logical_and
310331 relax_op = relax .op .logical_and
311332
333+ @classmethod
334+ def _impl_v1 (cls , bb , inputs , attr , params ):
335+ return cls .base_impl (bb , inputs , attr , params )
336+
312337
313338class Or (BinaryBase ):
314339 """Converts an onnx Or node into an equivalent Relax expression."""
315340
316341 numpy_op = _np .logical_or
317342 relax_op = relax .op .logical_or
318343
344+ @classmethod
345+ def _impl_v1 (cls , bb , inputs , attr , params ):
346+ return cls .base_impl (bb , inputs , attr , params )
347+
319348
320349class Xor (BinaryBase ):
321350 """Converts an onnx Xor node into an equivalent Relax expression."""
322351
323352 numpy_op = _np .logical_xor
324353 relax_op = relax .op .logical_xor
325354
355+ @classmethod
356+ def _impl_v1 (cls , bb , inputs , attr , params ):
357+ return cls .base_impl (bb , inputs , attr , params )
358+
326359
327360class Less (BinaryBase ):
328361 """Converts an onnx Less node into an equivalent Relax expression."""
329362
330363 numpy_op = _np .less
331364 relax_op = relax .op .less
332365
366+ @classmethod
367+ def _impl_v1 (cls , bb , inputs , attr , params ):
368+ return cls .base_impl (bb , inputs , attr , params )
369+
333370
334371class LessOrEqual (BinaryBase ):
335372 """Converts an onnx LessEqual node into an equivalent Relax expression."""
336373
337374 numpy_op = _np .less_equal
338375 relax_op = relax .op .less_equal
339376
377+ @classmethod
378+ def _impl_v1 (cls , bb , inputs , attr , params ):
379+ return cls .base_impl (bb , inputs , attr , params )
380+
340381
341382class Greater (BinaryBase ):
342383 """Converts an onnx Greater node into an equivalent Relax expression."""
343384
344385 numpy_op = _np .greater
345386 relax_op = relax .op .greater
346387
388+ @classmethod
389+ def _impl_v1 (cls , bb , inputs , attr , params ):
390+ return cls .base_impl (bb , inputs , attr , params )
391+
347392
348393class GreaterOrEqual (BinaryBase ):
349394 """Converts an onnx GreaterEqual node into an equivalent Relax expression."""
350395
351396 numpy_op = _np .greater_equal
352397 relax_op = relax .op .greater_equal
353398
399+ @classmethod
400+ def _impl_v1 (cls , bb , inputs , attr , params ):
401+ return cls .base_impl (bb , inputs , attr , params )
402+
354403
355404class Equal (OnnxOpConverter ):
356405 """Converts an onnx Equal node into an equivalent Relax expression."""
@@ -374,39 +423,78 @@ class BitwiseBase(BinaryBase):
374423 """Converts an onnx BitwiseBase node into an equivalent Relax expression."""
375424
376425 @classmethod
377- def base_impl (cls , bb , inputs , attr , params , py_func , relax_op ):
426+ def base_impl (cls , bb , inputs , attr , params ):
427+ """Base implementation for bitwise operations."""
378428 valid_types = ["int8" , "int16" , "int32" , "int64" , "uint8" , "uint16" , "uint32" , "uint64" ]
379429 for num , inp in enumerate (inputs ):
380430 if inp .struct_info .dtype not in valid_types :
381431 raise ValueError (
382432 f"Bitwise operations expect all inputs to have integer types, "
383433 f"got { inp .struct_info .dtype } for input { num } "
384434 )
385- return BinaryBase .base_impl (bb , inputs , attr , params , py_func , relax_op )
435+ return super () .base_impl (bb , inputs , attr , params )
386436
387437
388438class BitwiseAnd (BitwiseBase ):
389439 """Converts an onnx BitwiseAnd node into an equivalent Relax expression."""
390440
441+ numpy_op = _np .bitwise_and
442+ relax_op = relax .op .bitwise_and
443+
391444 @classmethod
392445 def _impl_v18 (cls , bb , inputs , attr , params ):
393- return cls .base_impl (bb , inputs , attr , params , lambda x , y : x & y , relax . op . bitwise_and )
446+ return cls .base_impl (bb , inputs , attr , params )
394447
395448
396449class BitwiseOr (BitwiseBase ):
397450 """Converts an onnx BitwiseOr node into an equivalent Relax expression."""
398451
452+ numpy_op = _np .bitwise_or
453+ relax_op = relax .op .bitwise_or
454+
399455 @classmethod
400456 def _impl_v18 (cls , bb , inputs , attr , params ):
401- return cls .base_impl (bb , inputs , attr , params , lambda x , y : x | y , relax . op . bitwise_or )
457+ return cls .base_impl (bb , inputs , attr , params )
402458
403459
404460class BitwiseXor (BitwiseBase ):
405461 """Converts an onnx BitwiseXor node into an equivalent Relax expression."""
406462
463+ numpy_op = _np .bitwise_xor
464+ relax_op = relax .op .bitwise_xor
465+
407466 @classmethod
408467 def _impl_v18 (cls , bb , inputs , attr , params ):
409- return cls .base_impl (bb , inputs , attr , params , lambda x , y : x ^ y , relax .op .bitwise_xor )
468+ return cls .base_impl (bb , inputs , attr , params )
469+
470+
471+ class BitwiseNot (BitwiseBase ):
472+ """Converts an onnx BitwiseNot node into an equivalent Relax expression."""
473+
474+ numpy_op = _np .bitwise_not
475+ relax_op = relax .op .bitwise_not
476+
477+ @classmethod
478+ def _impl_v18 (cls , bb , inputs , attr , params ):
479+ return cls .base_impl (bb , inputs , attr , params )
480+
481+
482+ class BitShift (BitwiseBase ):
483+ """Converts an onnx BitShift node into an equivalent Relax expression."""
484+
485+ @classmethod
486+ def _impl_v11 (cls , bb , inputs , attr , params ):
487+ direction = attr .get ("direction" , "LEFT" ).decode ("ascii" )
488+ if direction == "LEFT" :
489+ cls .numpy_op = _np .left_shift
490+ cls .relax_op = relax .op .left_shift
491+ elif direction == "RIGHT" :
492+ cls .numpy_op = _np .right_shift
493+ cls .relax_op = relax .op .right_shift
494+ else :
495+ raise ValueError ("Unsupported Shift Direction: " + direction )
496+
497+ return cls .base_impl (bb , inputs , attr , params )
410498
411499
412500class Sigmoid (OnnxOpConverter ):
@@ -2654,8 +2742,8 @@ def _get_convert_map():
26542742 "BitwiseAnd" : BitwiseAnd ,
26552743 "BitwiseOr" : BitwiseOr ,
26562744 "BitwiseXor" : BitwiseXor ,
2657- # "BitwiseNot": BitwiseNot,
2658- # "BitwiseShift ": BitwiseShift ,
2745+ "BitwiseNot" : BitwiseNot ,
2746+ "BitShift " : BitShift ,
26592747 "And" : And ,
26602748 "Or" : Or ,
26612749 "Xor" : Xor ,
0 commit comments