diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 22e6282b43a4..eb91a3c5384a 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -4240,10 +4240,6 @@ def convert_RNN(node, **kwargs): name, input_nodes, attrs = get_inputs(node, kwargs) - mode = str(attrs.get('mode')) - if mode != 'lstm': - raise NotImplementedError('Currently RNN onnx export only supports lstm mode') - bidirectional = str(attrs.get('bidirectional', 'False')) if bidirectional != 'False': raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False') @@ -4270,128 +4266,246 @@ def convert_RNN(node, **kwargs): data = input_nodes[0] param = input_nodes[1] initial_h = input_nodes[2] - initial_c = input_nodes[3] nodes = [] - if num_layers == 2: - create_tensor([0], name+'_0', kwargs['initializer']) - create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) - create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) + mode = str(attrs.get('mode')) + if mode == 'lstm': + initial_c = input_nodes[3] + if num_layers == 2: + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) + create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) + create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer']) + create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) + create_tensor([4*4*state_size*state_size], name+'_WR_offset', kwargs['initializer']) - create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer']) - create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) + nodes += [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + + # Layer 0 + # get W + make_node('Slice', [param, name+'_0', name+'_4*state_size^2'], [name+'_W0_1d']), + make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02', name+'_W03']), + make_node('Concat', [name+'_W00', name+'_W03', name+'_W01', name+'_W02'], [name+'_W0_'], axis=0), + make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']), + # get R + make_node('Add', [name+'_4*state_size^2', name+'_4*state_size^2'], [name+'_R0_offset']), + make_node('Slice', [param, name+'_4*state_size^2', name+'_R0_offset'], [name+'_R0_1d']), + make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02', name+'_R03']), + make_node('Concat', [name+'_R00', name+'_R03', name+'_R01', name+'_R02'], [name+'_R0_'], axis=0), + make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']), + # get B + make_node('Add', [name+'_WR_offset', name+'_8*state_size'], [name+'_B0_offset']), + make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']), + make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02', name+'_B03', + name+'_B04', name+'_B05', name+'_B06', name+'_B07']), + make_node('Concat', [name+'_B00', name+'_B03', name+'_B01', name+'_B02', + name+'_B04', name+'_B07', name+'_B05', name+'_B06'], [name+'_B0_'], axis=0), + make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']), + # get initial states + make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0), + make_node('Split', [initial_c], [name+'_initial_c0', name+'_initial_c1'], axis=0), + # get seq_len + make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), + make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), + # Layer 0 LSTM + make_node('LSTM', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len', + name+'_initial_h0', name+'_initial_c0'], + [name+'_lstm0_out_', name+'_lstm0_h', name+'_lstm0_c'], hidden_size=state_size), + make_node('Squeeze', [name+'_lstm0_out_'], [name+'_lstm0_out'], axes=[1]), + + # Layer 1 + # get W + make_node('Add', [name+'_R0_offset', name+'_4*state_size^2'], [name+'_W1_offset']), + make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']), + make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12', name+'_W13']), + make_node('Concat', [name+'_W10', name+'_W13', name+'_W11', name+'_W12'], [name+'_W1_'], axis=0), + make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']), + # get R + make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']), + make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12', name+'_R13']), + make_node('Concat', [name+'_R10', name+'_R13', name+'_R11', name+'_R12'], [name+'_R1_'], axis=0), + make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']), + # get B + make_node('Add', [name+'_B0_offset', name+'_8*state_size'], [name+'_B1_offset']), + make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']), + make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12', name+'_B13', + name+'_B14', name+'_B15', name+'_B16', name+'_B17']), + make_node('Concat', [name+'_B10', name+'_B13', name+'_B11', name+'_B12', + name+'_B14', name+'_B17', name+'_B15', name+'_B16'], [name+'_B1_'], axis=0), + make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']), + # Layer 1 LSTM + make_node('LSTM', [name+'_lstm0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len', + name+'_initial_h1', name+'_initial_c1'], + [name+'_lstm1_out_', name+'_lstm1_h', name+'_lstm1_c'], hidden_size=state_size), + make_node('Squeeze', [name+'_lstm1_out_'], [name], axes=[1]), + make_node('Concat', [name+'_lstm0_h', name+'_lstm1_h'], [name+'1'], axis=0), + make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0), + ] + elif num_layers == 1: + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer']) + create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) + create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) + create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer']) + create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) - create_tensor([4*4*state_size*state_size], name+'_WR_offset', kwargs['initializer']) + nodes += [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + # get W + make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), + make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), + make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), + make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), + # get R + make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), + make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']), + make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']), + make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0), + make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']), + # get B + make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']), + make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']), + make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', + name+'_B4', name+'_B5', name+'_B6', name+'_B7']), + make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', + name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), + make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']), + # get seq_len + make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), + make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), + # compute LSTM + make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], + [name+'0_', name+'1', name+'2'], hidden_size=state_size), + make_node('Squeeze', [name+'0_'], [name], axes=[1]), + ] + else: + raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2') - nodes += [ - make_node('Shape', [data], [name+'_data_shape']), - make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), - - # Layer 0 - # get W - make_node('Slice', [param, name+'_0', name+'_4*state_size^2'], [name+'_W0_1d']), - make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02', name+'_W03']), - make_node('Concat', [name+'_W00', name+'_W03', name+'_W01', name+'_W02'], [name+'_W0_'], axis=0), - make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']), - # get R - make_node('Add', [name+'_4*state_size^2', name+'_4*state_size^2'], [name+'_R0_offset']), - make_node('Slice', [param, name+'_4*state_size^2', name+'_R0_offset'], [name+'_R0_1d']), - make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02', name+'_R03']), - make_node('Concat', [name+'_R00', name+'_R03', name+'_R01', name+'_R02'], [name+'_R0_'], axis=0), - make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']), - # get B - make_node('Add', [name+'_WR_offset', name+'_8*state_size'], [name+'_B0_offset']), - make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']), - make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02', name+'_B03', - name+'_B04', name+'_B05', name+'_B06', name+'_B07']), - make_node('Concat', [name+'_B00', name+'_B03', name+'_B01', name+'_B02', - name+'_B04', name+'_B07', name+'_B05', name+'_B06'], [name+'_B0_'], axis=0), - make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']), - # get initial states - make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0), - make_node('Split', [initial_c], [name+'_initial_c0', name+'_initial_c1'], axis=0), - # get seq_len - make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), - make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), - # Layer 0 LSTM - make_node('LSTM', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len', - name+'_initial_h0', name+'_initial_c0'], - [name+'_lstm0_out_', name+'_lstm0_h', name+'_lstm0_c'], hidden_size=state_size), - make_node('Squeeze', [name+'_lstm0_out_'], [name+'_lstm0_out'], axes=[1]), - - # Layer 1 - # get W - make_node('Add', [name+'_R0_offset', name+'_4*state_size^2'], [name+'_W1_offset']), - make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']), - make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12', name+'_W13']), - make_node('Concat', [name+'_W10', name+'_W13', name+'_W11', name+'_W12'], [name+'_W1_'], axis=0), - make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']), - # get R - make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']), - make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12', name+'_R13']), - make_node('Concat', [name+'_R10', name+'_R13', name+'_R11', name+'_R12'], [name+'_R1_'], axis=0), - make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']), - # get B - make_node('Add', [name+'_B0_offset', name+'_8*state_size'], [name+'_B1_offset']), - make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']), - make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12', name+'_B13', - name+'_B14', name+'_B15', name+'_B16', name+'_B17']), - make_node('Concat', [name+'_B10', name+'_B13', name+'_B11', name+'_B12', - name+'_B14', name+'_B17', name+'_B15', name+'_B16'], [name+'_B1_'], axis=0), - make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']), - # Layer 1 LSTM - make_node('LSTM', [name+'_lstm0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len', - name+'_initial_h1', name+'_initial_c1'], - [name+'_lstm1_out_', name+'_lstm1_h', name+'_lstm1_c'], hidden_size=state_size), - make_node('Squeeze', [name+'_lstm1_out_'], [name], axes=[1]), - make_node('Concat', [name+'_lstm0_h', name+'_lstm1_h'], [name+'1'], axis=0), - make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0), - ] - elif num_layers == 1: - create_tensor([0], name+'_0', kwargs['initializer']) - create_tensor([1], name+'_1', kwargs['initializer']) - create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer']) - create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) - create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) - create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer']) - create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) + elif mode == 'gru': + if num_layers == 2: + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer']) + create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer']) + create_tensor([1, 3*state_size, state_size], name+'_WR_shape', kwargs['initializer']) + create_tensor([1, 6*state_size], name+'_B_shape', kwargs['initializer']) + create_tensor([4*3*state_size*state_size], name+'_WR_offset', kwargs['initializer']) - nodes += [ - make_node('Shape', [data], [name+'_data_shape']), - make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), - # get W - make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), - make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), - make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), - make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), - make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), - make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), - # get R - make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), - make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']), - make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']), - make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0), - make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']), - # get B - make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']), - make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']), - make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', - name+'_B4', name+'_B5', name+'_B6', name+'_B7']), - make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', - name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), - make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']), - # get seq_len - make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), - make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), - # compute LSTM - make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], - [name+'0_', name+'1', name+'2'], hidden_size=state_size), - make_node('Squeeze', [name+'0_'], [name], axes=[1]), - ] - else: - raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2') + nodes += [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + + # Layer 0 + # get W + make_node('Slice', [param, name+'_0', name+'_3*state_size^2'], [name+'_W0_1d']), + make_node('Split', [name+'_W0_1d'], [name+'_W00', name+'_W01', name+'_W02']), + make_node('Concat', [name+'_W01', name+'_W00', name+'_W02'], [name+'_W0_'], axis=0), + make_node('Reshape', [name+'_W0_', name+'_WR_shape'], [name+'_W0']), + # get R + make_node('Add', [name+'_3*state_size^2', name+'_3*state_size^2'], [name+'_R0_offset']), + make_node('Slice', [param, name+'_3*state_size^2', name+'_R0_offset'], [name+'_R0_1d']), + make_node('Split', [name+'_R0_1d'], [name+'_R00', name+'_R01', name+'_R02']), + make_node('Concat', [name+'_R01', name+'_R00', name+'_R02'], [name+'_R0_'], axis=0), + make_node('Reshape', [name+'_R0_', name+'_WR_shape'], [name+'_R0']), + # get B + make_node('Add', [name+'_WR_offset', name+'_6*state_size'], [name+'_B0_offset']), + make_node('Slice', [param, name+'_WR_offset', name+'_B0_offset'], [name+'_B0_1d']), + make_node('Split', [name+'_B0_1d'], [name+'_B00', name+'_B01', name+'_B02', + name+'_B03', name+'_B04', name+'_B05']), + make_node('Concat', [name+'_B01', name+'_B00', name+'_B02', + name+'_B04', name+'_B03', name+'_B05'], [name+'_B0_'], axis=0), + make_node('Reshape', [name+'_B0_', name+'_B_shape'], [name+'_B0']), + # get initial states + make_node('Split', [initial_h], [name+'_initial_h0', name+'_initial_h1'], axis=0), + # get seq_len + make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), + make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), + # Layer 0 GRU + make_node('GRU', [data, name+'_W0', name+'_R0', name+'_B0', name+'_seq_len', + name+'_initial_h0'], + [name+'_gru0_out_', name+'_gru0_h'], hidden_size=state_size, linear_before_reset=1), + make_node('Squeeze', [name+'_gru0_out_'], [name+'_gru0_out'], axes=[1]), + + # Layer 1 + # get W + make_node('Add', [name+'_R0_offset', name+'_3*state_size^2'], [name+'_W1_offset']), + make_node('Slice', [param, name+'_R0_offset', name+'_W1_offset'], [name+'_W1_1d']), + make_node('Split', [name+'_W1_1d'], [name+'_W10', name+'_W11', name+'_W12']), + make_node('Concat', [name+'_W11', name+'_W10', name+'_W12'], [name+'_W1_'], axis=0), + make_node('Reshape', [name+'_W1_', name+'_WR_shape'], [name+'_W1']), + # get R + make_node('Slice', [param, name+'_W1_offset', name+'_WR_offset'], [name+'_R1_1d']), + make_node('Split', [name+'_R1_1d'], [name+'_R10', name+'_R11', name+'_R12']), + make_node('Concat', [name+'_R11', name+'_R10', name+'_R12'], [name+'_R1_'], axis=0), + make_node('Reshape', [name+'_R1_', name+'_WR_shape'], [name+'_R1']), + # get B + make_node('Add', [name+'_B0_offset', name+'_6*state_size'], [name+'_B1_offset']), + make_node('Slice', [param, name+'_B0_offset', name+'_B1_offset'], [name+'_B1_1d']), + make_node('Split', [name+'_B1_1d'], [name+'_B10', name+'_B11', name+'_B12', + name+'_B13', name+'_B14', name+'_B15']), + make_node('Concat', [name+'_B11', name+'_B10', name+'_B12', + name+'_B14', name+'_B13', name+'_B15'], [name+'_B1_'], axis=0), + make_node('Reshape', [name+'_B1_', name+'_B_shape'], [name+'_B1']), + # Layer 1 GRU + make_node('GRU', [name+'_gru0_out', name+'_W1', name+'_R1', name+'_B1', name+'_seq_len', + name+'_initial_h1'], + [name+'_gru1_out_', name+'_gru1_h'], hidden_size=state_size, linear_before_reset=1), + make_node('Squeeze', [name+'_gru1_out_'], [name], axes=[1]), + make_node('Concat', [name+'_gru0_h', name+'_gru1_h'], [name+'1'], axis=0) + ] + elif num_layers == 1: + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor([3*state_size], name+'_3*state_size', kwargs['initializer']) + create_tensor([6*state_size], name+'_6*state_size', kwargs['initializer']) + create_tensor([3*state_size*state_size], name+'_3*state_size^2', kwargs['initializer']) + create_tensor([1, 3*state_size, state_size], name+'_R_shape', kwargs['initializer']) + create_tensor([1, 6*state_size], name+'_B_shape', kwargs['initializer']) + + nodes += [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + # get W + make_node('Mul', [name+'_3*state_size', name+'_input_size'], [name+'_mul0']), + make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), + make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2']), + make_node('Concat', [name+'_W1', name+'_W0', name+'_W2'], [name+'_W_'], axis=0), + make_node('Concat', [name+'_1', name+'_3*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), + # get R + make_node('Add', [name+'_mul0', name+'_3*state_size^2'], [name+'_add0']), + make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']), + make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2']), + make_node('Concat', [name+'_R1', name+'_R0', name+'_R2'], [name+'_R_'], axis=0), + make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']), + # get B + make_node('Add', [name+'_add0', name+'_6*state_size'], [name+'_add1']), + make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']), + make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', + name+'_B3', name+'_B4', name+'_B5']), + make_node('Concat', [name+'_B1', name+'_B0', name+'_B2', + name+'_B4', name+'_B3', name+'_B5'], [name+'_B_'], axis=0), + make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']), + # get seq_len + make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), + make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), + # compute LSTM + make_node('GRU', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h], + [name+'0_', name+'1'], hidden_size=state_size, linear_before_reset=1), + make_node('Squeeze', [name+'0_'], [name], axes=[1]), + ] + else: + raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1') + + else: + raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode") return nodes @mx_op.register('_rnn_param_concat') diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 520ac407e49b..51170e671a43 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1221,27 +1221,34 @@ def test_onnx_export_sequence_reverse(tmp_path, dtype, params): # onnx LSTM from opset 11 does not support float64 +@pytest.mark.parametrize('mode', ['lstm', 'gru']) @pytest.mark.parametrize('dtype', ['float32']) -@pytest.mark.parametrize('state_size', [32, 40]) -@pytest.mark.parametrize('input_size', [32, 40, 64]) +@pytest.mark.parametrize('state_size', [16, 32]) +@pytest.mark.parametrize('input_size', [16, 32, 64]) @pytest.mark.parametrize('num_layers', [1, 2]) -@pytest.mark.parametrize('batch_size', [1, 3, 5]) +@pytest.mark.parametrize('batch_size', [1, 2, 4]) @pytest.mark.parametrize('seq_length', [16, 32]) -def test_onnx_export_RNN(tmp_path, dtype, state_size, input_size, num_layers, batch_size, seq_length): +def test_onnx_export_RNN(tmp_path, mode, dtype, state_size, input_size, num_layers, batch_size, seq_length): # TODO: The current implementation fails assertion checks for large parm/state_size. - + # for num_layers >= 2, input_size must equal to state_size if num_layers >= 2 and input_size != state_size: return - - M = def_model('RNN', mode='lstm', state_size=state_size, state_outputs=True, num_layers=num_layers, p=0) + factor = 3 + if mode == 'lstm': + factor = 4 + + M = def_model('RNN', mode=mode, state_size=state_size, state_outputs=True, num_layers=num_layers, p=0) x = mx.nd.random.normal(0, 10, (seq_length, batch_size, input_size), dtype=dtype) - param = mx.nd.random.normal(0, 1, [num_layers*4*state_size*input_size + - num_layers*4*state_size*state_size + - num_layers*8*state_size], dtype=dtype) + param = mx.nd.random.normal(0, 1, [num_layers*factor*state_size*input_size + + num_layers*factor*state_size*state_size + + num_layers*2*factor*state_size], dtype=dtype) state = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype) - cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype) - op_export_test('rnn', M, [x, param, state, cell], tmp_path) + if mode == 'lstm': + cell = mx.nd.random.uniform(-1, 1, [num_layers, batch_size, state_size], dtype=dtype) + op_export_test('rnn', M, [x, param, state, cell], tmp_path) + else: + op_export_test('rnn', M, [x, param, state], tmp_path) @pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])