Skip to content

Commit

Permalink
[wgsl-in] Avoid splatting all binary operator expressions (#2440)
Browse files Browse the repository at this point in the history
* [wgsl-in] Avoid splatting all binary operator expressions

Fixes #2439.

* [wgsl-in] Expand binary_op_splat function comment
  • Loading branch information
fornwall authored Aug 18, 2023
1 parent f6e99a4 commit 3da9355
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,13 +504,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}

/// Insert splats, if needed by the non-'*' operations.
///
/// See the "Binary arithmetic expressions with mixed scalar and vector operands"
/// table in the WebGPU Shading Language specification for relevant operators.
///
/// Multiply is not handled here as backends are expected to handle vec*scalar
/// operations, so inserting splats into the IR increases size needlessly.
fn binary_op_splat(
&mut self,
op: crate::BinaryOperator,
left: &mut Handle<crate::Expression>,
right: &mut Handle<crate::Expression>,
) -> Result<(), Error<'source>> {
if op != crate::BinaryOperator::Multiply {
if matches!(
op,
crate::BinaryOperator::Add
| crate::BinaryOperator::Subtract
| crate::BinaryOperator::Divide
| crate::BinaryOperator::Modulo
) {
self.grow_types(*left)?.grow_types(*right)?;

let left_size = match *self.resolved_inner(*left) {
Expand Down
48 changes: 48 additions & 0 deletions src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,54 @@ fn parse_expressions() {
}").unwrap();
}

#[test]
fn binary_expression_mixed_scalar_and_vector_operands() {
for (operand, expect_splat) in [
('<', false),
('>', false),
('&', false),
('|', false),
('+', true),
('-', true),
('*', false),
('/', true),
('%', true),
] {
let module = parse_str(&format!(
"
const some_vec = vec3<f32>(1.0, 1.0, 1.0);
@fragment
fn main() -> @location(0) vec4<f32> {{
if (all(1.0 {operand} some_vec)) {{
return vec4(0.0);
}}
return vec4(1.0);
}}
"
))
.unwrap();

let expressions = &&module.entry_points[0].function.expressions;

let found_expressions = expressions
.iter()
.filter(|&(_, e)| {
if let crate::Expression::Binary { left, .. } = *e {
matches!(
(expect_splat, &expressions[left]),
(false, &crate::Expression::Literal(crate::Literal::F32(..)))
| (true, &crate::Expression::Splat { .. })
)
} else {
false
}
})
.count();

assert_eq!(found_expressions, 1);
}
}

#[test]
fn parse_pointers() {
parse_str(
Expand Down

0 comments on commit 3da9355

Please sign in to comment.