@@ -1589,33 +1589,56 @@ def test_upsample3d_trilinear(target, dev):
15891589 tvm .testing .assert_allclose (out_array , tvm_out , rtol = 1e-5 , atol = 1e-5 )
15901590
15911591
1592+ # TODO: Fix softmax with dynamic input on cuda and enable this test
1593+ @tvm .testing .known_failing_targets ("cuda" )
15921594@tvm .testing .parametrize_targets
15931595def test_softmax (target , dev ):
1594- def verify_softmax (inshape , axis ):
1596+ def verify_softmax (inshape , axis , opset = None , dynamic = False ):
15951597 opname = "Softmax"
1596- indata = np .random .uniform (size = inshape ).astype (np .float32 )
15971598 outshape = inshape
1598- y = helper .make_node (opname , ["in" ], ["out" ])
1599+ node_list = []
1600+ input_node_list = [helper .make_tensor_value_info ("in" , TensorProto .FLOAT , list (inshape ))]
1601+ output_node_list = [helper .make_tensor_value_info ("out" , TensorProto .FLOAT , list (outshape ))]
1602+ input_list = [np .random .uniform (size = inshape ).astype (np .float32 )]
1603+ softmax_inputs = ["in" ]
1604+
1605+ if dynamic :
1606+ input_node_list .append (
1607+ helper .make_tensor_value_info ("shape" , TensorProto .INT64 , [len (inshape )])
1608+ )
1609+ input_list .append (np .asarray (inshape ))
1610+ reshape_node = helper .make_node ("Reshape" , ["in" , "shape" ], ["dynamic_in" ])
1611+ softmax_inputs [0 ] = "dynamic_in"
1612+ node_list += [reshape_node ]
1613+
1614+ y = helper .make_node (opname , softmax_inputs , ["out" ])
15991615 if axis is not None :
16001616 axis_attr = helper .make_attribute ("axis" , axis )
16011617 y .attribute .append (axis_attr )
1618+ node_list .append (y )
16021619
16031620 graph = helper .make_graph (
1604- [ y ] ,
1621+ node_list ,
16051622 opname + "_test" ,
1606- inputs = [ helper . make_tensor_value_info ( "in" , TensorProto . FLOAT , list ( indata . shape ))] ,
1607- outputs = [ helper . make_tensor_value_info ( "out" , TensorProto . FLOAT , list ( outshape ))] ,
1623+ inputs = input_node_list ,
1624+ outputs = output_node_list ,
16081625 )
16091626
16101627 model = helper .make_model (graph , producer_name = opname + "_test" )
1611- verify_with_ort_with_inputs (model , [indata ], target = target , dev = dev )
1628+ verify_with_ort_with_inputs (
1629+ model , input_list , use_vm = True , opset = opset , target = target , dev = dev
1630+ )
16121631
16131632 verify_softmax ((1 , 10 ), None )
16141633 verify_softmax ((1 , 10 ), 1 )
16151634 verify_softmax ((1 , 2 , 3 , 10 ), 0 )
16161635 verify_softmax ((1 , 2 , 3 , 10 ), 2 )
16171636 verify_softmax ((1 , 2 , 3 , 4 , 10 ), 3 )
16181637 verify_softmax ((1 , 2 , 3 , 4 , 10 ), 4 )
1638+ verify_softmax ((1 , 10 ), - 1 , dynamic = True )
1639+ verify_softmax ((1 , 2 , 3 , 10 ), - 1 , dynamic = True )
1640+ verify_softmax ((1 , 10 ), - 1 , opset = 8 , dynamic = True )
1641+ verify_softmax ((1 , 2 , 3 , 10 ), - 1 , opset = 8 , dynamic = True )
16191642
16201643
16211644@tvm .testing .parametrize_targets
0 commit comments