Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm_spyre/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def register():

def _init_logging():
"""Setup logging, extending from the vLLM logging config"""
config = dict[str, Any]()
config: dict[str, Any] = {}

if VLLM_CONFIGURE_LOGGING:
config = {**DEFAULT_LOGGING_CONFIG}
Expand Down
2 changes: 1 addition & 1 deletion vllm_spyre_next/examples/torch_spyre_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main():
t0 = time.time()
outputs = llm.generate(prompts, sampling_params)
print(
"Time elaspsed for %d tokens is %.2f sec"
"Time elapsed for %d tokens is %.2f sec"
% (len(outputs[0].outputs[0].token_ids), time.time() - t0)
)
print("===============")
Expand Down
4 changes: 2 additions & 2 deletions vllm_spyre_next/vllm_spyre_next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def register_ops():

def _init_logging():
"""Setup logging, extending from the vLLM logging config"""
config = dict[str, Any]()
config: dict[str, Any] = {}

if VLLM_CONFIGURE_LOGGING:
config = {**DEFAULT_LOGGING_CONFIG}

if VLLM_LOGGING_CONFIG_PATH:
# Error checks must be done already in vllm.logger.py
# Error checks must already be done in vllm.logger
with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
config = json.loads(file.read())

Expand Down
22 changes: 10 additions & 12 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

Spyre Device Constraints:
- Minimum batch size: 64 (due to spyre constraint, automatically padded)
- Device dtype: float16 (converted for CPU)
- Output dtype: bfloat16 (converted on CPU)
- Algorithm: Transpose-based computation with torch.ops.spyre.full()
- Computations performed in torch.float16:
Input (dtype defined by model / user) converted to torch.float16 for
operations on spyre and then converted back to original dtype for cpu.
- Epsilon as tensor: Instead of a scalar, a tensor is created via torch.full()

Limitations:
Currently the implementation in `forward_spyre` is similar to the
Expand Down Expand Up @@ -113,18 +114,15 @@ def forward_spyre(
hidden_size: int,
weight: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
variance_size_override: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Spyre-optimized RMS norm using transpose-based computation (active implementation).
"""Spyre-optimized RMS norm implementation.

Based on upstream vLLM's forward_static (vllm/model_executor/layers/layernorm.py)
but adapted for Spyre device with transpose operations and torch.ops.spyre.full().
Compiled separately via torch.compile in __init__.
but adapted for Spyre device. Compiled separately via torch.compile in __init__.

Key differences from upstream:
- Uses transpose(-1, -2) for computation efficiency on Spyre
- Creates epsilon tensor via torch.ops.spyre.full() instead of scalar
- No dtype promotion support (torch-spyre limitation)
- Creates epsilon tensor via torch.full() instead of scalar
- No dtype promotion support to torch.float32 (torch-spyre limitation)
"""
if residual is not None:
x = x + residual
Expand Down Expand Up @@ -159,7 +157,7 @@ def _forward_spyre_impl(
1. Minimum batch size: Pads to 64 if needed
2. Device transfer: CPU -> Spyre convert to float16
3. Kernel execution: Calls compiled maybe_compiled_forward_spyre
4. Result transfer: Spyre -> CPU, trim padding, convert to bfloat16
4. Result transfer: Spyre -> CPU, trim padding, convert to input dtype

Limitations:
- variance_size_override not implemented (raises NotImplementedError)
Expand All @@ -169,7 +167,7 @@ def _forward_spyre_impl(
residual: Optional residual

Returns:
Normalized output [batch_size, hidden_size] in bfloat16
Normalized output [batch_size, hidden_size] in input dtype
"""
x_dtype = x.dtype
x_device = x.device
Expand Down
12 changes: 10 additions & 2 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
- Separate Compilation: forward_spyre is compiled independently via maybe_compile

Spyre Device Constraints:
- Device dtype: float16 (via convert_for_spyre)
- Output dtype: matches input dtype (converted on CPU)
- Computations performed in torch.float16:
Input (dtype defined by model / user) converted to torch.float16 for
operations on spyre and then converted back to original dtype for cpu.

Output Shape Note:
Unlike RMSNorm (same input/output shape), SiluAndMul halves the last dimension:
Expand Down Expand Up @@ -66,6 +67,13 @@ def __init__(self, *args, **kwargs):

self._layer_name = register_layer(self, "spyre_siluandmul")

logger.debug_once(
"SpyreSiluAndMul: Dispatch: enabled=%s, Forward method=%s, Compiled=%s",
self.enabled(),
self._forward_method.__name__,
self.maybe_compiled_forward_spyre is not self.forward_spyre,
)

def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
"""OOT forward pass using custom op to bypass torch.compile.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def __init__(self, *args, **kwargs):

self._layer_name = register_layer(self, "spyre_vocab_parallel_embedding")

logger.debug_once(
"SpyreVocabParallelEmbedding: Dispatch: enabled=%s, Forward method=%s, Compiled=%s",
self.enabled(),
self._forward_method.__name__,
self.maybe_compiled_forward_spyre is not self.forward_spyre,
)

def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
"""OOT forward pass using custom op to bypass torch.compile.

Expand Down
4 changes: 1 addition & 3 deletions vllm_spyre_next/vllm_spyre_next/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
# "auto" defaults to the CPUWorker as we inherit from the CpuPlatform
# from vllm_spyre_next.v1.worker.spyre_worker import TorchSpyreWorker
# Override with TorchSpyreWorker for Spyre-specific functionality
worker_class = "vllm_spyre_next.v1.worker.spyre_worker.TorchSpyreWorker"
# if a torch spyre specific worker class is needed it can be loaded with
# worker_class = "vllm_spyre_next.v1.worker.spyre_worker.TorchSpyreWorker"
logger.info("Loading worker from: %s", worker_class)
parallel_config.worker_cls = worker_class

Expand Down
9 changes: 5 additions & 4 deletions vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _get_paths_from_yaml() -> str:

def _cache_root() -> Path:
"""
Cache directory for cloned tests (sticky between runs)
Cache directory for cloned tests (persists across runs)
"""
# Respect XDG if present, fallback to ~/.cache
xdg = os.environ.get("XDG_CACHE_HOME")
Expand All @@ -171,7 +171,8 @@ def _cache_root() -> Path:
def _extract_vllm_commit_from_pyproject() -> str:
"""
Extract the vLLM git reference from pyproject.toml [tool.uv.sources] section.
Returns None if not found or parseable.
Raises FileNotFoundError if pyproject.toml is missing, or KeyError
if the expected source entry is not found.
"""
pyproject_path = Path(__file__).parent.parent.parent / "pyproject.toml"
if not pyproject_path.exists():
Expand Down Expand Up @@ -498,7 +499,7 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item
elif allow_entry.mode == "xfail_strict":
item.add_marker(pytest.mark.xfail(strict=True))

# Reorder tests so that tests with "model" in the name run first
# Reorder tests so that tests with "uses_subprocess" marker run first
_reorder_tests_by_name(items)


Expand Down Expand Up @@ -582,7 +583,7 @@ def _spyre_default_vllm_config(monkeypatch):
compilation_config=CompilationConfig(custom_ops=["all"]),
)
with set_current_vllm_config(config), set_forward_context(None, config):
# Set forward context so custom ops can access no_compile_layers
# Set forward context so custom ops can access the vllm config
yield


Expand Down
Loading