[Spyre-Next] [Feature] Wrap RoPE layer on Spyre#881
[Spyre-Next] [Feature] Wrap RoPE layer on Spyre#881dilipgb wants to merge 9 commits intotorch-spyre:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to vLLM support on Spyre. We also recommend installing prek and configuring it to check your code before every local commit. |
bohnstingl
left a comment
There was a problem hiding this comment.
Thank you @dilipgb. I made a first pass through the PR and left some comments.
In general, I think you need to merge in the latest main branch, which would remove the changes from pyproject.toml and uv.lock. At least the changes there are not related and should be removed.
Also, could you please adopt the new call chain from #872?
| - No dtype promotion (torch-spyre limitation) | ||
| - rope_scaling not yet implemented | ||
| - Expect numerical differences from upstream vLLM |
There was a problem hiding this comment.
Ad 1) Where are is the dtype promotion happening in upstream?
I see that there is enable_fp32_compute, but I belief for Granite it is False?
Ad 2) Can we have an assert / raise Exception to ensure that this code path is reached, i.e., that scaling_type == "default" and "mrope_section" not in rope_parameters and "use_fope" not in rope_parameters or not rope_parameters["use_fope"].
Ad 3) s the dtype promotion the source for the numerical differences, or is there anything apart from that?
There was a problem hiding this comment.
- dtype promotion in upstream is optional and will be enabled only when enable_fp32_compute. We can have a another condition to check for this? But granite dtype is fp16 and we will be using the same without upcasting, that is the though process here.
- addressed.
- Yes since trig functions and other intermediate operations in upstream vllm is upcasted for better precision.
| # Use float16 directly - no dynamic dimensions (Spyre constraint) | ||
| compute_dtype = torch.float16 | ||
|
|
||
| # Compute inverse frequencies: base^(-2i/rotary_dim) | ||
| # Using negative exponent for numerical stability | ||
| exponents = -torch.arange(0, self.rotary_dim, 2, dtype=compute_dtype) / self.rotary_dim | ||
| inv_freq = torch.pow(self.base, exponents) | ||
|
|
||
| # Create position indices [0, 1, 2, ..., max_position_embeddings-1] | ||
| t = torch.arange(self.max_position_embeddings, dtype=compute_dtype) | ||
|
|
||
| # Compute frequencies for each position: pos * inv_freq | ||
| # Shape: [max_position_embeddings, rotary_dim // 2] | ||
| freqs = torch.outer(t, inv_freq) | ||
|
|
||
| # Duplicate frequencies for interleaved pattern | ||
| # Shape: [max_position_embeddings, rotary_dim] | ||
| emb = torch.cat([freqs, freqs], dim=-1) |
There was a problem hiding this comment.
Can we make a comment that these ops are currently happening on CPU?
torch-spyre has had some more ops added lately and torch.cat should now work on spyre. So we might want to try and convert some of these operations to be happening on spyre.
There was a problem hiding this comment.
I tried torch.arrange and torch.outer which are not yet implemented on spyre. torch.cat though implemented on spyre emb calculation will again fall back on CPU. We need to move data too and fro from CPU to card and card to CPU multiple times for only supporting torch.cat on cpu.
There was a problem hiding this comment.
outer, may indeed not be supported. However, torch.cat and torch.arange should be.
Although torch.cat and torch.arange might have CPU fallbacks, I think we should still try to use them with the spyre device, because once those operations are supported through torch-spyre, they will just work in vllm-spyre.
| query_rot = query[..., :rotary_dim] | ||
| query_pass = query[..., rotary_dim:] | ||
| key_rot = key[..., :rotary_dim] | ||
| key_pass = key[..., rotary_dim:] |
There was a problem hiding this comment.
Same here, tensor slicing shouldn't currently work with the tensors on spyre? Can you confirm that the tensors are indeed on spyre?
| # Retrieve cos/sin for the given positions | ||
| # positions shape: [batch_size, seq_len] or [total_tokens] | ||
| cos = cos_cache[positions] # [..., rotary_dim] | ||
| sin = sin_cache[positions] # [..., rotary_dim] |
There was a problem hiding this comment.
I am surprised this is actually working when cos_cache and sin_cache are on spyre?
There was a problem hiding this comment.
@dilipgb I am still puzzled by how this can work for you. I tested it locally and the tensor slicing fails, as this is not yet supported in torch-spyre. At least not in eager mode. Can you confirm that?
Therefore, I think we should restructure this function a bit overall and do the slicing in _forward_spyre_impl on CPU, pass-in the two halves, apply the individual RoPE on them, return them and then combine them together again in _forward_spyre_impl on CPU.
Signed-off-by: Dilip Gowda Bhagavan <dilip.bhagavan@ibm.com>
<!-- markdownlint-disable --> This PR bumps the lower bound of foundation-model-stack dependency from 1.7.0 to 1.8.0 which includes Llama bug fixes for torch 2.10. <!-- Link related issues, e.g., `Fixes #` or `Relates to torch-spyre#456` --> <!-- Describe how you tested your changes. Include commands or steps to reproduce. --> - [ ] I have read the [contributing guidelines](https://docs.vllm.ai/projects/spyre/en/latest/contributing) - [ ] My code follows the project's code style (run `bash format.sh`) - [ ] I have added tests for my changes (if applicable) - [ ] I have updated the documentation (if applicable) - [ ] My commits include a `Signed-off-by:` line (DCO compliance) --------- Signed-off-by: Daniel Schenker <daniel.schenker@ibm.com> Signed-off-by: Dilip Gowda Bhagavan <dilip.bhagavan@ibm.com>
Signed-off-by: Dilip Gowda Bhagavan <dilip.bhagavan@ibm.com>
Signed-off-by: Dilip Gowda Bhagavan <dilip.bhagavan@ibm.com>
Signed-off-by: Dilip Gowda Bhagavan <dilip.bhagavan@ibm.com>
Signed-off-by: Dilip Gowda Bhagavan <dilip.bhagavan@ibm.com>
bohnstingl
left a comment
There was a problem hiding this comment.
@dilipgb Can you please have a look at my comments and also merge-in the latest main?
| # Transfer cos/sin cache to Spyre device if not already there | ||
| # if self.cos_cache.device != self._target_device: | ||
| # self.cos_cache = convert(self.cos_cache, self._target_device, self._target_dtype) | ||
| # self.sin_cache = convert(self.sin_cache, self._target_device, self._target_dtype) |
| Rotated tensor [..., rotary_dim] | ||
| """ | ||
| x1, x2 = x.chunk(2, dim=-1) | ||
| return torch.cat([-x2, x1], dim=-1) |
There was a problem hiding this comment.
I think we can rework this to be supported on spyre. In particular,
- Let's replace tensor.chunk with torch.split, which is supported. E.g.,
x1, x2 = torch.split(x, [d, d], dim=1), wheredis passed in from the caller. - torch.cat should be supported in torch-spyre, see https://github.com/torch-spyre/torch-spyre/blob/eaf2f76026880b071cd8dffcf473685e8223c8aa/tests/inductor/test_inductor_ops.py#L1799-L1803
| # Use float16 directly - no dynamic dimensions (Spyre constraint) | ||
| compute_dtype = torch.float16 | ||
|
|
||
| # Compute inverse frequencies: base^(-2i/rotary_dim) | ||
| # Using negative exponent for numerical stability | ||
| exponents = -torch.arange(0, self.rotary_dim, 2, dtype=compute_dtype) / self.rotary_dim | ||
| inv_freq = torch.pow(self.base, exponents) | ||
|
|
||
| # Create position indices [0, 1, 2, ..., max_position_embeddings-1] | ||
| t = torch.arange(self.max_position_embeddings, dtype=compute_dtype) | ||
|
|
||
| # Compute frequencies for each position: pos * inv_freq | ||
| # Shape: [max_position_embeddings, rotary_dim // 2] | ||
| freqs = torch.outer(t, inv_freq) | ||
|
|
||
| # Duplicate frequencies for interleaved pattern | ||
| # Shape: [max_position_embeddings, rotary_dim] | ||
| emb = torch.cat([freqs, freqs], dim=-1) |
There was a problem hiding this comment.
outer, may indeed not be supported. However, torch.cat and torch.arange should be.
Although torch.cat and torch.arange might have CPU fallbacks, I think we should still try to use them with the spyre device, because once those operations are supported through torch-spyre, they will just work in vllm-spyre.
| # Retrieve cos/sin for the given positions | ||
| # positions shape: [batch_size, seq_len] or [total_tokens] | ||
| cos = cos_cache[positions] # [..., rotary_dim] | ||
| sin = sin_cache[positions] # [..., rotary_dim] |
There was a problem hiding this comment.
@dilipgb I am still puzzled by how this can work for you. I tested it locally and the tensor slicing fails, as this is not yet supported in torch-spyre. At least not in eager mode. Can you confirm that?
Therefore, I think we should restructure this function a bit overall and do the slicing in _forward_spyre_impl on CPU, pass-in the two halves, apply the individual RoPE on them, return them and then combine them together again in _forward_spyre_impl on CPU.
Signed-off-by: Dilip Gowda Bhagavan <dilip.bhagavan@ibm.com>
Signed-off-by: Dilip Gowda Bhagavan <110233170+dilipgb@users.noreply.github.com>
bohnstingl
left a comment
There was a problem hiding this comment.
Some minor details that I don't forget.
As discussed offline, it would be good to run as many operations on spyre. For example, torch.cat should be supported now.
Indexing operations, such as slicing still need to stay on cpu for the moment.
| Tq, q_hidden = query.shape | ||
| Tk, k_hidden = key.shape | ||
|
|
||
| assert Tq == Tk, f"Query/Key sequence mismatch: {Tq} != {Tk}" |
There was a problem hiding this comment.
Why do we have this constraint? Is this also part of the upstream implementation? In upstream I see https://github.com/vllm-project/vllm/blob/a5b17fba8ff3fcc076d73ba749a0819e0ec25f06/vllm/model_executor/layers/rotary_embedding/base.py#L140-L180
| # Compile the forward kernel | ||
| self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre) | ||
| self._layer_name = register_layer(self, "spyre_rotary_embedding") | ||
|
|
There was a problem hiding this comment.
We recently introduced an additional logging. Thus, please include something like
logger.debug_once(
"SpyreRotaryEmbedding: Dispatch: enabled=%s, Forward method=%s, Compiled=%s",
self.enabled(),
self._forward_method.__name__,
self.maybe_compiled_forward_spyre is not self.forward_spyre,
)
| assert cos_q.shape == query_rot.shape, f"{cos_q.shape} != {query.shape}" | ||
| assert sin_q.shape == query_rot.shape |
There was a problem hiding this comment.
Can we have more descriptive error messages here?
Description
Adds SpyreRotaryEmbeding, a Spyre-optimized out-of-tree (OOT) replacement for vLLM's RotaryEmbedding, following the custom op pattern from #842 (same as SpyreRMSNorm and SpyreSiluAndMul).
Related Issues
Fixes #820
Test Plan
Ran Couple of tests to confirm its working on spyre. Also looking into upstream test for rotary embeddings which is in progress.
examples/torch_spyre_inference.py
vllm_spyre_next/examples/Offline_demo.py
Checklist
bash format.sh)Signed-off-by:line (DCO compliance)