diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 884c4100c9..f3a7b22dd3 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -504,13 +504,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { } /// Insert splats, if needed by the non-'*' operations. + /// + /// See the "Binary arithmetic expressions with mixed scalar and vector operands" + /// table in the WebGPU Shading Language specification for relevant operators. + /// + /// Multiply is not handled here as backends are expected to handle vec*scalar + /// operations, so inserting splats into the IR increases size needlessly. fn binary_op_splat( &mut self, op: crate::BinaryOperator, left: &mut Handle, right: &mut Handle, ) -> Result<(), Error<'source>> { - if op != crate::BinaryOperator::Multiply { + if matches!( + op, + crate::BinaryOperator::Add + | crate::BinaryOperator::Subtract + | crate::BinaryOperator::Divide + | crate::BinaryOperator::Modulo + ) { self.grow_types(*left)?.grow_types(*right)?; let left_size = match *self.resolved_inner(*left) { diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index 084655d666..02fc110cae 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -387,6 +387,54 @@ fn parse_expressions() { }").unwrap(); } +#[test] +fn binary_expression_mixed_scalar_and_vector_operands() { + for (operand, expect_splat) in [ + ('<', false), + ('>', false), + ('&', false), + ('|', false), + ('+', true), + ('-', true), + ('*', false), + ('/', true), + ('%', true), + ] { + let module = parse_str(&format!( + " + const some_vec = vec3(1.0, 1.0, 1.0); + @fragment + fn main() -> @location(0) vec4 {{ + if (all(1.0 {operand} some_vec)) {{ + return vec4(0.0); + }} + return vec4(1.0); + }} + " + )) + .unwrap(); + + let expressions = &&module.entry_points[0].function.expressions; + + let found_expressions = expressions + .iter() + .filter(|&(_, e)| { + if let crate::Expression::Binary { left, .. } = *e { + matches!( + (expect_splat, &expressions[left]), + (false, &crate::Expression::Literal(crate::Literal::F32(..))) + | (true, &crate::Expression::Splat { .. }) + ) + } else { + false + } + }) + .count(); + + assert_eq!(found_expressions, 1); + } +} + #[test] fn parse_pointers() { parse_str(