Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

Spyre Device Constraints:
- Minimum batch size: 64 (due to spyre constraint, automatically padded)
- Device dtype: float16 (converted for CPU)
- Output dtype: bfloat16 (converted on CPU)
- Computations performed in torch.float16:
Input (dtype defined by model / user) converted to torch.float16 for
operations on spyre and then converted back to original dtype for cpu.
- Epsilon as tensor: Instead of a scalar, a tensor is created via torch.full()

Limitations:
Currently the implementation in `forward_spyre` is similar to the
Expand Down Expand Up @@ -120,7 +122,7 @@ def forward_spyre(

Key differences from upstream:
- Creates epsilon tensor via torch.full() instead of scalar
- No dtype promotion support (torch-spyre limitation)
- No dtype promotion support to torch.float32 (torch-spyre limitation)
"""
if residual is not None:
x = x + residual
Expand Down Expand Up @@ -155,7 +157,7 @@ def _forward_spyre_impl(
1. Minimum batch size: Pads to 64 if needed
2. Device transfer: CPU -> Spyre convert to float16
3. Kernel execution: Calls compiled maybe_compiled_forward_spyre
4. Result transfer: Spyre -> CPU, trim padding, convert to bfloat16
4. Result transfer: Spyre -> CPU, trim padding, convert to input dtype

Limitations:
- variance_size_override not implemented (raises NotImplementedError)
Expand All @@ -165,7 +167,7 @@ def _forward_spyre_impl(
residual: Optional residual

Returns:
Normalized output [batch_size, hidden_size] in bfloat16
Normalized output [batch_size, hidden_size] in input dtype
"""
x_dtype = x.dtype
x_device = x.device
Expand Down
5 changes: 3 additions & 2 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
- Separate Compilation: forward_spyre is compiled independently via maybe_compile

Spyre Device Constraints:
- Device dtype: float16 (via convert_for_spyre)
- Output dtype: matches input dtype (converted on CPU)
- Computations performed in torch.float16:
Input (dtype defined by model / user) converted to torch.float16 for
operations on spyre and then converted back to original dtype for cpu.

Output Shape Note:
Unlike RMSNorm (same input/output shape), SiluAndMul halves the last dimension:
Expand Down