Skip to content

Conversation

@SangwonSUH
Copy link
Contributor

@SangwonSUH SangwonSUH commented Oct 16, 2025

What does this PR do ?

This PR adds type hints to mask_padded_tokens function to enable JIT scripting.

Collection: [common]

Changelog

  • Adding type hints has no effect directly with python runtime.
    Therefore, introduced a new test case test_transformer_utils.py to reproduce the scenario where JIT script conversion previously failed.

    Expand to view the initial JIT error details
    tests/collections/common/test_transformer_utils.py::TestMaskPaddedTokens::test_mask_padded_tokens_jit_script_compilation FAILED
    
    =========================================== FAILURES ============================================
    ______________ TestMaskPaddedTokens.test_mask_padded_tokens_jit_script_compilation ______________
    
    self = <tests.collections.common.test_transformer_utils.TestMaskPaddedTokens object at 0x14939b4a0>
    
        @pytest.mark.unit
        def test_mask_padded_tokens_jit_script_compilation(self):
            """Test that mask_padded_tokens works correctly with TorchScript compilation.
        
            This test ensures type hints are properly defined.
            """
        
            class SimpleModule(nn.Module):
                """Module wrapper for testing mask_padded_tokens with TorchScript."""
        
                def __init__(self, pad_id: int = 0):
                    super().__init__()
                    self.pad = pad_id
        
                def forward(self, tokens: torch.Tensor) -> torch.Tensor:
                    mask = mask_padded_tokens(tokens, self.pad)
                    return mask.float()
        
            module = SimpleModule(pad_id=0)
    >       scripted_module = torch.jit.script(module)
                              ^^^^^^^^^^^^^^^^^^^^^^^^
    
    tests/collections/common/test_transformer_utils.py:44: 
    _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
    ../../miniconda3/envs/nemo/lib/python3.12/site-packages/torch/jit/_script.py:1443: in script
        ret = _script_impl(
    ../../miniconda3/envs/nemo/lib/python3.12/site-packages/torch/jit/_script.py:1152: in _script_impl
        return torch.jit._recursive.create_script_module(
    ../../miniconda3/envs/nemo/lib/python3.12/site-packages/torch/jit/_recursive.py:556: in create_script_module
        return create_script_module_impl(nn_module, concrete_type, stubs_fn)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ../../miniconda3/envs/nemo/lib/python3.12/site-packages/torch/jit/_recursive.py:629: in create_script_module_impl
        create_methods_and_properties_from_stubs(
    _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
    
    concrete_type = <torch.ConcreteModuleType object at 0x1493c84b0>
    method_stubs = [ScriptMethodStub(resolution_callback=<function createResolutionCallbackFromEnv.<locals>.<lambda> at 0x149ac8680>, def... TestMaskPaddedTokens.test_mask_padded_tokens_jit_script_compilation.<locals>.SimpleModule.forward of SimpleModule()>)]
    property_stubs = []
    
        def create_methods_and_properties_from_stubs(
            concrete_type, method_stubs, property_stubs
        ):
            method_defs = [m.def_ for m in method_stubs]
            method_rcbs = [m.resolution_callback for m in method_stubs]
            method_defaults = [get_default_args(m.original_method) for m in method_stubs]
        
            property_defs = [p.def_ for p in property_stubs]
            property_rcbs = [p.resolution_callback for p in property_stubs]
        
    >       concrete_type._create_methods_and_properties(
                property_defs, property_rcbs, method_defs, method_rcbs, method_defaults
            )
    E       RuntimeError: 
    E       
    E       mask_padded_tokens(Tensor tokens, Tensor pad_id) -> Tensor:
    E       Expected a value of type 'Tensor (inferred)' for argument 'pad_id' but instead found type 'int'.
    E       Inferred 'pad_id' to be of type 'Tensor' because it was not annotated with an explicit type.
    E       :
    E         File "/Users/Sangwon/git/NeMo/tests/collections/common/test_transformer_utils.py", line 40
    E                   def forward(self, tokens: torch.Tensor) -> torch.Tensor:
    E                       mask = mask_padded_tokens(tokens, self.pad)
    E                              ~~~~~~~~~~~~~~~~~~ <--- HERE
    E                       return mask.float()
    
    ../../miniconda3/envs/nemo/lib/python3.12/site-packages/torch/jit/_recursive.py:465: RuntimeError
  • Added explicit type hints to the mask_padded_tokens to resolve JIT conversion errors.

Usage

# how to run test
pytest tests/collections/common/test_transformer_utils.py

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

The mask_padded_tokens is used in transformer_generators.py in both the asr, and nlp collections.
Requesting review from ASR and NLP members: @titu1994, @redoctopus, @jbalam-nv, @okuchaiev, @MaximumEntropy, @ericharper, @ekmb, @yzhang123, @VahidooX, @vladgets, or @okuchaiev

Additional Information

  • This change would help to enable a larger effort to JIT script the GreedySequenceGenerator from nemo.collections.asr.modules.transformer.transformer_generators.

@SangwonSUH SangwonSUH changed the title Add type hints to mask_padded_tokens Enable JIT Scripting for mask_padded_tokens Oct 16, 2025
@SangwonSUH SangwonSUH marked this pull request as ready for review October 16, 2025 16:14
@SangwonSUH SangwonSUH marked this pull request as draft October 17, 2025 00:32
@SangwonSUH SangwonSUH force-pushed the add_typehint_mask_padded_tokens branch from f27eff4 to ea0f130 Compare October 17, 2025 11:19
@SangwonSUH SangwonSUH marked this pull request as ready for review October 17, 2025 11:20
@SangwonSUH SangwonSUH changed the title Enable JIT Scripting for mask_padded_tokens Add type hints to mask_padded_tokens Oct 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant