Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,13 +681,8 @@ def forward(
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
latents = latents.to(latents_dtype)

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
Expand Down
4 changes: 0 additions & 4 deletions vllm_omni/diffusion/models/flux2/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,12 +1062,8 @@ def forward(
noise_pred = noise_pred[:, : latents.size(1) :]

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

if latents.dtype != latents_dtype and torch.backends.mps.is_available():
latents = latents.to(latents_dtype)

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
Expand Down
4 changes: 3 additions & 1 deletion vllm_omni/diffusion/models/mammoth_moda2/rope_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from einops import repeat
from torch import nn

from vllm_omni.platforms import current_omni_platform


def apply_real_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -119,7 +121,7 @@ def get_freqs_real(
axes_dim: tuple[int, int, int], axes_lens: tuple[int, int, int], theta: int
) -> list[tuple[torch.Tensor, torch.Tensor]]:
freqs_real = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_dtype = torch.float64 if current_omni_platform.supports_float64() else torch.float32
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
cos_emb, sin_emb = get_1d_rotary_pos_embed_real(d, e, theta=theta, freqs_dtype=freqs_dtype)
freqs_real.append((cos_emb, sin_emb))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.platforms import current_omni_platform

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -411,7 +412,7 @@ def get_freqs_cis(
axes_dim: tuple[int, int, int], axes_lens: tuple[int, int, int], theta: int
) -> list[torch.Tensor]:
freqs_cis = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_dtype = torch.float64 if current_omni_platform.supports_float64() else torch.float32
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
freqs_cis.append(emb)
Expand Down
3 changes: 2 additions & 1 deletion vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
SequenceParallelOutput,
)
from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.platforms import current_omni_platform

logger = init_logger(__name__)

Expand Down Expand Up @@ -171,7 +172,7 @@ def __init__(
# Split dimensions for temporal, height, width
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_dtype = torch.float64 if current_omni_platform.supports_float64() else torch.float32

freqs_cos = []
freqs_sin = []
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def get_free_memory(cls, device: torch.device | None = None) -> int:
def supports_cpu_offload(cls) -> bool:
return True

@classmethod
def supports_float64(cls) -> bool:
return True

@classmethod
def set_device_control_env_var(cls, devices: str | int | None) -> None:
import os
Expand Down
5 changes: 5 additions & 0 deletions vllm_omni/platforms/musa/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def supports_torch_inductor(cls) -> bool:
"""MUSA supports torch.compile with inductor backend."""
return True

@classmethod
def supports_float64(cls) -> bool:
"""MUSA does not support float64 yet."""
return False

@classmethod
def get_torch_device(cls, local_rank: int | None = None) -> torch.device:
"""Get the torch device for MUSA platform.
Expand Down
Loading