diff --git a/benchmarks/python/conftest.py b/benchmarks/python/conftest.py index a8301f0d8de..835e23c7a34 100644 --- a/benchmarks/python/conftest.py +++ b/benchmarks/python/conftest.py @@ -97,6 +97,10 @@ def pytest_configure(config): "markers", "inner_persistent: mark tests using inner_persistent scheduler if not being segmented.", ) + config.addinivalue_line( + "markers", + "resize: mark tests using resize scheduler if not being segmented.", + ) def pytest_collection_modifyitems(session, config, items): diff --git a/benchmarks/python/cross_entropy_loss.py b/benchmarks/python/cross_entropy_loss.py index 7478b2bfc81..f9747915704 100644 --- a/benchmarks/python/cross_entropy_loss.py +++ b/benchmarks/python/cross_entropy_loss.py @@ -132,7 +132,7 @@ def __init__(self, dtype): super().__init__("hf_mistral_nemo", dtype) def model(self): - from transformers.models.phi3 import MistralPreTrainedModel + from transformers.models.mistral import MistralPreTrainedModel class MyModel(MistralPreTrainedModel): def __init__(self, config): diff --git a/benchmarks/python/model_configs.py b/benchmarks/python/model_configs.py index 430a7a4ab1c..9a975bac15d 100644 --- a/benchmarks/python/model_configs.py +++ b/benchmarks/python/model_configs.py @@ -3,8 +3,6 @@ # SPDX-License-Identifier: BSD-3-Clause from functools import partial -from transformers import AutoConfig - def llama_hf_cfg(config_str): class Config: @@ -40,6 +38,8 @@ def __init__( def hf_qwen2_cfg(): + from transformers import AutoConfig + config = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct") config.batch_size = 1 config.seq_len = 4096 @@ -48,6 +48,8 @@ def hf_qwen2_cfg(): def hf_phi3_cfg(): + from transformers import AutoConfig + config = AutoConfig.from_pretrained("microsoft/Phi-3.5-mini-instruct") config.batch_size = 1 config.seq_len = 8192 @@ -96,10 +98,22 @@ def hf_mistral_nemo_cfg(): return cfg +def litgpt_cfg(model_name): + import litgpt + + cfg = litgpt.Config.from_name(model_name) + cfg.batch_size = 1 + cfg.seq_len = 4096 + cfg.name_or_path = model_name + + return cfg + + configs = { "llama_2_7b_hf": partial(llama_hf_cfg, config_str="llama_2_7b_hf"), "llama_3_8B": partial(llama_hf_cfg, config_str="llama_3_8B"), "hf_qwen2": hf_qwen2_cfg, "hf_phi3": hf_phi3_cfg, "hf_mistral_nemo": hf_mistral_nemo_cfg, + "litgpt": litgpt_cfg, } diff --git a/benchmarks/python/rope_ops.py b/benchmarks/python/rope_ops.py index a76914373fd..6a706ad9e00 100644 --- a/benchmarks/python/rope_ops.py +++ b/benchmarks/python/rope_ops.py @@ -772,6 +772,123 @@ def iobytes(): return MistralNemoRope(cfg).cuda().bfloat16(), inputs, grads, iobytes +def Litgpt(seq_length, model_name): + class LitgptRope(torch.nn.Module): + def __init__(self, config) -> None: + from litgpt.model import apply_rope + + self.fused_apply_rotary_pos_emb_cached = None + + super().__init__() + self.config = config + self.apply_rope = apply_rope + + def forward( + self, + qkv: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + B, T, _ = qkv.shape # batch size, sequence length + + # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) + q_per_kv = self.config.n_head // self.config.n_query_groups + total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + qkv = qkv.view( + B, T, self.config.n_query_groups, total_qkv, self.config.head_size + ) + qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) + + # split batched computation into three + q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) + + # maybe repeat k and v if for the non multi-head attention cases + # training: flash attention requires it + # inference: multi-query would require a full kv cache so avoid it to limit its memory usage + if ( + self.config.n_query_groups != self.config.n_head + and self.config.n_query_groups != 1 + ): + k = k.expand( + B, self.config.n_query_groups, q_per_kv, T, self.config.head_size + ) + v = v.expand( + B, self.config.n_query_groups, q_per_kv, T, self.config.head_size + ) + + q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) + k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) + v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) + + q_roped = self.apply_rope(q[..., : self.config.rope_n_elem], cos, sin) + k_roped = self.apply_rope(k[..., : self.config.rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) + k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) + return q, k, v + + cfg = configs["litgpt"](model_name) + # overwrite seq_length + cfg.seq_len = seq_length + + def inputs(): + qkv = torch.randn( + cfg.batch_size, + cfg.seq_len, + (cfg.n_head + 2 * cfg.n_query_groups) * cfg.head_size, + device="cuda", + dtype=torch.bfloat16, + requires_grad=True, + ) + cos = torch.randn( + 1, + cfg.seq_len, + cfg.rope_n_elem, + device="cuda", + dtype=torch.bfloat16, + requires_grad=False, + ) + sin = torch.randn( + 1, + cfg.seq_len, + cfg.rope_n_elem, + device="cuda", + dtype=torch.bfloat16, + requires_grad=False, + ) + return qkv, cos, sin + + def grads(): + grad = torch.randn( + cfg.batch_size, + cfg.n_head, + cfg.seq_len, + cfg.head_size, + device="cuda", + dtype=torch.bfloat16, + requires_grad=False, + ) + return grad + + # Manual IOBytes computes the total bandwidth for thunder backward trace. + def iobytes(): + n_elements = 0 + # adding size of qkv.grad + n_elements += ( + cfg.batch_size + * cfg.seq_len + * (cfg.n_head + 2 * cfg.n_query_groups) + * cfg.head_size + ) + # adding size of sin, cos (saved from forward) + n_elements += 2 * cfg.seq_len * cfg.rope_n_elem + # adding size of q, k, v (saved from forward) + n_elements += 3 * cfg.batch_size * cfg.seq_len * cfg.n_head * cfg.head_size + # totoal io sizes + return n_elements * torch.bfloat16.itemsize + + return LitgptRope(cfg).cuda().bfloat16(), inputs, grads, iobytes + + # The setup returns a function that would setup benchmark by returning: # fwd_model, inputs_fn, grads_fn, iobytes_fn rope_setup = { @@ -780,4 +897,15 @@ def iobytes(): "hf_qwen2": hf_qwen2, "hf_phi3": hf_phi3, "hf_mistral_nemo": hf_mistral_nemo, + "litgpt-gemma-2-9b": partial(Litgpt, model_name="google/gemma-2-9b-it"), + "litgpt-mistral-7b": partial( + Litgpt, model_name="mistralai/Mistral-7B-Instruct-v0.3" + ), + "litgpt-meta-llama-3-8B": partial( + Litgpt, model_name="meta-llama/Meta-Llama-3-8B-Instruct" + ), + "litgpt-phi3.5-mini": partial( + Litgpt, + model_name="microsoft/Phi-3.5-mini-instruct", + ), } diff --git a/benchmarks/python/test_rope.py b/benchmarks/python/test_rope.py index dd076f8f9f3..3ada39abd1c 100644 --- a/benchmarks/python/test_rope.py +++ b/benchmarks/python/test_rope.py @@ -15,12 +15,17 @@ "hf_qwen2", "hf_phi3", "hf_mistral_nemo", + "litgpt-gemma-2-9b", + "litgpt-mistral-7b", + "litgpt-meta-llama-3-8B", + "litgpt-phi3.5-mini", ], ) @pytest.mark.parametrize( "executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"] ) @pytest.mark.parametrize("seq_length", SEQ_LENGTHS) +@pytest.mark.resize def test_rope_fwd_benchmark( benchmark, variation: str, @@ -52,12 +57,17 @@ def fwd_call(inp): "hf_qwen2", "hf_phi3", "hf_mistral_nemo", + "litgpt-gemma-2-9b", + "litgpt-mistral-7b", + "litgpt-meta-llama-3-8B", + "litgpt-phi3.5-mini", ], ) @pytest.mark.parametrize( "executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"] ) @pytest.mark.parametrize("seq_length", SEQ_LENGTHS) +@pytest.mark.resize def test_rope_bwd_benchmark( benchmark, variation: str,