@@ -272,7 +272,10 @@ def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
272
272
273
273
@register ("Size" )
274
274
def size (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
275
- shape = node .inputs [0 ].shape
275
+ input = _get_input (node , 0 )
276
+ if input is None :
277
+ return None
278
+ shape = input .shape
276
279
if shape is None :
277
280
return None
278
281
size = 1
@@ -285,8 +288,8 @@ def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
285
288
286
289
@register ("If" )
287
290
def if_op (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
288
- cond = _get_input (node , 0 )
289
- cond = _get_bool_value (cond )
291
+ cond_input = _get_input (node , 0 )
292
+ cond = _get_bool_value (cond_input )
290
293
if cond is not None :
291
294
# cond is a constant-value: inline the branch
292
295
branch = "then_branch" if cond else "else_branch"
@@ -346,9 +349,9 @@ def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValu
346
349
if any (x is None for x in inputs ):
347
350
return None
348
351
new_axis = _get_int_attribute (node , "new_axis" , 0 )
349
- if "axis" not in node .attributes :
352
+ axis = _get_int_attribute (node , "axis" , None )
353
+ if axis is None :
350
354
return None
351
- axis = node .attributes ["axis" ].value
352
355
if input is not None and isinstance (inputs , list ):
353
356
if new_axis == 0 :
354
357
logger .debug ("ConcatFromSequence => Concat: %s" , [x .name for x in inputs ])
@@ -400,6 +403,8 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
400
403
return None
401
404
402
405
axis = _get_int_attribute (node , "axis" , 0 )
406
+ if axis is None :
407
+ return None
403
408
shape = input .shape
404
409
if shape is None :
405
410
return None
@@ -466,7 +471,7 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
466
471
return None
467
472
position_val = position_val .item ()
468
473
try :
469
- result = input_vals [position_val ]
474
+ result = input_vals [position_val ] # type: ignore[index]
470
475
except IndexError :
471
476
return None
472
477
state .set_sym_value (output , result )
@@ -528,7 +533,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
528
533
output_types = onnx .shape_inference .infer_node_outputs (
529
534
schema ,
530
535
ir .serde .serialize_node (node ),
531
- input_types ,
536
+ input_types , # type: ignore[arg-type]
532
537
input_data , # type: ignore[arg-type]
533
538
)
534
539
for output in node .outputs :
0 commit comments