diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index be30520878..96f64bbb8a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3761,6 +3761,192 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::gru.input", trace_only=True) +def aten_gru( + input: TFloat, + hx: TFloat, + params: Sequence[TFloat], + has_biases: bool, + num_layers: int, + dropout: float, + train: bool, + bidirectional: bool, + batch_first: bool, +) -> tuple[TFloat, TFloat]: + """gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)""" + + # Determine number of directions + num_directions = 2 if bidirectional else 1 + + # Get dimensions + if batch_first: + # Convert from [batch, seq, input_size] to [seq, batch, input_size] + input = op.Transpose(input, perm=[1, 0, 2]) + + hidden_size = op.Shape(hx, start=2, end=3) + + # Process each layer + current_input = input + output_h_list = [] + + for layer_idx in range(num_layers): + # Extract hidden state for this layer + layer_start = layer_idx * num_directions + layer_end = (layer_idx + 1) * num_directions + layer_h = op.Slice(hx, layer_start, layer_end, axes=[0]) + + # Extract parameters for this layer + # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction + params_per_direction = 4 if has_biases else 2 + params_per_layer = params_per_direction * num_directions + param_start_idx = layer_idx * params_per_layer + + # Build weight matrices for ONNX GRU + # ONNX expects: W[zrh] shape [num_directions, 3*hidden_size, input_size] + # PyTorch provides: W_ih shape [3*hidden_size, input_size] + W_list = [] + R_list = [] + B_list = [] if has_biases else None + + for dir_idx in range(num_directions): + dir_param_start = param_start_idx + dir_idx * params_per_direction + W_ih = params[ + dir_param_start + ] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] + W_hh = params[ + dir_param_start + 1 + ] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] + + # Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n] + # Split into individual gates + W_ir = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) + W_iz = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_in = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + W_hr = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) + W_hz = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + # Reorder: [z,r,n] + W_ih_reordered = op.Concat( + W_iz, W_ir, W_in, axis=0 + ) # [3*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat( + W_hz, W_hr, W_hn, axis=0 + ) # [3*hidden_size, hidden_size] - ONNX order + + # Add direction dimension + W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size] + W_hh_expanded = op.Unsqueeze( + W_hh_reordered, [0] + ) # [1, 3*hidden_size, hidden_size] + + W_list.append(W_ih_expanded) + R_list.append(W_hh_expanded) + + if has_biases: + b_ih = params[dir_param_start + 2] # [3*hidden_size] - PyTorch order: [r,z,n] + b_hh = params[dir_param_start + 3] # [3*hidden_size] - PyTorch order: [r,z,n] + + # Reorder biases from PyTorch [r,z,n] to ONNX [z,r,n] + b_ir = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) + b_iz = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_in = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + b_hr = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) + b_hz = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + # Reorder: [z,r,n] + b_ih_reordered = op.Concat( + b_iz, b_ir, b_in, axis=0 + ) # [3*hidden_size] - ONNX order + b_hh_reordered = op.Concat( + b_hz, b_hr, b_hn, axis=0 + ) # [3*hidden_size] - ONNX order + + # ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]] + b_combined = op.Concat( + b_ih_reordered, b_hh_reordered, axis=0 + ) # [6*hidden_size] + b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size] + B_list.append(b_expanded) + + # Concatenate weights for all directions + W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] + R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] + B = ( + op.Concat(*B_list, axis=0) + if has_biases and len(B_list) > 1 + else (B_list[0] if has_biases else None) + ) + + # Call ONNX GRU operator + direction = "bidirectional" if bidirectional else "forward" + + # Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size] + hidden_size_attr = hx.shape[2] + + if B is not None: + Y, Y_h = op.GRU( + current_input, + W, + R, + B, + initial_h=layer_h, + direction=direction, + hidden_size=hidden_size_attr, + ) + else: + Y, Y_h = op.GRU( + current_input, + W, + R, + initial_h=layer_h, + direction=direction, + hidden_size=hidden_size_attr, + ) + + # Y shape: [seq_length, num_directions, batch_size, hidden_size] + # Reshape to [seq_length, batch_size, num_directions * hidden_size] + Y = op.Transpose( + Y, perm=[0, 2, 1, 3] + ) # [seq_length, batch_size, num_directions, hidden_size] + Y_shape = op.Shape(Y) + new_shape = op.Concat( + op.Slice(Y_shape, [0], [1]), # seq_length + op.Slice(Y_shape, [1], [2]), # batch_size + op.Reshape( + op.Mul( + op.Slice(Y_shape, [2], [3]), # num_directions + op.Slice(Y_shape, [3], [4]), # hidden_size + ), + op.Constant(value_ints=[-1]), + ), + axis=0, + ) + current_input = op.Reshape(Y, new_shape) + + # Apply dropout if not last layer and dropout > 0 + if layer_idx < num_layers - 1 and dropout > 0.0 and train: + current_input, _ = op.Dropout(current_input, dropout, train) + + # Store final hidden state + output_h_list.append(Y_h) + + # Concatenate all layer outputs + final_h = ( + output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + ) + + # Handle batch_first for output + if batch_first: + # Convert from [seq, batch, features] to [batch, seq, features] + current_input = op.Transpose(current_input, perm=[1, 0, 2]) + + return current_input, final_h + + @torch_op(("_operator::getitem", "aten::getitem")) def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor: return op.SequenceAt(self, i) @@ -4991,6 +5177,212 @@ def aten_lstm_mps_backward( raise NotImplementedError() +@torch_op("aten::lstm.input", trace_only=True) +def aten_lstm( + input: TFloat, + hx: Sequence[TFloat], + params: Sequence[TFloat], + has_biases: bool, + num_layers: int, + dropout: float, + train: bool, + bidirectional: bool, + batch_first: bool, +) -> tuple[TFloat, TFloat, TFloat]: + """lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)""" + + # Extract initial hidden and cell states + initial_h = hx[0] # Shape: [num_directions * num_layers, batch_size, hidden_size] + initial_c = hx[1] # Shape: [num_directions * num_layers, batch_size, hidden_size] + + # Determine number of directions + num_directions = 2 if bidirectional else 1 + + # Get dimensions + if batch_first: + # Convert from [batch, seq, input_size] to [seq, batch, input_size] + input = op.Transpose(input, perm=[1, 0, 2]) + + hidden_size = op.Shape(initial_h, start=2, end=3) + + # Process each layer + current_input = input + output_h_list = [] + output_c_list = [] + + for layer_idx in range(num_layers): + # Extract hidden and cell states for this layer + layer_start = layer_idx * num_directions + layer_end = (layer_idx + 1) * num_directions + layer_h = op.Slice(initial_h, layer_start, layer_end, axes=[0]) + layer_c = op.Slice(initial_c, layer_start, layer_end, axes=[0]) + + # Extract parameters for this layer + # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction + params_per_direction = 4 if has_biases else 2 + params_per_layer = params_per_direction * num_directions + param_start_idx = layer_idx * params_per_layer + + # Build weight matrices for ONNX LSTM + # ONNX expects: W[iofc] shape [num_directions, 4*hidden_size, input_size] + # PyTorch provides: W_ih shape [4*hidden_size, input_size] + W_list = [] + R_list = [] + B_list = [] if has_biases else None + + for dir_idx in range(num_directions): + dir_param_start = param_start_idx + dir_idx * params_per_direction + W_ih = params[ + dir_param_start + ] # [4*hidden_size, input_size] - PyTorch order: [i,f,g,o] + W_hh = params[ + dir_param_start + 1 + ] # [4*hidden_size, hidden_size] - PyTorch order: [i,f,g,o] + + # Reorder gates from PyTorch [i,f,g,o] to ONNX [i,o,f,g] + # Split into individual gates + W_ii = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) + W_if = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_ig = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + W_io = op.Slice(W_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + W_hi = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) + W_hf = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_hg = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + W_ho = op.Slice(W_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + # Reorder: [i,o,f,g] + W_ih_reordered = op.Concat( + W_ii, W_io, W_if, W_ig, axis=0 + ) # [4*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat( + W_hi, W_ho, W_hf, W_hg, axis=0 + ) # [4*hidden_size, hidden_size] - ONNX order + + # Add direction dimension + W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 4*hidden_size, input_size] + W_hh_expanded = op.Unsqueeze( + W_hh_reordered, [0] + ) # [1, 4*hidden_size, hidden_size] + + W_list.append(W_ih_expanded) + R_list.append(W_hh_expanded) + + if has_biases: + b_ih = params[ + dir_param_start + 2 + ] # [4*hidden_size] - PyTorch order: [i,f,g,o] + b_hh = params[ + dir_param_start + 3 + ] # [4*hidden_size] - PyTorch order: [i,f,g,o] + + # Reorder biases from PyTorch [i,f,g,o] to ONNX [i,o,f,g] + b_ii = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) + b_if = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_ig = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + b_io = op.Slice(b_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + b_hi = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) + b_hf = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_hg = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + b_ho = op.Slice(b_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + # Reorder: [i,o,f,g] + b_ih_reordered = op.Concat( + b_ii, b_io, b_if, b_ig, axis=0 + ) # [4*hidden_size] - ONNX order + b_hh_reordered = op.Concat( + b_hi, b_ho, b_hf, b_hg, axis=0 + ) # [4*hidden_size] - ONNX order + + # ONNX expects biases concatenated: [Wb[iofc], Rb[iofc]] + b_combined = op.Concat( + b_ih_reordered, b_hh_reordered, axis=0 + ) # [8*hidden_size] + b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 8*hidden_size] + B_list.append(b_expanded) + + # Concatenate weights for all directions + W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] + R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] + B = ( + op.Concat(*B_list, axis=0) + if has_biases and len(B_list) > 1 + else (B_list[0] if has_biases else None) + ) + + # Call ONNX LSTM operator + direction = "bidirectional" if bidirectional else "forward" + + # Extract hidden_size from initial_h shape: [num_layers * num_directions, batch, hidden_size] + hidden_size_attr = initial_h.shape[2] + + if B is not None: + Y, Y_h, Y_c = op.LSTM( + current_input, + W, + R, + B, + initial_h=layer_h, + initial_c=layer_c, + direction=direction, + hidden_size=hidden_size_attr, + ) + else: + Y, Y_h, Y_c = op.LSTM( + current_input, + W, + R, + initial_h=layer_h, + initial_c=layer_c, + direction=direction, + hidden_size=hidden_size_attr, + ) + + # Y shape: [seq_length, num_directions, batch_size, hidden_size] + # Reshape to [seq_length, batch_size, num_directions * hidden_size] + Y = op.Transpose( + Y, perm=[0, 2, 1, 3] + ) # [seq_length, batch_size, num_directions, hidden_size] + Y_shape = op.Shape(Y) + new_shape = op.Concat( + op.Slice(Y_shape, [0], [1]), # seq_length + op.Slice(Y_shape, [1], [2]), # batch_size + op.Reshape( + op.Mul( + op.Slice(Y_shape, [2], [3]), # num_directions + op.Slice(Y_shape, [3], [4]), # hidden_size + ), + op.Constant(value_ints=[-1]), + ), + axis=0, + ) + current_input = op.Reshape(Y, new_shape) + + # Apply dropout if not last layer and dropout > 0 + if layer_idx < num_layers - 1 and dropout > 0.0 and train: + current_input, _ = op.Dropout(current_input, dropout, train) + + # Store final hidden and cell states + output_h_list.append(Y_h) + output_c_list.append(Y_c) + + # Concatenate all layer outputs + final_h = ( + output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + ) + final_c = ( + output_c_list[0] if len(output_c_list) == 1 else op.Concat(*output_c_list, axis=0) + ) + + # Handle batch_first for output + if batch_first: + # Convert from [seq, batch, features] to [batch, seq, features] + current_input = op.Transpose(current_input, perm=[1, 0, 2]) + + return current_input, final_h, final_c + + @torch_op( ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), trace_only=True, diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 24ccaf4b40..f74dda699d 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -302,6 +302,110 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_lstm_unidirectional(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size=10, hidden_size=20, num_layers=1, batch_first=True + ) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_lstm_bidirectional(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size=10, + hidden_size=20, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_lstm_multilayer(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size=10, hidden_size=20, num_layers=3, batch_first=True + ) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_unidirectional(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU( + input_size=10, hidden_size=20, num_layers=1, batch_first=True + ) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_bidirectional(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU( + input_size=10, + hidden_size=20, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_multilayer(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU( + input_size=10, hidden_size=20, num_layers=3, batch_first=True + ) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main()