diff --git a/vllm_omni/diffusion/models/helios/scheduling_helios.py b/vllm_omni/diffusion/models/helios/scheduling_helios.py index 69bcaa1a2d..847dea2d0e 100644 --- a/vllm_omni/diffusion/models/helios/scheduling_helios.py +++ b/vllm_omni/diffusion/models/helios/scheduling_helios.py @@ -410,12 +410,13 @@ def multistep_uni_p_bh_update( D1s.append((mi - m0) / rk) rks.append(1.0) - rks = torch.tensor(rks, device=device) + rks = torch.tensor(rks, device=device, dtype=torch.float32) R = [] b = [] hh = -h if self.predict_x0 else h + hh = hh.float() h_phi_1 = torch.expm1(hh) h_phi_k = h_phi_1 / hh - 1 @@ -435,7 +436,7 @@ def multistep_uni_p_bh_update( h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) - b = torch.tensor(b, device=device) + b = torch.tensor(b, device=device, dtype=torch.float32) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) @@ -523,12 +524,13 @@ def multistep_uni_c_bh_update( D1s.append((mi - m0) / rk) rks.append(1.0) - rks = torch.tensor(rks, device=device) + rks = torch.tensor(rks, device=device, dtype=torch.float32) R = [] b = [] hh = -h if self.predict_x0 else h + hh = hh.float() h_phi_1 = torch.expm1(hh) h_phi_k = h_phi_1 / hh - 1 @@ -548,7 +550,7 @@ def multistep_uni_c_bh_update( h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) - b = torch.tensor(b, device=device) + b = torch.tensor(b, device=device, dtype=torch.float32) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1)