Skip to content
Merged
4 changes: 4 additions & 0 deletions benchmarks/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def pytest_configure(config):
"markers",
"inner_persistent: mark tests using inner_persistent scheduler if not being segmented.",
)
config.addinivalue_line(
"markers",
"resize: mark tests using resize scheduler if not being segmented.",
)


def pytest_collection_modifyitems(session, config, items):
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/python/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, dtype):
super().__init__("hf_mistral_nemo", dtype)

def model(self):
from transformers.models.phi3 import MistralPreTrainedModel
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@protonu Just so that you are aware.

from transformers.models.mistral import MistralPreTrainedModel

class MyModel(MistralPreTrainedModel):
def __init__(self, config):
Expand Down
18 changes: 16 additions & 2 deletions benchmarks/python/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# SPDX-License-Identifier: BSD-3-Clause
from functools import partial

from transformers import AutoConfig


def llama_hf_cfg(config_str):
class Config:
Expand Down Expand Up @@ -40,6 +38,8 @@ def __init__(


def hf_qwen2_cfg():
from transformers import AutoConfig

config = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
config.batch_size = 1
config.seq_len = 4096
Expand All @@ -48,6 +48,8 @@ def hf_qwen2_cfg():


def hf_phi3_cfg():
from transformers import AutoConfig

config = AutoConfig.from_pretrained("microsoft/Phi-3.5-mini-instruct")
config.batch_size = 1
config.seq_len = 8192
Expand Down Expand Up @@ -96,10 +98,22 @@ def hf_mistral_nemo_cfg():
return cfg


def litgpt_cfg(model_name):
import litgpt

cfg = litgpt.Config.from_name(model_name)
cfg.batch_size = 1
cfg.seq_len = 4096
cfg.name_or_path = model_name

return cfg


configs = {
"llama_2_7b_hf": partial(llama_hf_cfg, config_str="llama_2_7b_hf"),
"llama_3_8B": partial(llama_hf_cfg, config_str="llama_3_8B"),
"hf_qwen2": hf_qwen2_cfg,
"hf_phi3": hf_phi3_cfg,
"hf_mistral_nemo": hf_mistral_nemo_cfg,
"litgpt": litgpt_cfg,
}
128 changes: 128 additions & 0 deletions benchmarks/python/rope_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,123 @@ def iobytes():
return MistralNemoRope(cfg).cuda().bfloat16(), inputs, grads, iobytes


def Litgpt(seq_length, model_name):
class LitgptRope(torch.nn.Module):
def __init__(self, config) -> None:
from litgpt.model import apply_rope

self.fused_apply_rotary_pos_emb_cached = None

super().__init__()
self.config = config
self.apply_rope = apply_rope

def forward(
self,
qkv: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
B, T, _ = qkv.shape # batch size, sequence length

# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
q_per_kv = self.config.n_head // self.config.n_query_groups
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
qkv = qkv.view(
B, T, self.config.n_query_groups, total_qkv, self.config.head_size
)
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)

# split batched computation into three
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)

# maybe repeat k and v if for the non multi-head attention cases
# training: flash attention requires it
# inference: multi-query would require a full kv cache so avoid it to limit its memory usage
if (
self.config.n_query_groups != self.config.n_head
and self.config.n_query_groups != 1
):
k = k.expand(
B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
)
v = v.expand(
B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
)

q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)

q_roped = self.apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
k_roped = self.apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
return q, k, v

cfg = configs["litgpt"](model_name)
# overwrite seq_length
cfg.seq_len = seq_length

def inputs():
qkv = torch.randn(
cfg.batch_size,
cfg.seq_len,
(cfg.n_head + 2 * cfg.n_query_groups) * cfg.head_size,
device="cuda",
dtype=torch.bfloat16,
requires_grad=True,
)
cos = torch.randn(
1,
cfg.seq_len,
cfg.rope_n_elem,
device="cuda",
dtype=torch.bfloat16,
requires_grad=False,
)
sin = torch.randn(
1,
cfg.seq_len,
cfg.rope_n_elem,
device="cuda",
dtype=torch.bfloat16,
requires_grad=False,
)
return qkv, cos, sin

def grads():
grad = torch.randn(
cfg.batch_size,
cfg.n_head,
cfg.seq_len,
cfg.head_size,
device="cuda",
dtype=torch.bfloat16,
requires_grad=False,
)
return grad

# Manual IOBytes computes the total bandwidth for thunder backward trace.
def iobytes():
n_elements = 0
# adding size of qkv.grad
n_elements += (
cfg.batch_size
* cfg.seq_len
* (cfg.n_head + 2 * cfg.n_query_groups)
* cfg.head_size
)
# adding size of sin, cos (saved from forward)
n_elements += 2 * cfg.seq_len * cfg.rope_n_elem
# adding size of q, k, v (saved from forward)
n_elements += 3 * cfg.batch_size * cfg.seq_len * cfg.n_head * cfg.head_size
# totoal io sizes
return n_elements * torch.bfloat16.itemsize

return LitgptRope(cfg).cuda().bfloat16(), inputs, grads, iobytes


# The setup returns a function that would setup benchmark by returning:
# fwd_model, inputs_fn, grads_fn, iobytes_fn
rope_setup = {
Expand All @@ -780,4 +897,15 @@ def iobytes():
"hf_qwen2": hf_qwen2,
"hf_phi3": hf_phi3,
"hf_mistral_nemo": hf_mistral_nemo,
"litgpt-gemma-2-9b": partial(Litgpt, model_name="google/gemma-2-9b-it"),
"litgpt-mistral-7b": partial(
Litgpt, model_name="mistralai/Mistral-7B-Instruct-v0.3"
),
"litgpt-meta-llama-3-8B": partial(
Litgpt, model_name="meta-llama/Meta-Llama-3-8B-Instruct"
),
"litgpt-phi3.5-mini": partial(
Litgpt,
model_name="microsoft/Phi-3.5-mini-instruct",
),
}
10 changes: 10 additions & 0 deletions benchmarks/python/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
"hf_qwen2",
"hf_phi3",
"hf_mistral_nemo",
"litgpt-gemma-2-9b",
"litgpt-mistral-7b",
"litgpt-meta-llama-3-8B",
"litgpt-phi3.5-mini",
],
)
@pytest.mark.parametrize(
"executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"]
)
@pytest.mark.parametrize("seq_length", SEQ_LENGTHS)
@pytest.mark.resize
def test_rope_fwd_benchmark(
benchmark,
variation: str,
Expand Down Expand Up @@ -52,12 +57,17 @@ def fwd_call(inp):
"hf_qwen2",
"hf_phi3",
"hf_mistral_nemo",
"litgpt-gemma-2-9b",
"litgpt-mistral-7b",
"litgpt-meta-llama-3-8B",
"litgpt-phi3.5-mini",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xwang233 do we need to manually add new entries in dashboard?

Copy link
Collaborator

@xwang233 xwang233 Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might not show up in the PR benchmark results perhaps it will work automatically (let's see), but it will show up in nightly benchmark results once merged

],
)
@pytest.mark.parametrize(
"executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"]
)
@pytest.mark.parametrize("seq_length", SEQ_LENGTHS)
@pytest.mark.resize
def test_rope_bwd_benchmark(
benchmark,
variation: str,
Expand Down