diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py index d02ddb07f..82bcb434e 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py @@ -16,8 +16,10 @@ Spyre Device Constraints: - Minimum batch size: 64 (due to spyre constraint, automatically padded) - - Device dtype: float16 (converted for CPU) - - Output dtype: bfloat16 (converted on CPU) + - Computations performed in torch.float16: + Input (dtype defined by model / user) converted to torch.float16 for + operations on spyre and then converted back to original dtype for cpu. + - Epsilon as tensor: Instead of a scalar, a tensor is created via torch.full() Limitations: Currently the implementation in `forward_spyre` is similar to the @@ -120,7 +122,7 @@ def forward_spyre( Key differences from upstream: - Creates epsilon tensor via torch.full() instead of scalar - - No dtype promotion support (torch-spyre limitation) + - No dtype promotion support to torch.float32 (torch-spyre limitation) """ if residual is not None: x = x + residual @@ -155,7 +157,7 @@ def _forward_spyre_impl( 1. Minimum batch size: Pads to 64 if needed 2. Device transfer: CPU -> Spyre convert to float16 3. Kernel execution: Calls compiled maybe_compiled_forward_spyre - 4. Result transfer: Spyre -> CPU, trim padding, convert to bfloat16 + 4. Result transfer: Spyre -> CPU, trim padding, convert to input dtype Limitations: - variance_size_override not implemented (raises NotImplementedError) @@ -165,7 +167,7 @@ def _forward_spyre_impl( residual: Optional residual Returns: - Normalized output [batch_size, hidden_size] in bfloat16 + Normalized output [batch_size, hidden_size] in input dtype """ x_dtype = x.dtype x_device = x.device diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py index 2774e8dfa..a86b22d71 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py @@ -15,8 +15,9 @@ - Separate Compilation: forward_spyre is compiled independently via maybe_compile Spyre Device Constraints: - - Device dtype: float16 (via convert_for_spyre) - - Output dtype: matches input dtype (converted on CPU) + - Computations performed in torch.float16: + Input (dtype defined by model / user) converted to torch.float16 for + operations on spyre and then converted back to original dtype for cpu. Output Shape Note: Unlike RMSNorm (same input/output shape), SiluAndMul halves the last dimension: