diff --git a/mistralrs-core/src/models/phi3_5_moe.rs b/mistralrs-core/src/models/phi3_5_moe.rs index 890460f613..acdc88bcf6 100644 --- a/mistralrs-core/src/models/phi3_5_moe.rs +++ b/mistralrs-core/src/models/phi3_5_moe.rs @@ -436,7 +436,7 @@ impl MoeMlp { self.router_jitter_noise, )?; - let mut final_hidden_states = Tensor::zeros((bs * seq, hidden), xs.dtype(), xs.device())?; + let mut final_hidden_states = Tensor::zeros((bs * seq, hidden), xs.dtype(), xs_dev)?; // One hot encode the selected experts to create an expert mask // this will be used to easily index which expert to activate @@ -471,8 +471,11 @@ impl MoeMlp { let current_hidden_states = exp_out.broadcast_mul(¤t_routing_weights)?; final_hidden_states = final_hidden_states.index_add( - &top_x.contiguous()?, - ¤t_hidden_states.squeeze(0)?.to_dtype(xs.dtype())?, + &top_x.contiguous()?.to_device(xs_dev)?, + ¤t_hidden_states + .squeeze(0)? + .to_dtype(xs.dtype())? + .to_device(xs_dev)?, 0, )?; }