diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index 383a3c83b84a..9e7441bd39e5 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -18,6 +18,7 @@ BATCH_SIZES = [5] # Arbitrary values for testing SEQ_LENS = [11, 8192] # Arbitrary values for testing SEEDS = [0] +PADDING_LEN = 64 CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] @@ -32,7 +33,7 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, # For testing sliced tensors def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int, head_size: int) -> tuple[int, ...]: - return (batch_size, seq_len, num_heads, head_size + 64) + return (batch_size, seq_len, num_heads, head_size + PADDING_LEN) def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, @@ -72,6 +73,9 @@ def test_rotary_embedding( max_position: int = 8192, base: int = 10000, ) -> None: + raw_input_head_size = head_size + if tensor_shape_fn is _get_padded_tensor_shape: + head_size += PADDING_LEN if rotary_dim is None: rotary_dim = head_size @@ -83,7 +87,8 @@ def test_rotary_embedding( rope = rope.to(dtype=dtype, device=torch.get_default_device()) positions = torch.randint(0, max_position, (batch_size, seq_len)) - query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) + query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, + raw_input_head_size) query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) if use_key else None @@ -139,6 +144,10 @@ def test_batched_rotary_embedding( ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) + + raw_input_head_size = head_size + if tensor_shape_fn is _get_padded_tensor_shape: + head_size += PADDING_LEN if rotary_dim is None: rotary_dim = head_size rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { @@ -148,7 +157,8 @@ def test_batched_rotary_embedding( rope = rope.to(dtype=dtype, device=torch.get_default_device()) positions = torch.randint(0, max_position, (batch_size, seq_len)) - query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) + query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, + raw_input_head_size) query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) if use_key else None