From dae43ac202dcb3336a81f1cc65a84d08d0523cb8 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Tue, 31 Mar 2026 09:11:11 +0200 Subject: [PATCH 01/10] remove redundant comment Signed-off-by: Yannick Schnider --- vllm_spyre_next/vllm_spyre_next/platform.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_spyre_next/vllm_spyre_next/platform.py b/vllm_spyre_next/vllm_spyre_next/platform.py index 024d225e7..7c000e110 100644 --- a/vllm_spyre_next/vllm_spyre_next/platform.py +++ b/vllm_spyre_next/vllm_spyre_next/platform.py @@ -81,10 +81,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": # "auto" defaults to the CPUWorker as we inherit from the CpuPlatform - # from vllm_spyre_next.v1.worker.spyre_worker import TorchSpyreWorker + # Override with TorchSpyreWorker for Spyre-specific functionality worker_class = "vllm_spyre_next.v1.worker.spyre_worker.TorchSpyreWorker" - # if a torch spyre specific worker class is needed it can be loaded with - # worker_class = "vllm_spyre_next.v1.worker.spyre_worker.TorchSpyreWorker" logger.info("Loading worker from: %s", worker_class) parallel_config.worker_cls = worker_class From 9a14e5114c96ae767dd81c3934a6fd1983483bf7 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Tue, 31 Mar 2026 09:11:55 +0200 Subject: [PATCH 02/10] remove old comments about transpose Signed-off-by: Yannick Schnider --- vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py index 4ba164b2d..1026ef930 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py @@ -18,7 +18,6 @@ - Minimum batch size: 64 (due to spyre constraint, automatically padded) - Device dtype: float16 (converted for CPU) - Output dtype: bfloat16 (converted on CPU) - - Algorithm: Transpose-based computation with torch.ops.spyre.full() Limitations: Currently the implementation in `forward_spyre` is similar to the @@ -115,14 +114,12 @@ def forward_spyre( residual: torch.Tensor | None = None, variance_size_override: int | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """Spyre-optimized RMS norm using transpose-based computation (active implementation). + """Spyre-optimized RMS norm implementation. Based on upstream vLLM's forward_static (vllm/model_executor/layers/layernorm.py) - but adapted for Spyre device with transpose operations and torch.ops.spyre.full(). - Compiled separately via torch.compile in __init__. + but adapted for Spyre device. Compiled separately via torch.compile in __init__. Key differences from upstream: - - Uses transpose(-1, -2) for computation efficiency on Spyre - Creates epsilon tensor via torch.ops.spyre.full() instead of scalar - No dtype promotion support (torch-spyre limitation) """ From c07bf2a33a2033e8c011618b4cb7cf82c135f7e4 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Tue, 31 Mar 2026 09:12:15 +0200 Subject: [PATCH 03/10] fix typo Signed-off-by: Yannick Schnider --- vllm_spyre_next/examples/torch_spyre_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_spyre_next/examples/torch_spyre_inference.py b/vllm_spyre_next/examples/torch_spyre_inference.py index 47524aa28..cf129bad8 100644 --- a/vllm_spyre_next/examples/torch_spyre_inference.py +++ b/vllm_spyre_next/examples/torch_spyre_inference.py @@ -99,7 +99,7 @@ def main(): t0 = time.time() outputs = llm.generate(prompts, sampling_params) print( - "Time elaspsed for %d tokens is %.2f sec" + "Time elapsed for %d tokens is %.2f sec" % (len(outputs[0].outputs[0].token_ids), time.time() - t0) ) print("===============") From 40dd65684e62c7964dc2264d65c18376b71702d9 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Tue, 31 Mar 2026 09:12:31 +0200 Subject: [PATCH 04/10] fix typing Signed-off-by: Yannick Schnider --- vllm_spyre/__init__.py | 2 +- vllm_spyre_next/vllm_spyre_next/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_spyre/__init__.py b/vllm_spyre/__init__.py index 480012d29..ab99eb905 100644 --- a/vllm_spyre/__init__.py +++ b/vllm_spyre/__init__.py @@ -16,7 +16,7 @@ def register(): def _init_logging(): """Setup logging, extending from the vLLM logging config""" - config = dict[str, Any]() + config: dict[str, Any] = {} if VLLM_CONFIGURE_LOGGING: config = {**DEFAULT_LOGGING_CONFIG} diff --git a/vllm_spyre_next/vllm_spyre_next/__init__.py b/vllm_spyre_next/vllm_spyre_next/__init__.py index 5fda4cc0a..d00bca606 100644 --- a/vllm_spyre_next/vllm_spyre_next/__init__.py +++ b/vllm_spyre_next/vllm_spyre_next/__init__.py @@ -23,7 +23,7 @@ def register_ops(): def _init_logging(): """Setup logging, extending from the vLLM logging config""" - config = dict[str, Any]() + config: dict[str, Any] = {} if VLLM_CONFIGURE_LOGGING: config = {**DEFAULT_LOGGING_CONFIG} From a0257c78ff22676a0780d26733c641cbc16e7b2a Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Tue, 31 Mar 2026 19:46:19 +0000 Subject: [PATCH 05/10] Minor docstring updates Signed-off-by: Thomas Ortner --- vllm_spyre_next/vllm_spyre_next/__init__.py | 2 +- vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm_spyre_next/vllm_spyre_next/__init__.py b/vllm_spyre_next/vllm_spyre_next/__init__.py index 5fda4cc0a..ec09b3e7a 100644 --- a/vllm_spyre_next/vllm_spyre_next/__init__.py +++ b/vllm_spyre_next/vllm_spyre_next/__init__.py @@ -29,7 +29,7 @@ def _init_logging(): config = {**DEFAULT_LOGGING_CONFIG} if VLLM_LOGGING_CONFIG_PATH: - # Error checks must be done already in vllm.logger.py + # Error checks must already be done in vllm.logger with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file: config = json.loads(file.read()) diff --git a/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py b/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py index 52c44bbf3..eea6b7fe4 100644 --- a/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py +++ b/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py @@ -160,7 +160,7 @@ def _get_paths_from_yaml() -> str: def _cache_root() -> Path: """ - Cache directory for cloned tests (sticky between runs) + Cache directory for cloned tests (persists across runs) """ # Respect XDG if present, fallback to ~/.cache xdg = os.environ.get("XDG_CACHE_HOME") @@ -171,7 +171,8 @@ def _cache_root() -> Path: def _extract_vllm_commit_from_pyproject() -> str: """ Extract the vLLM git reference from pyproject.toml [tool.uv.sources] section. - Returns None if not found or parseable. + Raises FileNotFoundError if pyproject.toml is missing, or KeyError + if the expected source entry is not found. """ pyproject_path = Path(__file__).parent.parent.parent / "pyproject.toml" if not pyproject_path.exists(): @@ -498,7 +499,7 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item elif allow_entry.mode == "xfail_strict": item.add_marker(pytest.mark.xfail(strict=True)) - # Reorder tests so that tests with "model" in the name run first + # Reorder tests so that tests with "uses_subprocess" marker run first _reorder_tests_by_name(items) @@ -582,7 +583,7 @@ def _spyre_default_vllm_config(monkeypatch): compilation_config=CompilationConfig(custom_ops=["all"]), ) with set_current_vllm_config(config), set_forward_context(None, config): - # Set forward context so custom ops can access no_compile_layers + # Set forward context so custom ops can access the vllm config yield From bc69351f2eaf284b89384524874541bcf74f1096 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Wed, 1 Apr 2026 00:55:10 +0200 Subject: [PATCH 06/10] remove torch.ops.spyre.full() reference Signed-off-by: Yannick Schnider --- vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py index 1026ef930..f32e3d0e6 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py @@ -120,7 +120,7 @@ def forward_spyre( but adapted for Spyre device. Compiled separately via torch.compile in __init__. Key differences from upstream: - - Creates epsilon tensor via torch.ops.spyre.full() instead of scalar + - Creates epsilon tensor via torch.full() instead of scalar - No dtype promotion support (torch-spyre limitation) """ if residual is not None: From 1c5204adc2a6cd61c9fb5a611d40ec05aefb674c Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Wed, 1 Apr 2026 00:56:54 +0200 Subject: [PATCH 07/10] remove unused argument Signed-off-by: Yannick Schnider --- vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py index f32e3d0e6..d02ddb07f 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py @@ -112,7 +112,6 @@ def forward_spyre( hidden_size: int, weight: torch.Tensor | None = None, residual: torch.Tensor | None = None, - variance_size_override: int | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Spyre-optimized RMS norm implementation. From 7e6259092d7751efbea408177825c6fd2bbbe765 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Wed, 1 Apr 2026 13:05:40 +0000 Subject: [PATCH 08/10] Various docstring corrections Signed-off-by: Thomas Ortner --- .../vllm_spyre_next/custom_ops/rms_norm.py | 18 +++++++++--------- .../vllm_spyre_next/custom_ops/silu_and_mul.py | 5 +++-- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py index 4ba164b2d..714cfec91 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py @@ -16,9 +16,10 @@ Spyre Device Constraints: - Minimum batch size: 64 (due to spyre constraint, automatically padded) - - Device dtype: float16 (converted for CPU) - - Output dtype: bfloat16 (converted on CPU) - - Algorithm: Transpose-based computation with torch.ops.spyre.full() + - Computations performed in torch.float16: + Input (dtype defined by model / user) converted to torch.float16 for + operations on spyre and then converted back to original dtype for cpu. + - Epsilon as tensor: Instead of a scalar, a tensor is created via torch.full() Limitations: Currently the implementation in `forward_spyre` is similar to the @@ -118,13 +119,12 @@ def forward_spyre( """Spyre-optimized RMS norm using transpose-based computation (active implementation). Based on upstream vLLM's forward_static (vllm/model_executor/layers/layernorm.py) - but adapted for Spyre device with transpose operations and torch.ops.spyre.full(). + but adapted for Spyre device with epsilon being a tensor with torch.full(). Compiled separately via torch.compile in __init__. Key differences from upstream: - - Uses transpose(-1, -2) for computation efficiency on Spyre - - Creates epsilon tensor via torch.ops.spyre.full() instead of scalar - - No dtype promotion support (torch-spyre limitation) + - Creates epsilon tensor via torch.full() instead of scalar + - No dtype promotion support to torch.float32 (torch-spyre limitation) """ if residual is not None: x = x + residual @@ -159,7 +159,7 @@ def _forward_spyre_impl( 1. Minimum batch size: Pads to 64 if needed 2. Device transfer: CPU -> Spyre convert to float16 3. Kernel execution: Calls compiled maybe_compiled_forward_spyre - 4. Result transfer: Spyre -> CPU, trim padding, convert to bfloat16 + 4. Result transfer: Spyre -> CPU, trim padding, convert to input dtype Limitations: - variance_size_override not implemented (raises NotImplementedError) @@ -169,7 +169,7 @@ def _forward_spyre_impl( residual: Optional residual Returns: - Normalized output [batch_size, hidden_size] in bfloat16 + Normalized output [batch_size, hidden_size] in input dtype """ x_dtype = x.dtype x_device = x.device diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py index 2774e8dfa..a86b22d71 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py @@ -15,8 +15,9 @@ - Separate Compilation: forward_spyre is compiled independently via maybe_compile Spyre Device Constraints: - - Device dtype: float16 (via convert_for_spyre) - - Output dtype: matches input dtype (converted on CPU) + - Computations performed in torch.float16: + Input (dtype defined by model / user) converted to torch.float16 for + operations on spyre and then converted back to original dtype for cpu. Output Shape Note: Unlike RMSNorm (same input/output shape), SiluAndMul halves the last dimension: From bff1d4a411675ac4f3a154b87d7a8358ef8b5f98 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Wed, 8 Apr 2026 22:08:43 +0000 Subject: [PATCH 09/10] Added debug statement to silu and embedding Signed-off-by: Thomas Ortner --- vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py | 7 +++++++ .../vllm_spyre_next/custom_ops/vocab_parallel_embedding.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py index a86b22d71..190547a3e 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py @@ -66,6 +66,13 @@ def __init__(self, *args, **kwargs): self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre) self._layer_name = register_layer(self, "spyre_siluandmul") + + logger.debug_once( + "SpyreSiluAndMul: Dispatch: enabled=%s, Forward method=%s, Compiled=%s", + self.enabled(), + self._forward_method.__name__, + self.maybe_compiled_forward_spyre is not self.forward_spyre, + ) def forward_oot(self, x: torch.Tensor) -> torch.Tensor: """OOT forward pass using custom op to bypass torch.compile. diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py index 183159331..30ce05af8 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py @@ -94,6 +94,13 @@ def __init__(self, *args, **kwargs): self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre) self._layer_name = register_layer(self, "spyre_vocab_parallel_embedding") + + logger.debug_once( + "SpyreVocabParallelEmbedding: Dispatch: enabled=%s, Forward method=%s, Compiled=%s", + self.enabled(), + self._forward_method.__name__, + self.maybe_compiled_forward_spyre is not self.forward_spyre, + ) def forward_oot(self, x: torch.Tensor) -> torch.Tensor: """OOT forward pass using custom op to bypass torch.compile. From c43484a10c95cb6f54ea0b7b407e28496b991569 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Thu, 9 Apr 2026 10:40:59 +0200 Subject: [PATCH 10/10] fix ruff Signed-off-by: Yannick Schnider --- vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py | 2 +- .../vllm_spyre_next/custom_ops/vocab_parallel_embedding.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py index 190547a3e..ffaa8a63b 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py @@ -66,7 +66,7 @@ def __init__(self, *args, **kwargs): self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre) self._layer_name = register_layer(self, "spyre_siluandmul") - + logger.debug_once( "SpyreSiluAndMul: Dispatch: enabled=%s, Forward method=%s, Compiled=%s", self.enabled(), diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py index 30ce05af8..167ab04f3 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py @@ -94,7 +94,7 @@ def __init__(self, *args, **kwargs): self.maybe_compiled_forward_spyre = self.maybe_compile(self.forward_spyre) self._layer_name = register_layer(self, "spyre_vocab_parallel_embedding") - + logger.debug_once( "SpyreVocabParallelEmbedding: Dispatch: enabled=%s, Forward method=%s, Compiled=%s", self.enabled(),