Skip to content

Commit 61511ef

Browse files
committed
More mypy warnings fixed
1 parent 1cb8689 commit 61511ef

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

onnxscript/optimizer/_constant_folding.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,10 @@ def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
272272

273273
@register("Size")
274274
def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
275-
shape = node.inputs[0].shape
275+
input = _get_input(node, 0)
276+
if input is None:
277+
return None
278+
shape = input.shape
276279
if shape is None:
277280
return None
278281
size = 1
@@ -285,8 +288,8 @@ def size(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
285288

286289
@register("If")
287290
def if_op(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
288-
cond = _get_input(node, 0)
289-
cond = _get_bool_value(cond)
291+
cond_input = _get_input(node, 0)
292+
cond = _get_bool_value(cond_input)
290293
if cond is not None:
291294
# cond is a constant-value: inline the branch
292295
branch = "then_branch" if cond else "else_branch"
@@ -346,9 +349,9 @@ def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValu
346349
if any(x is None for x in inputs):
347350
return None
348351
new_axis = _get_int_attribute(node, "new_axis", 0)
349-
if "axis" not in node.attributes:
352+
axis = _get_int_attribute(node, "axis", None)
353+
if axis is None:
350354
return None
351-
axis = node.attributes["axis"].value
352355
if input is not None and isinstance(inputs, list):
353356
if new_axis == 0:
354357
logger.debug("ConcatFromSequence => Concat: %s", [x.name for x in inputs])
@@ -400,6 +403,8 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
400403
return None
401404

402405
axis = _get_int_attribute(node, "axis", 0)
406+
if axis is None:
407+
return None
403408
shape = input.shape
404409
if shape is None:
405410
return None
@@ -466,7 +471,7 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
466471
return None
467472
position_val = position_val.item()
468473
try:
469-
result = input_vals[position_val]
474+
result = input_vals[position_val] # type: ignore[index]
470475
except IndexError:
471476
return None
472477
state.set_sym_value(output, result)
@@ -528,7 +533,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
528533
output_types = onnx.shape_inference.infer_node_outputs(
529534
schema,
530535
ir.serde.serialize_node(node),
531-
input_types,
536+
input_types, # type: ignore[arg-type]
532537
input_data, # type: ignore[arg-type]
533538
)
534539
for output in node.outputs:

0 commit comments

Comments
 (0)