diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 84c0d993e1..707a1ff1db 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -510,7 +510,13 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { left: &mut Handle, right: &mut Handle, ) -> Result<(), Error<'source>> { - if op != crate::BinaryOperator::Multiply { + if matches!( + op, + crate::BinaryOperator::Add + | crate::BinaryOperator::Subtract + | crate::BinaryOperator::Divide + | crate::BinaryOperator::Modulo + ) { self.grow_types(*left)?.grow_types(*right)?; let left_size = match *self.resolved_inner(*left) { diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index 80ea261434..a430bb1579 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -387,6 +387,51 @@ fn parse_expressions() { }").unwrap(); } +#[test] +fn binary_expression_mixed_scalar_and_vector_operands() { + for (operand, expect_splat) in [ + ('<', false), + ('>', false), + ('+', true), + ('-', true), + ('/', true), + ('*', false), + ] { + let module = parse_str(&format!( + " + const some_vec = vec3(1.0, 1.0, 1.0); + @fragment + fn main() -> @location(0) vec4 {{ + if (all(1.0 {operand} some_vec)) {{ + return vec4(0.0); + }} + return vec4(1.0); + }} + " + )) + .unwrap(); + + let expressions = &&module.entry_points[0].function.expressions; + + let found_expressions = expressions + .iter() + .filter(|&(_, e)| { + if let crate::Expression::Binary { left, .. } = *e { + matches!( + (expect_splat, &expressions[left]), + (false, &crate::Expression::Literal(crate::Literal::F32(..))) + | (true, &crate::Expression::Splat { .. }) + ) + } else { + false + } + }) + .count(); + + assert_eq!(found_expressions, 1); + } +} + #[test] fn parse_pointers() { parse_str(