Skip to content

Litgpt benchmark#4320

Merged
jjsjann123 merged 13 commits intomainfrom
litgpt_benchmark
Apr 30, 2025
Merged

Litgpt benchmark#4320
jjsjann123 merged 13 commits intomainfrom
litgpt_benchmark

Conversation

@jjsjann123
Copy link
Collaborator

Fixes #4253

@github-actions
Copy link

github-actions bot commented Apr 25, 2025

Review updated until commit 183b4f9

Description

  • Added LitGPT benchmark configurations

  • Implemented LitGPT model setup in rope_ops.py

  • Updated conftest.py with new resize marker

  • Fixed import in cross_entropy_loss.py


Changes walkthrough 📝

Relevant files
Enhancement
conftest.py
Add resize marker                                                                               

benchmarks/python/conftest.py

  • Added resize marker to pytest configuration
+4/-0     
model_configs.py
Add LitGPT configuration                                                                 

benchmarks/python/model_configs.py

  • Added litgpt_cfg function to load LitGPT configurations
  • Moved AutoConfig import inside functions
  • +16/-2   
    rope_ops.py
    Add LitGPT setup in rope_ops                                                         

    benchmarks/python/rope_ops.py

  • Implemented Litgpt function for LitGPT model setup
  • Added LitGPT configurations to rope_setup
  • +128/-0 
    test_rope.py
    Update test_rope with LitGPT                                                         

    benchmarks/python/test_rope.py

  • Added LitGPT variations to test parameters
  • Marked LitGPT tests with resize marker
  • +10/-0   
    Bug fix
    cross_entropy_loss.py
    Fix import in cross_entropy_loss                                                 

    benchmarks/python/cross_entropy_loss.py

  • Corrected import from transformers.models.mistral instead of phi3
  • +1/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Import Error

    The import statement for MistralPreTrainedModel is incorrect. It should be imported from the correct module.

    from transformers.models.mistral import MistralPreTrainedModel
    Code Complexity

    The Litgpt function is quite long and complex. Consider breaking it down into smaller, more manageable functions.

    def Litgpt(seq_length, model_name):
        class LitgptRope(torch.nn.Module):
            def __init__(self, config) -> None:
                from litgpt.model import apply_rope
    
                self.fused_apply_rotary_pos_emb_cached = None
    
                super().__init__()
                self.config = config
                self.apply_rope = apply_rope
    
            def forward(
                self,
                qkv: torch.Tensor,
                cos: torch.Tensor,
                sin: torch.Tensor,
            ) -> torch.Tensor:
                B, T, _ = qkv.shape  # batch size, sequence length
    
                # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
                q_per_kv = self.config.n_head // self.config.n_query_groups
                total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value
                qkv = qkv.view(
                    B, T, self.config.n_query_groups, total_qkv, self.config.head_size
                )
                qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)
    
                # split batched computation into three
                q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
    
                # maybe repeat k and v if for the non multi-head attention cases
                # training: flash attention requires it
                # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
                if (
                    self.config.n_query_groups != self.config.n_head
                    and self.config.n_query_groups != 1
                ):
                    k = k.expand(
                        B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
                    )
                    v = v.expand(
                        B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
                    )
    
                q = q.reshape(B, -1, T, self.config.head_size)  # (B, nh_q, T, hs)
                k = k.reshape(B, -1, T, self.config.head_size)  # (B, nh_k, T, hs)
                v = v.reshape(B, -1, T, self.config.head_size)  # (B, nh_v, T, hs)
    
                q_roped = self.apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
                k_roped = self.apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
                q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
                k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
                return q, k, v
    
        cfg = configs["litgpt"](model_name)
        # overwrite seq_length
        cfg.seq_len = seq_length
    
        def inputs():
            qkv = torch.randn(
                cfg.batch_size,
                cfg.seq_len,
                (cfg.n_head + 2 * cfg.n_query_groups) * cfg.head_size,
                device="cuda",
                dtype=torch.bfloat16,
                requires_grad=True,
            )
            cos = torch.randn(
                1,
                cfg.seq_len,
                cfg.rope_n_elem,
                device="cuda",
                dtype=torch.bfloat16,
                requires_grad=False,
            )
            sin = torch.randn(
                1,
                cfg.seq_len,
                cfg.rope_n_elem,
                device="cuda",
                dtype=torch.bfloat16,
                requires_grad=False,
            )
            return qkv, cos, sin
    
        def grads():
            grad = torch.randn(
                cfg.batch_size,
                cfg.n_head,
                cfg.seq_len,
                cfg.head_size,
                device="cuda",
                dtype=torch.bfloat16,
                requires_grad=False,
            )
            return grad
    
        # Manual IOBytes computes the total bandwidth for thunder backward trace.
        def iobytes():
            n_elements = 0
            # adding size of qkv.grad
            n_elements += (
                cfg.batch_size
                * cfg.seq_len
                * (cfg.n_head + 2 * cfg.n_query_groups)
                * cfg.head_size
            )
            # adding size of sin, cos (saved from forward)
            n_elements += 2 * cfg.seq_len * cfg.rope_n_elem
            # adding size of q, k, v (saved from forward)
            n_elements += 3 * cfg.batch_size * cfg.seq_len * cfg.n_head * cfg.head_size
            # totoal io sizes
            return n_elements * torch.bfloat16.itemsize
    
        return LitgptRope(cfg).cuda().bfloat16(), inputs, grads, iobytes
    Test Coverage

    Ensure that the new test cases cover a variety of scenarios and edge cases for the litgpt models.

    "litgpt-gemma-2-9b",
    "litgpt-mistral-7b",
    "litgpt-meta-llama-3-8B",
    "litgpt-phi3.5-mini",

    @jjsjann123
    Copy link
    Collaborator Author

    hmmm. the number again doesn't match the original benchmark. I need to take another look at that.

    The added benchmark

    Name (time in us)                                                                           Mean
    --------------------------------------------------------------------------------------------------------
    test_rope_fwd_benchmark[seq_length=4096-executor='torchcompile'-variation='litgpt']      90.1059 (1.0)
    test_rope_fwd_benchmark[seq_length=4096-executor='thunder'-variation='litgpt']          112.1219 (1.24)
    test_rope_bwd_benchmark[seq_length=4096-executor='torchcompile'-variation='litgpt']     140.5902 (1.56)
    test_rope_bwd_benchmark[seq_length=4096-executor='thunder'-variation='litgpt']          329.3374 (3.66)
    --------------------------------------------------------------------------------------------------------
    

    vs reference benchmark

                    Executor                     Model     DType  Batch  Seq-Len  Fwd-Krnls  Fwd-K-Time(ms)  Bwd-Krnls  Bwd-K-Time(ms)
    1          torch.compile  Meta-Llama-3-8B-Instruct  bfloat16      1     4096          3           0.090          3           0.129
    3        Thunder-nvFuser  Meta-Llama-3-8B-Instruct  bfloat16      1     4096          3           0.098          6           0.289
    
    

    @jjsjann123
    Copy link
    Collaborator Author

    kernel looks the same. The difference in measured time is coming from:

    1. not clearing L2 cache
    2. not clearing grad on inputs

    Since those are coming from the reference implementation. I'm not going to update that.

    @jjsjann123 jjsjann123 requested review from Priya2698 and naoyam April 28, 2025 17:27
    @jjsjann123 jjsjann123 marked this pull request as ready for review April 28, 2025 17:27
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @@ -0,0 +1 @@
    litgpt[all]
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Can we make this requirement local?
    We have several benchmark files that can run okay without this module. So this can be a hassle when trying to run unrelated benchmarks.

    One way could be to use @pytest.mark.skipif to check for the presence of this module in relevant benchmarks.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Sounds fair. I guess I should have stick to what the other benchmarks are doing and rely on the module installed in the container. At least we do have litgpt in our CI containers. I'll remove this file.

    @naoyam
    Copy link
    Collaborator

    naoyam commented Apr 28, 2025

    Can you add a marker? #4290

    @jjsjann123 jjsjann123 requested a review from Priya2698 April 29, 2025 00:16
    @naoyam
    Copy link
    Collaborator

    naoyam commented Apr 29, 2025

    I don't have any comment anymore. Thanks @jjsjann123 for adding these benchmarks. I'll let @Priya2698 to give a final stamp.

    Copy link
    Collaborator

    @Priya2698 Priya2698 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Please remove requirements.txt or make it a local check. LGTM otherwise.

    @jjsjann123
    Copy link
    Collaborator Author

    !test --pybench

    "litgpt-gemma-2-9b",
    "litgpt-mistral-7b",
    "litgpt-meta-llama-3-8B",
    "litgpt-phi3.5-mini",
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    @xwang233 do we need to manually add new entries in dashboard?

    Copy link
    Collaborator

    @xwang233 xwang233 Apr 29, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    it might not show up in the PR benchmark results perhaps it will work automatically (let's see), but it will show up in nightly benchmark results once merged

    @jjsjann123
    Copy link
    Collaborator Author

    hmmm... seeing an import error.

    00:19:01 FAILED benchmarks/python/test_cross_entropy_loss.py::test_cross_entropy_fwd_benchmark[executor='thunder-torchcompile'-variation='hf_mistral_nemo'] - ImportError: cannot import name 'MistralPreTrainedModel' from 'transformers.models.phi3' (/usr/local/lib/python3.12/dist-packages/transformers/models/phi3/__init__.py)
    

    Let me investigate.

    @jjsjann123
    Copy link
    Collaborator Author

    errr. that's coming from the cross_entropy benchmark. I"ll just patch that.
    It's a bit unfortunate that given the numerical mismatch in python benchmark, real errors are buried in false negative signals.

    cc'ing @protonu

    @jjsjann123
    Copy link
    Collaborator Author

    !build

    super().__init__("hf_mistral_nemo", dtype)

    def model(self):
    from transformers.models.phi3 import MistralPreTrainedModel
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    @protonu Just so that you are aware.

    @jjsjann123 jjsjann123 merged commit 8c12206 into main Apr 30, 2025
    16 checks passed
    @jjsjann123 jjsjann123 deleted the litgpt_benchmark branch April 30, 2025 20:43
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    Add Litgpt RoPE into nvFuser python benchmark suite

    4 participants