diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py b/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py index c7574c1c854..c3bea7dd1c4 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py @@ -681,13 +681,8 @@ def forward( neg_noise_pred = neg_noise_pred[:, : latents.size(1)] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py index cc25c6b7043..00d3288501b 100644 --- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py +++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py @@ -1062,12 +1062,8 @@ def forward( noise_pred = noise_pred[:, : latents.size(1) :] # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - if latents.dtype != latents_dtype and torch.backends.mps.is_available(): - latents = latents.to(latents_dtype) - if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: diff --git a/vllm_omni/diffusion/models/mammoth_moda2/rope_real.py b/vllm_omni/diffusion/models/mammoth_moda2/rope_real.py index d16181a6913..64cc4324869 100644 --- a/vllm_omni/diffusion/models/mammoth_moda2/rope_real.py +++ b/vllm_omni/diffusion/models/mammoth_moda2/rope_real.py @@ -18,6 +18,8 @@ from einops import repeat from torch import nn +from vllm_omni.platforms import current_omni_platform + def apply_real_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor: """ @@ -119,7 +121,7 @@ def get_freqs_real( axes_dim: tuple[int, int, int], axes_lens: tuple[int, int, int], theta: int ) -> list[tuple[torch.Tensor, torch.Tensor]]: freqs_real = [] - freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + freqs_dtype = torch.float64 if current_omni_platform.supports_float64() else torch.float32 for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): cos_emb, sin_emb = get_1d_rotary_pos_embed_real(d, e, theta=theta, freqs_dtype=freqs_dtype) freqs_real.append((cos_emb, sin_emb)) diff --git a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py index b626ca1d85b..9ff681a3c0b 100644 --- a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py +++ b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py @@ -19,6 +19,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) @@ -411,7 +412,7 @@ def get_freqs_cis( axes_dim: tuple[int, int, int], axes_lens: tuple[int, int, int], theta: int ) -> list[torch.Tensor]: freqs_cis = [] - freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + freqs_dtype = torch.float64 if current_omni_platform.supports_float64() else torch.float32 for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) freqs_cis.append(emb) diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index c4e3b40cdd5..cf1761123d7 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -29,6 +29,7 @@ SequenceParallelOutput, ) from vllm_omni.diffusion.forward_context import get_forward_context +from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__) @@ -171,7 +172,7 @@ def __init__( # Split dimensions for temporal, height, width h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim - freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + freqs_dtype = torch.float64 if current_omni_platform.supports_float64() else torch.float32 freqs_cos = [] freqs_sin = [] diff --git a/vllm_omni/platforms/interface.py b/vllm_omni/platforms/interface.py index 4325851e5fb..4df297fa021 100644 --- a/vllm_omni/platforms/interface.py +++ b/vllm_omni/platforms/interface.py @@ -117,6 +117,10 @@ def get_free_memory(cls, device: torch.device | None = None) -> int: def supports_cpu_offload(cls) -> bool: return True + @classmethod + def supports_float64(cls) -> bool: + return True + @classmethod def set_device_control_env_var(cls, devices: str | int | None) -> None: import os diff --git a/vllm_omni/platforms/musa/platform.py b/vllm_omni/platforms/musa/platform.py index 932ce62d27e..3bd520c61b9 100644 --- a/vllm_omni/platforms/musa/platform.py +++ b/vllm_omni/platforms/musa/platform.py @@ -81,6 +81,11 @@ def supports_torch_inductor(cls) -> bool: """MUSA supports torch.compile with inductor backend.""" return True + @classmethod + def supports_float64(cls) -> bool: + """MUSA does not support float64 yet.""" + return False + @classmethod def get_torch_device(cls, local_rank: int | None = None) -> torch.device: """Get the torch device for MUSA platform.