-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[Ascend] perf: optimize rope embedding with triton kernel for huge performance gain #5918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -187,6 +187,7 @@ def _rope_forward_oot( | |
| self.cos_sin_cache = self.cos_sin_cache.to(query.device) | ||
| if self.cos_sin_cache.dtype != query.dtype: | ||
| self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) | ||
| cos, sin = get_cos_and_sin_slice() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There might be problems to put this index_select here. Please contact @Angazenn to make sure of this.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fine |
||
| # adopt custom kernel path for rotary_embedding | ||
| if _custom_rotary_embedding_enabled( | ||
| query, is_neox_style, self.head_size) and get_ascend_device_type( | ||
|
|
@@ -204,7 +205,6 @@ def _rope_forward_oot( | |
| raise NotImplementedError( | ||
| "Batched rotary embedding is currently not supported on NPU.") | ||
| else: | ||
| cos, sin = get_cos_and_sin_slice() | ||
| if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ | ||
| -1] == 128 and cos is not None and sin is not None: | ||
| # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. | ||
|
|
@@ -217,28 +217,43 @@ def _rope_forward_oot( | |
| query, key = torch_npu.npu_apply_rotary_pos_emb( | ||
| query, key, cos, sin) | ||
| elif self.rotary_dim < self.head_size: | ||
| num_tokens = query.shape[0] | ||
| query = query.view(num_tokens, -1, self.head_size) | ||
| key = key.view(num_tokens, -1, self.head_size) | ||
| q_rot = query[..., :self.rotary_dim] | ||
| q_pass = query[..., self.rotary_dim:] | ||
| k_rot = key[..., :self.rotary_dim] | ||
| k_pass = key[..., self.rotary_dim:] | ||
| q_rot = q_rot.contiguous().view(num_tokens, -1) | ||
| k_rot = k_rot.contiguous().view(num_tokens, -1) | ||
| torch_npu._npu_rotary_embedding( | ||
| positions, | ||
| q_rot, | ||
| k_rot, | ||
| self.head_size, | ||
| self.cos_sin_cache, | ||
| is_neox_style, | ||
| ) | ||
| q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) | ||
| k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) | ||
| q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) | ||
| k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) | ||
| return q, k | ||
| if HAS_TRITON: | ||
|
|
||
| cos = cos.view(-1, self.rotary_dim) | ||
| sin = sin.view(-1, self.rotary_dim) | ||
| q = query.contiguous().view(query.shape[0], -1, | ||
| self.head_size) | ||
| k = key.contiguous().view(key.shape[0], -1, self.head_size) | ||
| query, key = torch.ops.vllm.rope_forward_triton(q, | ||
| k, | ||
| cos, | ||
| sin, | ||
| rope_dim=self.rotary_dim, | ||
| is_neox_style=True) | ||
| return query.view(query_shape), key.view(key_shape) | ||
| else: | ||
| num_tokens = query.shape[0] | ||
| query = query.view(num_tokens, -1, self.head_size) | ||
| key = key.view(num_tokens, -1, self.head_size) | ||
| q_rot = query[..., :self.rotary_dim] | ||
| q_pass = query[..., self.rotary_dim:] | ||
| k_rot = key[..., :self.rotary_dim] | ||
| k_pass = key[..., self.rotary_dim:] | ||
| q_rot = q_rot.contiguous().view(num_tokens, -1) | ||
| k_rot = k_rot.contiguous().view(num_tokens, -1) | ||
| torch_npu._npu_rotary_embedding( | ||
| positions, | ||
| q_rot, | ||
| k_rot, | ||
| self.head_size, | ||
| self.cos_sin_cache, | ||
| is_neox_style, | ||
| ) | ||
| q_rot = q_rot.view(num_tokens, -1, self.rotary_dim) | ||
| k_rot = k_rot.view(num_tokens, -1, self.rotary_dim) | ||
| q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape) | ||
| k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape) | ||
| return q, k | ||
| else: | ||
| # TODO: Remove the contiguous in the future. | ||
| query = query.contiguous().view(query.shape[0], -1) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function
rope_forward_tritonis used here as theop_funcfor the custom operator, but it is not defined or imported within this file. This will likely result in aNameErrorwhen this module is imported. To fix this, you should import it fromvllm_ascend.ops.triton.ropeat the top of the file.