Skip to content

[Spyre-Next] E2E test with optional offloading to spyre layers#900

Merged
bohnstingl merged 12 commits intotorch-spyre:mainfrom
romitjain:tests/e2e-layer-wise
Apr 10, 2026
Merged

[Spyre-Next] E2E test with optional offloading to spyre layers#900
bohnstingl merged 12 commits intotorch-spyre:mainfrom
romitjain:tests/e2e-layer-wise

Conversation

@romitjain
Copy link
Copy Markdown
Collaborator

@romitjain romitjain commented Apr 7, 2026

Description

I have updated examples/torch_spyre_inference.py to support the following arguments

  1. enforce_eager: Run in either compile mode or eager mode
  2. custom_ops: Support for dispatching our custom ops for forward pass. This can be used to offload individual layers to Spyre and test e2e inference with them. Example: Even if we have individual layer tests passing for SpyreRMSNorm, the e2e inference might diverge due to small numerical differences piling from multiple layers.

With this script, we should be able to test the e2e inference of any custom layer that we implement in both eager mode and compile mode.

A few relevant resources:

  1. How custom_ops are enabled for different enforce_eager modes: https://docs.vllm.ai/en/stable/api/vllm/config/compilation/#vllm.config.compilation.CompilationConfig.custom_ops
  2. How dispatch is decided for CustomOps: https://docs.vllm.ai/en/latest/design/custom_op/#how-customop-works-in-vllm

Related Issues

python examples/torch_spyre_inference.py -n 1 --custom_ops none +RMSNorm # Passes, compile mode for SpyreRMSNorm
python examples/torch_spyre_inference.py -n 1 --enforce_eager --custom_ops none +RMSNorm # Fails, see #794, eager mode for SpyreRMSNorm

Fixes:

Test Plan

I ran the following script with different custom ops. More in the internal slack thread here.

python examples/torch_spyre_inference.py --custom_ops none # Pure vLLM CPU mode with compile, this is also the default mode with enforce_eager=False in vLLM
python examples/torch_spyre_inference.py --custom_ops none +RMSNorm # vLLM CPU mode in compile mode with compiled SpyreRMSNorm layer offloading to Spyre
python examples/torch_spyre_inference.py --custom_ops all # vLLM CPU mode in compile mode with all custom ops implemented offloading to Spyre and run in compile mode

python examples/torch_spyre_inference.py --enforce_eager --custom_ops none # Pure vLLM CPU mode in eager mode
python examples/torch_spyre_inference.py --enforce_eager --custom_ops none +RMSNorm # vLLM CPU mode in eager mode with SpyreRMSNorm layer offloading to Spyre and running in eager mode
python examples/torch_spyre_inference.py --enforce_eager --custom_ops all # vLLM CPU mode in eager mode with all custom ops implemented offloading to Spyre, this is also the default mode with enforce_eager=True

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

Signed-off-by: romit <romit@ibm.com>
Signed-off-by: romit <romit@ibm.com>
Signed-off-by: romit <romit@ibm.com>
Signed-off-by: romit <romit@ibm.com>
Signed-off-by: romit <romit@ibm.com>
Signed-off-by: romit <romit@ibm.com>
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 7, 2026

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, run ./format.sh.
Now you are good to go 🚀.

We also recommend installing prek and configuring it to check your code before every local commit.

Signed-off-by: romit <romit@ibm.com>
@romitjain romitjain marked this pull request as ready for review April 7, 2026 08:06
@bohnstingl bohnstingl self-requested a review April 7, 2026 08:08
Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@romitjain thanks for the PR. I figured that the enforce_eager flag has been added in #853 as well. I am in principle fine with adding it either here or in the other PR.

I will try out the compilation_configs today and circle back

type=str,
nargs="*",
default=["none"],
help="Custom ops to enable (e.g., --custom_ops none +RMSNorm +SiluAndMul)",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The enforce_eager flag is also being added in the identical way actually in #853, see https://github.com/jvlunteren/vllm-spyre/blob/1749821c1f0f345bc7047e79f8eb351eb9d86f46/vllm_spyre_next/examples/torch_spyre_inference.py#L31.

Maybe we can focus this PR on the evaluation of the custom_ops?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually no, let's introduce the enforce_eager feature in this PR, like it is now.

Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. I left some minor change requests

"--custom-ops",
type=str,
nargs="*",
default=["none"],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we set the default value to [], so default=[]? This way, if the user doesn't set this variable, all CustomOps are enabled by default and the user doesn't have to explicitly add it via +RMSNorm, ... WDYT?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense.

Also, just to add - if default=[], enforce_eager=False will actually disable all the ops. So in that case, we still have to add --custom_ops all

See this

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can account for that by setting the default to None actually. Then we do

if args.custom_ops is None:
  if args.enforce_eager:
    args.custom_ops = ["all"]
  else:
    args.custom_ops = []

This way we ensure in both modes that we are enabling all CustomOps if the user does not explicitly use the --custom-ops parameters?

- "all": Run all supported ops on Spyre (default)
- "none": Run entirely on CPU
- "+LayerName": Selectively enable specific layers on Spyre
(e.g., --custom_ops none +RMSNorm +SiluAndMul)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The none wouldn't be necessary, I think?

Comment on lines +79 to +84
logger.warning_once(
"SpyreRMSNorm dispatch: enabled=%s, _forward_method=%s, forward_spyre compiled=%s",
self.enabled(),
self._forward_method.__name__,
self.maybe_compiled_forward_spyre is not self.forward_spyre,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to keep this, or was this more for debugging purposes?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially added this mostly for debugging purposes, but I remember reading it on Slack somewhere (I lost the thread) that we needed a way to figure out if the layers are actually being run or not. And while testing all the permutations and combinations of flags, I found this really helpful.

I am okay with removing too, or perhaps we can make this a debug statement?

WDYT?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this a debug statement makes sense to me. I will then add this also to #880 for the other wrappers as well.

@bohnstingl
Copy link
Copy Markdown
Collaborator

bot:test-next

Signed-off-by: romit <romit@ibm.com>
Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks good to me, just the one minor change and afterwards we can merge

"SpyreRMSNorm: no dtype promotion is performed, "
"expect numerical differences to upstream vLLM."
)
logger.debug(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this debug_once?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done

Signed-off-by: romit <romit@ibm.com>
@bohnstingl bohnstingl merged commit fc6b54b into torch-spyre:main Apr 10, 2026
13 checks passed
@romitjain romitjain deleted the tests/e2e-layer-wise branch April 10, 2026 05:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants