Skip to content

Conversation

@ombrdr47
Copy link
Contributor

@ombrdr47 ombrdr47 commented Nov 2, 2025

Implement aten_lstm and aten_gru operators to enable torch.onnx.export for PyTorch LSTM and GRU layers. This addresses issue #2546.

Key features:

  • Full support for multi-layer RNNs (num_layers > 1)
  • Bidirectional support (forward and backward directions)
  • Handles both biased and non-biased configurations
  • batch_first parameter support with automatic transposition
  • Dropout support between layers (nondeterministic seeded)
  • Proper gate reordering for ONNX compatibility:
    • LSTM: PyTorch [i,f,g,o] -> ONNX [i,o,f,g]
    • GRU: PyTorch [r,z,n] -> ONNX [z,r,n]

Implementation details:

  • Uses ONNX LSTM/GRU operators with proper parameter formatting
  • Handles weight matrix transposition and reshaping
  • Correctly concatenates biases using op.Concat
  • Processes each layer independently with proper state management
  • Returns outputs in PyTorch-compatible format

Closes: #2546

Also resolves:

Implement aten_lstm and aten_gru operators to enable torch.onnx.export
for PyTorch LSTM and GRU layers. This addresses issue microsoft#2546.

Key features:
- Full support for multi-layer RNNs (num_layers > 1)
- Bidirectional support (forward and backward directions)
- Handles both biased and non-biased configurations
- batch_first parameter support with automatic transposition
- Dropout support between layers (nondeterministic seeded)
- Proper gate reordering for ONNX compatibility:
  * LSTM: PyTorch [i,f,g,o] -> ONNX [i,o,f,g]
  * GRU: PyTorch [r,z,n] -> ONNX [z,r,n]

Implementation details:
- Uses ONNX LSTM/GRU operators with proper parameter formatting
- Handles weight matrix transposition and reshaping
- Correctly concatenates biases using op.Concat
- Processes each layer independently with proper state management
- Returns outputs in PyTorch-compatible format

Closes: microsoft#2546
@ombrdr47
Copy link
Contributor Author

ombrdr47 commented Nov 2, 2025

@microsoft-github-policy-service agree

@justinchuby
Copy link
Collaborator

Thanks for your contribution! Could you help create some tests in https://github.com/microsoft/onnxscript/blob/main/tests/function_libs/torch_lib/e2e_ops_tests.py?

Update to aten::lstm.input and aten::gru.input
Move aten_gru function to appear after aten_ger for alphabetical ordering with other aten_g* functions. Also add hidden_size attribute computation to GRU for consistency with LSTM implementation.
@ombrdr47
Copy link
Contributor Author

ombrdr47 commented Nov 3, 2025

Thanks for your contribution! Could you help create some tests in https://github.com/microsoft/onnxscript/blob/main/tests/function_libs/torch_lib/e2e_ops_tests.py?

Hi @justinchuby, I've addressed all comments:
1.Changed decorators to use full operator names (aten::lstm.input and aten::gru.input)
2. Moved aten_gru to alphabetical location with other aten_g* functions
All changes pushed (3 commits total)

Regarding tests:
I attempted to add comprehensive tests to e2e_ops_tests.py covering unidirectional, bidirectional, and multi-layer configurations for both LSTM and GRU. However, I encountered an issue with the hidden_size attribute that I'd like your guidance on.
While the ONNX spec marks hidden_size as optional (required=False), ONNX Runtime requires it and fails with:
info.GetAttr('hidden_size', &int64_value).IsOK() && int64_value > 0 was false

In trace_only mode, tensor shapes return SymbolicValues rather than concrete Python integers. When I compute hidden_size from the weight tensor shape:

W_shape = op.Shape(W)
hidden_size_times_4 = op.Slice(W_shape, [1], [2], axes=[0])
hidden_size_attr = op.Div(hidden_size_times_4, op.Constant(value_ints=[4]))

The result is a tensor/SymbolicValue, not a static integer that ONNX attributes require.
I'm not sure if there's a way to compute static attributes in trace mode, or if I should use a different pattern. Could you provide guidance on the correct approach? I'm happy to add tests once I understand how to handle this!

Thank you!

@justinchuby
Copy link
Collaborator

Thanks!

Can the hidden_size be obtained from any of the function inputs? Is it static? If so you may do shape = some_input.shape, which will give you static values to create an attribute.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Nov 3, 2025
Copy link
Contributor

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@codecov
Copy link

codecov bot commented Nov 3, 2025

Codecov Report

❌ Patch coverage is 2.77778% with 140 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.11%. Comparing base (93783ee) to head (1f5e764).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 2.77% 140 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2674      +/-   ##
==========================================
- Coverage   70.46%   70.11%   -0.36%     
==========================================
  Files         224      224              
  Lines       26812    26956     +144     
  Branches     2686     2700      +14     
==========================================
+ Hits        18893    18899       +6     
- Misses       6987     7126     +139     
+ Partials      932      931       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Use initial_h.shape[2] for LSTM and hx.shape[2] for GRU to get static
hidden_size values instead of computing from weight matrices. This allows
the attribute to be a Python integer rather than a SymbolicValue.
Add comprehensive tests for LSTM operator covering:
- Unidirectional single-layer
- Bidirectional single-layer
- Multi-layer (3 layers)

All tests pass successfully with the fixed hidden_size attribute.
@ombrdr47
Copy link
Contributor Author

ombrdr47 commented Nov 3, 2025

Thanks!

Can the hidden_size be obtained from any of the function inputs? Is it static? If so you may do shape = some_input.shape, which will give you static values to create an attribute.

Thanks! That worked. I've updated the code to use initial_h.shape[2] for LSTM and hx.shape[2] for GRU.

I added 3 LSTM tests and they all pass:

  • test_lstm_unidirectional
  • test_lstm_bidirectional
  • test_lstm_multilayer

However, I'm having issues with GRU. The tests are failing with numerical accuracy errors - the outputs don't match PyTorch. I've double-checked the gate reordering (PyTorch [r,z,n] to ONNX [z,r,n]) and it looks correct, but something is causing the results to be different. Could you help me figure out what's wrong with the GRU implementation?

All changes are pushed. Thanks for your help!

@justinchuby
Copy link
Collaborator

Feel free to add commit the tests and I can take a look

Add comprehensive tests for GRU operator covering:
- Unidirectional single-layer
- Bidirectional single-layer
- Multi-layer (3 layers)

Note: GRU tests currently fail with numerical accuracy issues.
@ombrdr47
Copy link
Contributor Author

ombrdr47 commented Nov 3, 2025

Feel free to add commit the tests and I can take a look

Thanks! I've added all 6 tests and pushed them. The LSTM tests all pass, but the GRU tests are failing. The outputs don't match PyTorch within tolerance even though the gate reordering looks correct.
All changes are commited. Thank you!

@justinchuby
Copy link
Collaborator

justinchuby commented Nov 3, 2025

@ombrdr47 the tests seem ok according to the CI? You may want to fix the lint errors

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested to add dynamic tests (or even just promote all tests to be dynamic). Other than that, the PR looks good.

@ombrdr47
Copy link
Contributor Author

ombrdr47 commented Nov 4, 2025

Suggested to add dynamic tests (or even just promote all tests to be dynamic). Other than that, the PR looks good.

Hi @titaiwangms,
I attempted to add dynamic_shapes to the tests, but I got export errors:
ConstraintViolationError: Constraints violated (L['x'].size()[1])!
You marked L['x'].size()[1] as dynamic but the code specialized it to be constant. I tested both approaches:

  • dynamic_shapes=({0: "batch", 1: "seq_len"}) - export fails at the torch.export step
  • dynamic_shapes=({0: "batch"},) - export fails at the ONNX conversion step
    looks like LSTM/GRU layers are forcing the dimensions to be constant values during export, even when marked as dynamic. Do you have any suggestions on how to work around this?

@justinchuby
Copy link
Collaborator

justinchuby commented Nov 4, 2025

I think it's ok to skip dynamic shapes testing for now if the error comes from torch.export, thanks!

@ombrdr47
Copy link
Contributor Author

ombrdr47 commented Nov 4, 2025

I think it's ok to skip dynamic shapes testing for now if the error comes from torch.export, thanks!
Thanks @justinchuby, I've also fixed lint errors, let me know if you need anything else

@titaiwangms
Copy link
Contributor

I think it's ok to skip dynamic shapes testing for now if the error comes from torch.export, thanks!

Have not seen this error message for a while. Is the CI using old torch? (before oblivious backed size?)

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can merge for now. Thank you!

@justinchuby
Copy link
Collaborator

I think it's ok to skip dynamic shapes testing for now if the error comes from torch.export, thanks!

Have not seen this error message for a while. Is the CI using old torch? (before oblivious backed size?)

Right we do test with 2.7

@justinchuby justinchuby enabled auto-merge (squash) November 4, 2025 18:21
@justinchuby
Copy link
Collaborator

@ombrdr47 Thank you very much!

@ombrdr47
Copy link
Contributor Author

ombrdr47 commented Nov 4, 2025

@ombrdr47 Thank you very much!

Thank you @justinchuby and @titaiwangms for your guidance. Happy to be a contributor to the project!

@justinchuby justinchuby merged commit 1a27df1 into microsoft:main Nov 4, 2025
32 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in ONNX Script Review Board Nov 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

[torchlib] Implement RNN operators

3 participants