diff --git a/vllm_spyre/__init__.py b/vllm_spyre/__init__.py index 480012d29..ab99eb905 100644 --- a/vllm_spyre/__init__.py +++ b/vllm_spyre/__init__.py @@ -16,7 +16,7 @@ def register(): def _init_logging(): """Setup logging, extending from the vLLM logging config""" - config = dict[str, Any]() + config: dict[str, Any] = {} if VLLM_CONFIGURE_LOGGING: config = {**DEFAULT_LOGGING_CONFIG} diff --git a/vllm_spyre_next/examples/torch_spyre_inference.py b/vllm_spyre_next/examples/torch_spyre_inference.py index 47524aa28..cf129bad8 100644 --- a/vllm_spyre_next/examples/torch_spyre_inference.py +++ b/vllm_spyre_next/examples/torch_spyre_inference.py @@ -99,7 +99,7 @@ def main(): t0 = time.time() outputs = llm.generate(prompts, sampling_params) print( - "Time elaspsed for %d tokens is %.2f sec" + "Time elapsed for %d tokens is %.2f sec" % (len(outputs[0].outputs[0].token_ids), time.time() - t0) ) print("===============") diff --git a/vllm_spyre_next/vllm_spyre_next/__init__.py b/vllm_spyre_next/vllm_spyre_next/__init__.py index 5fda4cc0a..e221fe858 100644 --- a/vllm_spyre_next/vllm_spyre_next/__init__.py +++ b/vllm_spyre_next/vllm_spyre_next/__init__.py @@ -23,13 +23,13 @@ def register_ops(): def _init_logging(): """Setup logging, extending from the vLLM logging config""" - config = dict[str, Any]() + config: dict[str, Any] = {} if VLLM_CONFIGURE_LOGGING: config = {**DEFAULT_LOGGING_CONFIG} if VLLM_LOGGING_CONFIG_PATH: - # Error checks must be done already in vllm.logger.py + # Error checks must already be done in vllm.logger with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file: config = json.loads(file.read()) diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py index 4ba164b2d..82bcb434e 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/rms_norm.py @@ -16,9 +16,10 @@ Spyre Device Constraints: - Minimum batch size: 64 (due to spyre constraint, automatically padded) - - Device dtype: float16 (converted for CPU) - - Output dtype: bfloat16 (converted on CPU) - - Algorithm: Transpose-based computation with torch.ops.spyre.full() + - Computations performed in torch.float16: + Input (dtype defined by model / user) converted to torch.float16 for + operations on spyre and then converted back to original dtype for cpu. + - Epsilon as tensor: Instead of a scalar, a tensor is created via torch.full() Limitations: Currently the implementation in `forward_spyre` is similar to the @@ -113,18 +114,15 @@ def forward_spyre( hidden_size: int, weight: torch.Tensor | None = None, residual: torch.Tensor | None = None, - variance_size_override: int | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """Spyre-optimized RMS norm using transpose-based computation (active implementation). + """Spyre-optimized RMS norm implementation. Based on upstream vLLM's forward_static (vllm/model_executor/layers/layernorm.py) - but adapted for Spyre device with transpose operations and torch.ops.spyre.full(). - Compiled separately via torch.compile in __init__. + but adapted for Spyre device. Compiled separately via torch.compile in __init__. Key differences from upstream: - - Uses transpose(-1, -2) for computation efficiency on Spyre - - Creates epsilon tensor via torch.ops.spyre.full() instead of scalar - - No dtype promotion support (torch-spyre limitation) + - Creates epsilon tensor via torch.full() instead of scalar + - No dtype promotion support to torch.float32 (torch-spyre limitation) """ if residual is not None: x = x + residual @@ -159,7 +157,7 @@ def _forward_spyre_impl( 1. Minimum batch size: Pads to 64 if needed 2. Device transfer: CPU -> Spyre convert to float16 3. Kernel execution: Calls compiled maybe_compiled_forward_spyre - 4. Result transfer: Spyre -> CPU, trim padding, convert to bfloat16 + 4. Result transfer: Spyre -> CPU, trim padding, convert to input dtype Limitations: - variance_size_override not implemented (raises NotImplementedError) @@ -169,7 +167,7 @@ def _forward_spyre_impl( residual: Optional residual Returns: - Normalized output [batch_size, hidden_size] in bfloat16 + Normalized output [batch_size, hidden_size] in input dtype """ x_dtype = x.dtype x_device = x.device diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py index 2774e8dfa..ffaa8a63b 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/silu_and_mul.py @@ -15,8 +15,9 @@ - Separate Compilation: forward_spyre is compiled independently via maybe_compile Spyre Device Constraints: - - Device dtype: float16 (via convert_for_spyre) - - Output dtype: matches input dtype (converted on CPU) + - Computations performed in torch.float16: + Input (dtype defined by model / user) converted to torch.float16 for + operations on spyre and then converted back to original dtype for cpu. Output Shape Note: Unlike RMSNorm (same input/output shape), SiluAndMul halves the last dimension: @@ -66,6 +67,13 @@ def __init__(self, *args, **kwargs): self._layer_name = register_layer(self, "spyre_siluandmul") + logger.debug_once( + "SpyreSiluAndMul: Dispatch: enabled=%s, Forward method=%s, Compiled=%s", + self.enabled(), + self._forward_method.__name__, + self.maybe_compiled_forward_spyre is not self.forward_spyre, + ) + def forward_oot(self, x: torch.Tensor) -> torch.Tensor: """OOT forward pass using custom op to bypass torch.compile. diff --git a/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py b/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py index 183159331..167ab04f3 100644 --- a/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py +++ b/vllm_spyre_next/vllm_spyre_next/custom_ops/vocab_parallel_embedding.py @@ -95,6 +95,13 @@ def __init__(self, *args, **kwargs): self._layer_name = register_layer(self, "spyre_vocab_parallel_embedding") + logger.debug_once( + "SpyreVocabParallelEmbedding: Dispatch: enabled=%s, Forward method=%s, Compiled=%s", + self.enabled(), + self._forward_method.__name__, + self.maybe_compiled_forward_spyre is not self.forward_spyre, + ) + def forward_oot(self, x: torch.Tensor) -> torch.Tensor: """OOT forward pass using custom op to bypass torch.compile. diff --git a/vllm_spyre_next/vllm_spyre_next/platform.py b/vllm_spyre_next/vllm_spyre_next/platform.py index 024d225e7..7c000e110 100644 --- a/vllm_spyre_next/vllm_spyre_next/platform.py +++ b/vllm_spyre_next/vllm_spyre_next/platform.py @@ -81,10 +81,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": # "auto" defaults to the CPUWorker as we inherit from the CpuPlatform - # from vllm_spyre_next.v1.worker.spyre_worker import TorchSpyreWorker + # Override with TorchSpyreWorker for Spyre-specific functionality worker_class = "vllm_spyre_next.v1.worker.spyre_worker.TorchSpyreWorker" - # if a torch spyre specific worker class is needed it can be loaded with - # worker_class = "vllm_spyre_next.v1.worker.spyre_worker.TorchSpyreWorker" logger.info("Loading worker from: %s", worker_class) parallel_config.worker_cls = worker_class diff --git a/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py b/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py index 52c44bbf3..eea6b7fe4 100644 --- a/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py +++ b/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py @@ -160,7 +160,7 @@ def _get_paths_from_yaml() -> str: def _cache_root() -> Path: """ - Cache directory for cloned tests (sticky between runs) + Cache directory for cloned tests (persists across runs) """ # Respect XDG if present, fallback to ~/.cache xdg = os.environ.get("XDG_CACHE_HOME") @@ -171,7 +171,8 @@ def _cache_root() -> Path: def _extract_vllm_commit_from_pyproject() -> str: """ Extract the vLLM git reference from pyproject.toml [tool.uv.sources] section. - Returns None if not found or parseable. + Raises FileNotFoundError if pyproject.toml is missing, or KeyError + if the expected source entry is not found. """ pyproject_path = Path(__file__).parent.parent.parent / "pyproject.toml" if not pyproject_path.exists(): @@ -498,7 +499,7 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item elif allow_entry.mode == "xfail_strict": item.add_marker(pytest.mark.xfail(strict=True)) - # Reorder tests so that tests with "model" in the name run first + # Reorder tests so that tests with "uses_subprocess" marker run first _reorder_tests_by_name(items) @@ -582,7 +583,7 @@ def _spyre_default_vllm_config(monkeypatch): compilation_config=CompilationConfig(custom_ops=["all"]), ) with set_current_vllm_config(config), set_forward_context(None, config): - # Set forward context so custom ops can access no_compile_layers + # Set forward context so custom ops can access the vllm config yield