[Spyre-Next] [Feature] rms_norm rework with vLLM IR#877
[Spyre-Next] [Feature] rms_norm rework with vLLM IR#877bohnstingl wants to merge 22 commits intotorch-spyre:mainfrom
Conversation
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
…R_integration Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
👋 Hi! Thank you for contributing to vLLM support on Spyre. We also recommend installing prek and configuring it to check your code before every local commit. |
|
@bohnstingl Shouldn't these methods run inside the compiled graph? Is that the current limitation where we do CPU->Spyre transfer inside the method? |
bringlein
left a comment
There was a problem hiding this comment.
Thanks for the effort @bohnstingl ! Good starting point to leverage vLLM IR.
| Spyre-specific implementation details: | ||
| - Epsilon as tensor: scalar broadcast limited, expand via torch.full() | ||
| - No dtype promotion: torch-spyre limitation, stays in input dtype | ||
| - variance_size: parameter currently not used | ||
| """ |
There was a problem hiding this comment.
I understand and agree with these constraints and I think we should proceed with this PR as is.
However, if we want to reduce the custom code for spyre even further, couldn't we reuse the upstream RMS norm if we automate the first two steps?
- to convert epsilon to a tensor could be a custom vllm compile pass? (or done in torch.spyre?)
- same for remove of the dtype cast?
There was a problem hiding this comment.
We could consider this once we fully utilize torch.compile for our provider, see my comment below. In a sense, currently our provider function is not traced, and thus no graph is created and thus we can't do any of the custom passes.
|
@romitjain The problem is that we need to perform the |
…R_integration Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
…R_integration Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
I updated the implementation and simplified the code infrastructure quite a bit. Currently, we only need to rewrite the provider function and don't require any of the CusomOp wrapping anymore. |
| # torch.ops.vllm_ir.rms_norm (a registered custom op). This creates | ||
| # an opaque boundary that Dynamo captures without tracing inside. | ||
| # Inductor. The provider therefore runs eagerly at each forward call. | ||
| compilation_config.ir_enable_torch_wrap = True |
There was a problem hiding this comment.
We need to use the ir_enable_torch_wrap at the moment, because of our .to(device="spyre") calls and also because an upstream issue in torch-spyre that needs to be investigated.
So, for the moment our provider function won't be compiled and runs in eager-mode.
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
Padding is no longer needed, since torch-spyre/torch-spyre#878 landed in torch-spyre. Therefore, we no longer need the D2H and H2D transfers in the IR and together with #878, we can avoid using |
…R_integration Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
I updated this PR and further simplified our code. At this stage, our provider is only different in three ways:
@tdoublep @bringlein please have a look |
…R_integration Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
…e into vllm_IR_integration
|
Merged-in the latest |
Description
This PR reworks the custom wrapping of torch-spyre ops to use the upstream vLLM IR system.
Key changes
New Spyre IR provider (
custom_ops/kernels/rms_norm.py): A standalone functionspyre_rms_normthat implements RMSNorm for spyre, registered with the vLLM IR system viair.ops.rms_norm.register_impl("spyre", ...), with spyre-specific workarounds.Simplified OOT wrapper (
custom_ops/rms_norm.py):SpyreRMSNorm._forward_spyre_impl()now callsir.ops.rms_norm(...)directly and is significantly simplified. The wrapper is reduced to: device sandwich (CPU → Spyre → CPU), batch padding, and residual handling. Note: This is only temporarily needed and can be removed with a spyre-specific model runner as well as with the upstream rework to aPluggableLayer.Platform Changes (
platform.py): Newget_default_ir_op_priority()classmethod setsrms_norm=["spyre", "native"], telling the IR system to prefer our spyre provider and only fall back to the pure-PyTorch native implementation when needed.check_and_update_config()forcescustom_ops=["all"]andir_enable_torch_wrap=False.How it works
The dispatch chain is:
Main benefits
Reuses the upstream vLLM kernel dispatch system. Instead of our own standalone kernel function called directly from the OOT wrapper (
maybe_compile), the spyre kernel is discovered and selected through the same priority-based dispatch that CUDA, ROCm, and XPU use.Simpler OOT wrapper. The
SpyreRMSNormclass no longer needs_dynamic_arg_dims,maybe_compile, etc. Moreover, the OOT wrapper can be further simplified when moving to aPluggableLayeror with a spyre-specific model_runner, keeping the inputs and outputs on spyre -> no device transfer.Decoupled provider. The spyre rms_norm kernel (
kernels/rms_norm.py) is a standalone function with no dependency on the OOT class. It can be tested in isolation viair.ops.rms_norm()and evolves independently of the wrapping logic.Limitations
PluggableLayerand the residual path handling. In general, this approach is bottle-necked by the availability of the vLLM IR for other operations, such as SiluAndMul, etc., but this is expected to evolve fast.torch.ops.vllm.spyre_rmsnorm) is still required. Dynamo cannot yet trace into the OOT wrapper, because it contains D2H and H2D tensor transfers and because of a necessary padding operation, that is not yet supported in torch-spyre. However, as mentioned above, these points can be addressed in the future and the overhead removed.cc @tdoublep @bringlein
Related Issues
#733
Test Plan
The change is non user-facing and thus in principle all current tests should pass. However, there are some
rms_norm-specific tests that currently use mocking. If we decide to keep them and not reuse the tests from upstream vLLM, we might need to adjust them. For example: https://github.com/vllm-project/vllm-spyre/blob/2d816a744f479025462bd97521c823c67b403173/vllm_spyre_next/tests/test_rms_norm.py#L97Checklist
bash format.sh)Signed-off-by:line (DCO compliance)