Skip to content

[Spyre-Next] Wrapped Embedding layer for spyre#836

Merged
bohnstingl merged 14 commits intotorch-spyre:mainfrom
coderfornow:embedding-support-spyre
Mar 31, 2026
Merged

[Spyre-Next] Wrapped Embedding layer for spyre#836
bohnstingl merged 14 commits intotorch-spyre:mainfrom
coderfornow:embedding-support-spyre

Conversation

@coderfornow
Copy link
Copy Markdown
Collaborator

@coderfornow coderfornow commented Mar 12, 2026

Description

Adds SpyreVocabParallelEmbedding, a Spyre-optimized out-of-tree (OOT) replacement for vLLM's VocabParallelEmbedding, following the custom op pattern from #842 (same as SpyreRMSNorm and SpyreSiluAndMul).

Related Issues

Test Plan

  • Existing CI should pass — OOT registration is transparent to upstream consumers
  • Manual verification that SpyreVocabParallelEmbedding is instantiated in place of VocabParallelEmbedding on Spyre card

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

@github-actions github-actions Bot changed the title Wrapped Embedding layer for spare [Spyre-Next] Wrapped Embedding layer for spare Mar 12, 2026
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, run ./format.sh.
Now you are good to go 🚀.

We also recommend installing prek and configuring it to check your code before every local commit.

@coderfornow coderfornow marked this pull request as draft March 12, 2026 22:18
@joerunde
Copy link
Copy Markdown
Collaborator

bot:next-test

@joerunde
Copy link
Copy Markdown
Collaborator

@coderfornow looks like something is busted here, tests are failing with:

        params = self.__dict__.get("_parameters")
        if isinstance(value, Parameter):
            if params is None:
>               raise AttributeError(
                    "cannot assign parameters before Module.__init__() call"
                )
E               AttributeError: cannot assign parameters before Module.__init__() call

.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1981: AttributeError

Comment thread vllm_spyre_next/tests/test_vocab_parallel_embedding.py Outdated
Comment thread vllm_spyre_next/tests/test_vocab_parallel_embedding.py Outdated
Comment thread vllm_spyre_next/tests/test_vocab_parallel_embedding.py Outdated
Comment thread vllm_spyre_next/tests/test_vocab_parallel_embedding.py Outdated
Comment thread vllm_spyre_next/tests/test_vocab_parallel_embedding.py Outdated
@GOavi101 GOavi101 force-pushed the embedding-support-spyre branch 4 times, most recently from ea1e192 to 470eb1f Compare March 16, 2026 08:26
@coderfornow coderfornow force-pushed the embedding-support-spyre branch 2 times, most recently from 72e0a17 to cce91f9 Compare March 16, 2026 11:56
@GOavi101 GOavi101 force-pushed the embedding-support-spyre branch from cce91f9 to af2fcf4 Compare March 16, 2026 12:14
…n and tests

Signed-off-by: coderfornow <ritikdhiranan@icloud.com>
@coderfornow coderfornow force-pushed the embedding-support-spyre branch from 5300594 to f76e565 Compare March 16, 2026 12:25
@GOavi101 GOavi101 force-pushed the embedding-support-spyre branch 3 times, most recently from 2a18c7a to d009c04 Compare March 17, 2026 07:20
@GOavi101
Copy link
Copy Markdown
Collaborator

All tests passing for SpyreVocabParallelEmbedding:

SKIP_UPSTREAM_TESTS=1 python -m pytest tests/test_vocab_parallel_embedding.py -v

tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-256-1] PASSED                                      [  2%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-256-32] PASSED                                     [  5%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-256-63] PASSED                                     [  8%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-256-64] PASSED                                     [ 10%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-256-65] PASSED                                     [ 13%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-256-128] PASSED                                    [ 16%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-1024-1] PASSED                                     [ 18%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-1024-32] PASSED                                    [ 21%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-1024-63] PASSED                                    [ 24%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-1024-64] PASSED                                    [ 27%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-1024-65] PASSED                                    [ 29%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[64-1024-128] PASSED                                   [ 32%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-256-1] PASSED                                     [ 35%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-256-32] PASSED                                    [ 37%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-256-63] PASSED                                    [ 40%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-256-64] PASSED                                    [ 43%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-256-65] PASSED                                    [ 45%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-256-128] PASSED                                   [ 48%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-1024-1] PASSED                                    [ 51%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-1024-32] PASSED                                   [ 54%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-1024-63] PASSED                                   [ 56%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-1024-64] PASSED                                   [ 59%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-1024-65] PASSED                                   [ 62%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[128-1024-128] PASSED                                  [ 64%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-256-1] PASSED                                     [ 67%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-256-32] PASSED                                    [ 70%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-256-63] PASSED                                    [ 72%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-256-64] PASSED                                    [ 75%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-256-65] PASSED                                    [ 78%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-256-128] PASSED                                   [ 81%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-1024-1] PASSED                                    [ 83%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-1024-32] PASSED                                   [ 86%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-1024-63] PASSED                                   [ 89%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-1024-64] PASSED                                   [ 91%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-1024-65] PASSED                                   [ 94%]
tests/test_vocab_parallel_embedding.py::test_spyre_vocab_parallel_embedding_matches_reference[512-1024-128] PASSED                                  [ 97%]
tests/test_vocab_parallel_embedding.py::test_vocab_parallel_embedding_oot_dispatch PASSED                                                           [100%]

@bohnstingl bohnstingl self-requested a review March 17, 2026 08:59
Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderfornow Thank you for the PR. I took a first look at it and made some comments. In addition, I would have two general questions:

  1. Are there tests in vllm upstream for the VocabParallelEmbedding layer? If so, is there a good reason to have separate tests, or can we reuse them from upstream? cc @joerunde
  2. Did you run the simple E2E test (https://github.com/vllm-project/vllm-spyre/blob/main/vllm_spyre_next/tests/test_vllm_spyre_next.py) and do you get the same outputs when using your wrapper and without it? Here this should be really simple, as there is not really much happening.

Comment thread vllm_spyre_next/vllm_spyre_next/platform.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py Outdated
@GOavi101 GOavi101 force-pushed the embedding-support-spyre branch from d009c04 to 903837a Compare March 17, 2026 10:28
…n and tests

Signed-off-by: coderfornow <ritikdhiranan@icloud.com>
Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
coderfornow and others added 4 commits March 20, 2026 15:13
…move static forward context, and consolidate custom op handling

Signed-off-by: coderfornow <ritikdhiranan@icloud.com>
…move static forward context, and consolidate custom op handling

Signed-off-by: coderfornow <ritikdhiranan@icloud.com>
…bedding-support-spyre

# Conflicts:
#	vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py
@coderfornow coderfornow requested a review from bohnstingl March 20, 2026 12:24
@coderfornow coderfornow linked an issue Mar 20, 2026 that may be closed by this pull request
@bohnstingl bohnstingl marked this pull request as ready for review March 24, 2026 09:44
@coderfornow coderfornow force-pushed the embedding-support-spyre branch 2 times, most recently from d5609a1 to 08c67f4 Compare March 24, 2026 11:38
Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some minor comments for code quality and functionality reuse

Comment on lines +120 to +124
if input_.device.type != "cpu":
raise NotImplementedError(
f"Expected input on CPU, got device={input_.device}. "
"Spyre has no native embedding kernel; input must stay on CPU."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check is not necessary. F.embedding would work with an input tensor on cpu and the dispatch machinery in torch-spyre will take care of doing the correct thing.
However, we should bring the tensor to spyre at this point using the convert function similar to the rms_norm example.

Comment on lines +102 to +119
"""Embedding execution on CPU.

F.embedding runs on CPU via torch-spyre's spyre__embedding fallback.
torch-spyre has no native embedding kernel; aten.embedding.default falls
back to CPU (spyre__embedding). Moving tensors to Spyre would cause a
DtException (unsupported data format), so no device transfer is performed.

No TP masking or all_reduce is performed (tp_size > 1 is not supported).

Args:
input_: Token index tensor [num_tokens] on CPU (int64)

Returns:
Embedding output [num_tokens, embedding_dim] in weight dtype on CPU

Raises:
NotImplementedError: If input tensor is not on CPU.
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that we should update this description a bit. From the vLLM perspective, the embedding is running on spyre, i.e., it should bring the input tensor to spyre and then torch-spyre will perform the operation. We can leave a note that currently torch-spyre has a fallback for the embedding to cpu and we could also leave the link to the corresponding issue: torch-spyre/torch-spyre#420

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

@coderfornow coderfornow force-pushed the embedding-support-spyre branch from 08c67f4 to 401ad79 Compare March 24, 2026 15:56
  Signed-off-by: coderfornow <ritikdhiranan@icloud.com>

Signed-off-by: coderfornow <ritikdhiranan@icloud.com>
…ling, and update Spyre fallback logic for indirect indexing

Signed-off-by: coderfornow <ritikdhiranan@icloud.com>
@coderfornow coderfornow force-pushed the embedding-support-spyre branch from ca29458 to 990ff00 Compare March 24, 2026 16:11
Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Can you confirm that the end-to-end example with a Granite3.3-8B model and only the VocabParallelEmbedding layer wrapper for spyre works and produces meaningful tokens?

@bohnstingl
Copy link
Copy Markdown
Collaborator

@joerunde, is the readthedocs CI fail blocking?

@bohnstingl bohnstingl requested a review from joerunde March 24, 2026 23:00
@tjohnson31415
Copy link
Copy Markdown
Collaborator

tjohnson31415 commented Mar 25, 2026

is the readthedocs CI fail blocking?

No, not blocking. It has been fixed on main: #860

Joe is OOO, but I can help get this merged once we have the confirmation with Granite 3.3 8b.

Please also update the PR description to match the current content of the PR.

@coderfornow
Copy link
Copy Markdown
Collaborator Author

I've tested with Granite 3.3 8b

here's output

INFO 03-27 10:12:42 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.1 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
Processed prompts: 100%|█| 1/1 [01:17<00:00, 77.64s/it, est. speed input: 0.10 toks/s, output: 0.06

Generated text: '\n\nIBM operates'

@bohnstingl
Copy link
Copy Markdown
Collaborator

The PR looks good to me in general. It has been observed though that the current way of wrapping for torch-spyre interferes with the enablement of upstream vLLM tests, see #863. To address this, I've opened a PR (#872) that reworks the forward call chain a bit and uses forward_oot instead of forward_native. Maybe we could hold off the merge a bit and get #872 merged first and then apply the rework directly here as well?

@coderfornow what do you think?

@bohnstingl
Copy link
Copy Markdown
Collaborator

@coderfornow #872 has landed. Could you please overtake the modified forward call structure? I will then push for a quick merge

coderfornow and others added 2 commits March 31, 2026 15:00
… forward dispatch, and clarify separate compilation for Spyre-specific kernels

Signed-off-by: coderfornow <ritikdhiranan@icloud.com>
Comment thread vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py Outdated
…ng during device transfers

Signed-off-by: coderfornow <ritikdhiranan@icloud.com>
Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT, thanks for merging the new call-chain in.

@coderfornow
Copy link
Copy Markdown
Collaborator Author

image

@bohnstingl bohnstingl merged commit a4e0935 into torch-spyre:main Mar 31, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Wrap the embedding layer

6 participants