From 717117315fb2a65e7f161842f3c4c03ff54b5c07 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 11 Dec 2025 10:21:04 -0800 Subject: [PATCH 1/3] refactor(naga): Extract `call_builtin` to a separate function --- naga/src/front/wgsl/lower/mod.rs | 1530 +++++++++++++++--------------- 1 file changed, 767 insertions(+), 763 deletions(-) diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 690e6d5075..416bc37360 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2008,8 +2008,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { function, arguments, None, - &mut ctx.as_expression(block, &mut emitter), true, + &mut ctx.as_expression(block, &mut emitter), )?; block.extend(emitter.finish(&ctx.function.expressions)); return Ok(()); @@ -2336,7 +2336,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { result_ty, } => { let handle = self - .call(span, function, arguments, result_ty, ctx, false)? + .call(span, function, arguments, result_ty, false, ctx)? .ok_or(Error::FunctionReturnsVoid(function.span))?; return Ok(Typed::Plain(handle)); } @@ -2654,6 +2654,769 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } + fn call_builtin( + &mut self, + span: Span, + function: &ast::Ident<'source>, + arguments: &[Handle>], + result_ty: Option<(Handle>, Span)>, + is_statement: bool, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<'source, Option>> { + let expr = if let Some(fun) = conv::map_relational_fun(function.name) { + let mut args = ctx.prepare_args(arguments, 1, span); + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + // Check for no-op all(bool) and any(bool): + let argument_unmodified = matches!( + fun, + ir::RelationalFunction::All | ir::RelationalFunction::Any + ) && { + matches!( + resolve_inner!(ctx, argument), + &ir::TypeInner::Scalar(ir::Scalar { + kind: ir::ScalarKind::Bool, + .. + }) + ) + }; + + if argument_unmodified { + return Ok(Some(argument)); + } else { + ir::Expression::Relational { fun, argument } + } + } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = self.expression(args.next()?, ctx)?; + args.finish()?; + + ir::Expression::Derivative { axis, ctrl, expr } + } else if let Some(fun) = conv::map_standard_fun(function.name) { + self.math_function_helper(span, fun, arguments, ctx)? + } else if let Some(fun) = Texture::map(function.name) { + self.texture_sample_helper(fun, arguments, span, ctx)? + } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { + return Ok(Some( + self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, + )); + } else if let Some(mode) = SubgroupGather::map(function.name) { + return Ok(Some( + self.subgroup_gather_helper(span, mode, arguments, ctx)?, + )); + } else if let Some(fun) = ir::AtomicFunction::map(function.name) { + return self.atomic_helper(span, fun, arguments, is_statement, ctx); + } else { + match function.name { + "select" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let reject_orig = args.next()?; + let accept_orig = args.next()?; + let mut values = [ + self.expression_for_abstract(reject_orig, ctx)?, + self.expression_for_abstract(accept_orig, ctx)?, + ]; + let condition = self.expression(args.next()?, ctx)?; + + args.finish()?; + + let diagnostic_details = + |ctx: &ExpressionContext<'_, '_, '_>, + ty_res: &proc::TypeResolution, + orig_expr| { + ( + ctx.ast_expressions.get_span(orig_expr), + format!("`{}`", ctx.as_diagnostic_display(ty_res)), + ) + }; + for (&value, orig_value) in + values.iter().zip([reject_orig, accept_orig]) + { + let value_ty_res = resolve!(ctx, value); + if value_ty_res + .inner_with(&ctx.module.types) + .vector_size_and_scalar() + .is_none() + { + let (arg_span, arg_type) = + diagnostic_details(ctx, value_ty_res, orig_value); + return Err(Box::new(Error::SelectUnexpectedArgumentType { + arg_span, + arg_type, + })); + } + } + let mut consensus_scalar = ctx + .automatic_conversion_consensus(&values) + .map_err(|_idx| { + let [reject, accept] = values; + let [(reject_span, reject_type), (accept_span, accept_type)] = + [(reject_orig, reject), (accept_orig, accept)].map( + |(orig_expr, expr)| { + let ty_res = &ctx.typifier()[expr]; + diagnostic_details(ctx, ty_res, orig_expr) + }, + ); + Error::SelectRejectAndAcceptHaveNoCommonType { + reject_span, + reject_type, + accept_span, + accept_type, + } + })?; + if !ctx.is_const(condition) { + consensus_scalar = consensus_scalar.concretize(); + } + + ctx.convert_slice_to_common_leaf_scalar(&mut values, consensus_scalar)?; + + let [reject, accept] = values; + + ir::Expression::Select { + reject, + accept, + condition, + } + } + "arrayLength" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = self.expression(args.next()?, ctx)?; + args.finish()?; + + ir::Expression::ArrayLength(expr) + } + "atomicLoad" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let (pointer, _scalar) = self.atomic_pointer(args.next()?, ctx)?; + args.finish()?; + + ir::Expression::Load { pointer } + } + "atomicStore" => { + let mut args = ctx.prepare_args(arguments, 2, span); + let (pointer, scalar) = self.atomic_pointer(args.next()?, ctx)?; + let value = + self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + rctx.block + .push(ir::Statement::Store { pointer, value }, span); + return Ok(None); + } + "atomicCompareExchangeWeak" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let (pointer, scalar) = self.atomic_pointer(args.next()?, ctx)?; + + let compare = + self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; + + let value = args.next()?; + let value_span = ctx.ast_expressions.get_span(value); + let value = self.expression_with_leaf_scalar(value, scalar, ctx)?; + + args.finish()?; + + let expression = match *resolve_inner!(ctx, value) { + ir::TypeInner::Scalar(scalar) => ir::Expression::AtomicResult { + ty: ctx.module.generate_predeclared_type( + ir::PredeclaredType::AtomicCompareExchangeWeakResult( + scalar, + ), + ), + comparison: true, + }, + _ => { + return Err(Box::new(Error::InvalidAtomicOperandType( + value_span, + ))) + } + }; + + let result = ctx.interrupt_emitter(expression, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + ir::Statement::Atomic { + pointer, + fun: ir::AtomicFunction::Exchange { + compare: Some(compare), + }, + value, + result: Some(result), + }, + span, + ); + return Ok(Some(result)); + } + "textureAtomicMin" | "textureAtomicMax" | "textureAtomicAdd" + | "textureAtomicAnd" | "textureAtomicOr" | "textureAtomicXor" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = self.expression(image, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (_, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let value = self.expression(args.next()?, ctx)?; + + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + let stmt = ir::Statement::ImageAtomic { + image, + coordinate, + array_index, + fun: match function.name { + "textureAtomicMin" => ir::AtomicFunction::Min, + "textureAtomicMax" => ir::AtomicFunction::Max, + "textureAtomicAdd" => ir::AtomicFunction::Add, + "textureAtomicAnd" => ir::AtomicFunction::And, + "textureAtomicOr" => ir::AtomicFunction::InclusiveOr, + "textureAtomicXor" => ir::AtomicFunction::ExclusiveOr, + _ => unreachable!(), + }, + value, + }; + rctx.block.push(stmt, span); + return Ok(None); + } + "storageBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(ir::Statement::ControlBarrier(ir::Barrier::STORAGE), span); + return Ok(None); + } + "workgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(ir::Statement::ControlBarrier(ir::Barrier::WORK_GROUP), span); + return Ok(None); + } + "subgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(ir::Statement::ControlBarrier(ir::Barrier::SUB_GROUP), span); + return Ok(None); + } + "textureBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(ir::Statement::ControlBarrier(ir::Barrier::TEXTURE), span); + return Ok(None); + } + "workgroupUniformLoad" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = args.next()?; + args.finish()?; + + let pointer = self.expression(expr, ctx)?; + let result_ty = match *resolve_inner!(ctx, pointer) { + ir::TypeInner::Pointer { + base, + space: ir::AddressSpace::WorkGroup, + } => match ctx.module.types[base].inner { + // Match `Expression::Load` semantics: + // loading through a pointer to `atomic` produces a `T`. + ir::TypeInner::Atomic(scalar) => ctx.module.types.insert( + ir::Type { + name: None, + inner: ir::TypeInner::Scalar(scalar), + }, + span, + ), + _ => base, + }, + ir::TypeInner::ValuePointer { + size, + scalar, + space: ir::AddressSpace::WorkGroup, + } => ctx.module.types.insert( + ir::Type { + name: None, + inner: match size { + Some(size) => ir::TypeInner::Vector { size, scalar }, + None => ir::TypeInner::Scalar(scalar), + }, + }, + span, + ), + _ => { + let span = ctx.ast_expressions.get_span(expr); + return Err(Box::new(Error::InvalidWorkGroupUniformLoad(span))); + } + }; + let result = ctx.interrupt_emitter( + ir::Expression::WorkGroupUniformLoadResult { ty: result_ty }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + ir::Statement::WorkGroupUniformLoad { pointer, result }, + span, + ); + + return Ok(Some(result)); + } + "textureStore" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = self.expression(image, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (class, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + let scalar = if let ir::ImageClass::Storage { format, .. } = class { + format.into() + } else { + return Err(Box::new(Error::NotStorageTexture(image_span))); + }; + + let value = + self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; + + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + let stmt = ir::Statement::ImageStore { + image, + coordinate, + array_index, + value, + }; + rctx.block.push(stmt, span); + return Ok(None); + } + "textureLoad" => { + let mut args = ctx.prepare_args(arguments, 2, span); + + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = self.expression(image, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (class, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let level = class + .is_mipmapped() + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let sample = class + .is_multisampled() + .then(|| self.expression(args.next()?, ctx)) + .transpose()?; + + args.finish()?; + + ir::Expression::ImageLoad { + image, + coordinate, + array_index, + level, + sample, + } + } + "textureDimensions" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + let level = args + .next() + .map(|arg| self.expression(arg, ctx)) + .ok() + .transpose()?; + args.finish()?; + + ir::Expression::ImageQuery { + image, + query: ir::ImageQuery::Size { level }, + } + } + "textureNumLevels" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + ir::Expression::ImageQuery { + image, + query: ir::ImageQuery::NumLevels, + } + } + "textureNumLayers" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + ir::Expression::ImageQuery { + image, + query: ir::ImageQuery::NumLayers, + } + } + "textureNumSamples" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + ir::Expression::ImageQuery { + image, + query: ir::ImageQuery::NumSamples, + } + } + "rayQueryInitialize" => { + let mut args = ctx.prepare_args(arguments, 3, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + let acceleration_structure = self.expression(args.next()?, ctx)?; + let descriptor = self.expression(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_ray_desc_type(); + let fun = ir::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + }; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + rctx.block + .push(ir::Statement::RayQuery { query, fun }, span); + return Ok(None); + } + "getCommittedHitVertexPositions" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_vertex_return_type(); + + ir::Expression::RayQueryVertexPositions { + query, + committed: true, + } + } + "getCandidateHitVertexPositions" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_vertex_return_type(); + + ir::Expression::RayQueryVertexPositions { + query, + committed: false, + } + } + "rayQueryProceed" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let result = + ctx.interrupt_emitter(ir::Expression::RayQueryProceedResult, span)?; + let fun = ir::RayQueryFunction::Proceed { result }; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(ir::Statement::RayQuery { query, fun }, span); + return Ok(Some(result)); + } + "rayQueryGenerateIntersection" => { + let mut args = ctx.prepare_args(arguments, 2, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + let hit_t = self.expression(args.next()?, ctx)?; + args.finish()?; + + let fun = ir::RayQueryFunction::GenerateIntersection { hit_t }; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(ir::Statement::RayQuery { query, fun }, span); + return Ok(None); + } + "rayQueryConfirmIntersection" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let fun = ir::RayQueryFunction::ConfirmIntersection; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(ir::Statement::RayQuery { query, fun }, span); + return Ok(None); + } + "rayQueryTerminate" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let fun = ir::RayQueryFunction::Terminate; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(ir::Statement::RayQuery { query, fun }, span); + return Ok(None); + } + "rayQueryGetCommittedIntersection" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_ray_intersection_type(); + ir::Expression::RayQueryGetIntersection { + query, + committed: true, + } + } + "rayQueryGetCandidateIntersection" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_ray_intersection_type(); + ir::Expression::RayQueryGetIntersection { + query, + committed: false, + } + } + "RayDesc" => { + let ty = ctx.module.generate_ray_desc_type(); + let handle = self.construct( + span, + &ast::ConstructorType::Type(ty), + function.span, + arguments, + ctx, + )?; + return Ok(Some(handle)); + } + "subgroupBallot" => { + let mut args = ctx.prepare_args(arguments, 0, span); + let predicate = if arguments.len() == 1 { + Some(self.expression(args.next()?, ctx)?) + } else { + None + }; + args.finish()?; + + let result = + ctx.interrupt_emitter(ir::Expression::SubgroupBallotResult, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(ir::Statement::SubgroupBallot { result, predicate }, span); + return Ok(Some(result)); + } + "quadSwapX" => { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode: crate::GatherMode::QuadSwap(crate::Direction::X), + argument, + result, + }, + span, + ); + return Ok(Some(result)); + } + "quadSwapY" => { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode: crate::GatherMode::QuadSwap(crate::Direction::Y), + argument, + result, + }, + span, + ); + return Ok(Some(result)); + } + "quadSwapDiagonal" => { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode: crate::GatherMode::QuadSwap(crate::Direction::Diagonal), + argument, + result, + }, + span, + ); + return Ok(Some(result)); + } + "coopLoad" | "coopLoadT" => { + let row_major = function.name.ends_with("T"); + let mut args = ctx.prepare_args(arguments, 1, span); + let pointer = self.expression(args.next()?, ctx)?; + let (matrix_ty, matrix_span) = result_ty.expect("generic argument"); + let (columns, rows, role) = match ctx.types[matrix_ty] { + ast::Type::CooperativeMatrix { + columns, + rows, + role, + .. + } => (columns, rows, role), + _ => return Err(Box::new(Error::InvalidCooperativeLoadType(matrix_span))), + }; + let stride = if args.total_args > 1 { + self.expression(args.next()?, ctx)? + } else { + // Infer the stride from the matrix type + let stride = if row_major { + columns as u32 + } else { + rows as u32 + }; + ctx.append_expression( + ir::Expression::Literal(ir::Literal::U32(stride)), + Span::UNDEFINED, + )? + }; + args.finish()?; + + crate::Expression::CooperativeLoad { + columns, + rows, + role, + data: crate::CooperativeData { + pointer, + stride, + row_major, + }, + } + } + "coopStore" | "coopStoreT" => { + let row_major = function.name.ends_with("T"); + + let mut args = ctx.prepare_args(arguments, 2, span); + let target = self.expression(args.next()?, ctx)?; + let pointer = self.expression(args.next()?, ctx)?; + let stride = if args.total_args > 2 { + self.expression(args.next()?, ctx)? + } else { + // Infer the stride from the matrix type + let stride = match *resolve_inner!(ctx, target) { + ir::TypeInner::CooperativeMatrix { columns, rows, .. } => { + if row_major { + columns as u32 + } else { + rows as u32 + } + } + _ => 0, + }; + ctx.append_expression( + ir::Expression::Literal(ir::Literal::U32(stride)), + Span::UNDEFINED, + )? + }; + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::CooperativeStore { + target, + data: crate::CooperativeData { + pointer, + stride, + row_major, + }, + }, + span, + ); + return Ok(None); + } + "coopMultiplyAdd" => { + let mut args = ctx.prepare_args(arguments, 3, span); + let a = self.expression(args.next()?, ctx)?; + let b = self.expression(args.next()?, ctx)?; + let c = self.expression(args.next()?, ctx)?; + args.finish()?; + + ir::Expression::CooperativeMultiplyAdd { a, b, c } + } + _ => return Err(Box::new(Error::UnknownIdent(function.span, function.name))), + } + }; + + let expr = ctx.append_expression(expr, span)?; + Ok(Some(expr)) + } + /// Generate Naga IR for call expressions and statements, and type /// constructor expressions. /// @@ -2678,8 +3441,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { function: &ast::Ident<'source>, arguments: &[Handle>], result_ty: Option<(Handle>, Span)>, - ctx: &mut ExpressionContext<'source, '_, '_>, is_statement: bool, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Option>> { let function_span = function.span; match ctx.globals.get(function.name) { @@ -2762,767 +3525,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(result) } - None => { - let span = function_span; - let expr = if let Some(fun) = conv::map_relational_fun(function.name) { - let mut args = ctx.prepare_args(arguments, 1, span); - let argument = self.expression(args.next()?, ctx)?; - args.finish()?; - - // Check for no-op all(bool) and any(bool): - let argument_unmodified = matches!( - fun, - ir::RelationalFunction::All | ir::RelationalFunction::Any - ) && { - matches!( - resolve_inner!(ctx, argument), - &ir::TypeInner::Scalar(ir::Scalar { - kind: ir::ScalarKind::Bool, - .. - }) - ) - }; - - if argument_unmodified { - return Ok(Some(argument)); - } else { - ir::Expression::Relational { fun, argument } - } - } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) { - let mut args = ctx.prepare_args(arguments, 1, span); - let expr = self.expression(args.next()?, ctx)?; - args.finish()?; - - ir::Expression::Derivative { axis, ctrl, expr } - } else if let Some(fun) = conv::map_standard_fun(function.name) { - self.math_function_helper(span, fun, arguments, ctx)? - } else if let Some(fun) = Texture::map(function.name) { - self.texture_sample_helper(fun, arguments, span, ctx)? - } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { - return Ok(Some( - self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, - )); - } else if let Some(mode) = SubgroupGather::map(function.name) { - return Ok(Some( - self.subgroup_gather_helper(span, mode, arguments, ctx)?, - )); - } else if let Some(fun) = ir::AtomicFunction::map(function.name) { - return self.atomic_helper(span, fun, arguments, is_statement, ctx); - } else { - match function.name { - "select" => { - let mut args = ctx.prepare_args(arguments, 3, span); - - let reject_orig = args.next()?; - let accept_orig = args.next()?; - let mut values = [ - self.expression_for_abstract(reject_orig, ctx)?, - self.expression_for_abstract(accept_orig, ctx)?, - ]; - let condition = self.expression(args.next()?, ctx)?; - - args.finish()?; - - let diagnostic_details = - |ctx: &ExpressionContext<'_, '_, '_>, - ty_res: &proc::TypeResolution, - orig_expr| { - ( - ctx.ast_expressions.get_span(orig_expr), - format!("`{}`", ctx.as_diagnostic_display(ty_res)), - ) - }; - for (&value, orig_value) in - values.iter().zip([reject_orig, accept_orig]) - { - let value_ty_res = resolve!(ctx, value); - if value_ty_res - .inner_with(&ctx.module.types) - .vector_size_and_scalar() - .is_none() - { - let (arg_span, arg_type) = - diagnostic_details(ctx, value_ty_res, orig_value); - return Err(Box::new(Error::SelectUnexpectedArgumentType { - arg_span, - arg_type, - })); - } - } - let mut consensus_scalar = ctx - .automatic_conversion_consensus(&values) - .map_err(|_idx| { - let [reject, accept] = values; - let [(reject_span, reject_type), (accept_span, accept_type)] = - [(reject_orig, reject), (accept_orig, accept)].map( - |(orig_expr, expr)| { - let ty_res = &ctx.typifier()[expr]; - diagnostic_details(ctx, ty_res, orig_expr) - }, - ); - Error::SelectRejectAndAcceptHaveNoCommonType { - reject_span, - reject_type, - accept_span, - accept_type, - } - })?; - if !ctx.is_const(condition) { - consensus_scalar = consensus_scalar.concretize(); - } - - ctx.convert_slice_to_common_leaf_scalar(&mut values, consensus_scalar)?; - let [reject, accept] = values; - - ir::Expression::Select { - reject, - accept, - condition, - } - } - "arrayLength" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let expr = self.expression(args.next()?, ctx)?; - args.finish()?; - - ir::Expression::ArrayLength(expr) - } - "atomicLoad" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let (pointer, _scalar) = self.atomic_pointer(args.next()?, ctx)?; - args.finish()?; - - ir::Expression::Load { pointer } - } - "atomicStore" => { - let mut args = ctx.prepare_args(arguments, 2, span); - let (pointer, scalar) = self.atomic_pointer(args.next()?, ctx)?; - let value = - self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; - args.finish()?; - - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .extend(rctx.emitter.finish(&rctx.function.expressions)); - rctx.emitter.start(&rctx.function.expressions); - rctx.block - .push(ir::Statement::Store { pointer, value }, span); - return Ok(None); - } - "atomicCompareExchangeWeak" => { - let mut args = ctx.prepare_args(arguments, 3, span); - - let (pointer, scalar) = self.atomic_pointer(args.next()?, ctx)?; - - let compare = - self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; - - let value = args.next()?; - let value_span = ctx.ast_expressions.get_span(value); - let value = self.expression_with_leaf_scalar(value, scalar, ctx)?; - - args.finish()?; - - let expression = match *resolve_inner!(ctx, value) { - ir::TypeInner::Scalar(scalar) => ir::Expression::AtomicResult { - ty: ctx.module.generate_predeclared_type( - ir::PredeclaredType::AtomicCompareExchangeWeakResult( - scalar, - ), - ), - comparison: true, - }, - _ => { - return Err(Box::new(Error::InvalidAtomicOperandType( - value_span, - ))) - } - }; - - let result = ctx.interrupt_emitter(expression, span)?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - ir::Statement::Atomic { - pointer, - fun: ir::AtomicFunction::Exchange { - compare: Some(compare), - }, - value, - result: Some(result), - }, - span, - ); - return Ok(Some(result)); - } - "textureAtomicMin" | "textureAtomicMax" | "textureAtomicAdd" - | "textureAtomicAnd" | "textureAtomicOr" | "textureAtomicXor" => { - let mut args = ctx.prepare_args(arguments, 3, span); - - let image = args.next()?; - let image_span = ctx.ast_expressions.get_span(image); - let image = self.expression(image, ctx)?; - - let coordinate = self.expression(args.next()?, ctx)?; - - let (_, arrayed) = ctx.image_data(image, image_span)?; - let array_index = arrayed - .then(|| { - args.min_args += 1; - self.expression(args.next()?, ctx) - }) - .transpose()?; - - let value = self.expression(args.next()?, ctx)?; - - args.finish()?; - - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .extend(rctx.emitter.finish(&rctx.function.expressions)); - rctx.emitter.start(&rctx.function.expressions); - let stmt = ir::Statement::ImageAtomic { - image, - coordinate, - array_index, - fun: match function.name { - "textureAtomicMin" => ir::AtomicFunction::Min, - "textureAtomicMax" => ir::AtomicFunction::Max, - "textureAtomicAdd" => ir::AtomicFunction::Add, - "textureAtomicAnd" => ir::AtomicFunction::And, - "textureAtomicOr" => ir::AtomicFunction::InclusiveOr, - "textureAtomicXor" => ir::AtomicFunction::ExclusiveOr, - _ => unreachable!(), - }, - value, - }; - rctx.block.push(stmt, span); - return Ok(None); - } - "storageBarrier" => { - ctx.prepare_args(arguments, 0, span).finish()?; - - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .push(ir::Statement::ControlBarrier(ir::Barrier::STORAGE), span); - return Ok(None); - } - "workgroupBarrier" => { - ctx.prepare_args(arguments, 0, span).finish()?; - - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .push(ir::Statement::ControlBarrier(ir::Barrier::WORK_GROUP), span); - return Ok(None); - } - "subgroupBarrier" => { - ctx.prepare_args(arguments, 0, span).finish()?; - - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .push(ir::Statement::ControlBarrier(ir::Barrier::SUB_GROUP), span); - return Ok(None); - } - "textureBarrier" => { - ctx.prepare_args(arguments, 0, span).finish()?; - - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .push(ir::Statement::ControlBarrier(ir::Barrier::TEXTURE), span); - return Ok(None); - } - "workgroupUniformLoad" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let expr = args.next()?; - args.finish()?; - - let pointer = self.expression(expr, ctx)?; - let result_ty = match *resolve_inner!(ctx, pointer) { - ir::TypeInner::Pointer { - base, - space: ir::AddressSpace::WorkGroup, - } => match ctx.module.types[base].inner { - // Match `Expression::Load` semantics: - // loading through a pointer to `atomic` produces a `T`. - ir::TypeInner::Atomic(scalar) => ctx.module.types.insert( - ir::Type { - name: None, - inner: ir::TypeInner::Scalar(scalar), - }, - span, - ), - _ => base, - }, - ir::TypeInner::ValuePointer { - size, - scalar, - space: ir::AddressSpace::WorkGroup, - } => ctx.module.types.insert( - ir::Type { - name: None, - inner: match size { - Some(size) => ir::TypeInner::Vector { size, scalar }, - None => ir::TypeInner::Scalar(scalar), - }, - }, - span, - ), - _ => { - let span = ctx.ast_expressions.get_span(expr); - return Err(Box::new(Error::InvalidWorkGroupUniformLoad(span))); - } - }; - let result = ctx.interrupt_emitter( - ir::Expression::WorkGroupUniformLoadResult { ty: result_ty }, - span, - )?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - ir::Statement::WorkGroupUniformLoad { pointer, result }, - span, - ); - - return Ok(Some(result)); - } - "textureStore" => { - let mut args = ctx.prepare_args(arguments, 3, span); - - let image = args.next()?; - let image_span = ctx.ast_expressions.get_span(image); - let image = self.expression(image, ctx)?; - - let coordinate = self.expression(args.next()?, ctx)?; - - let (class, arrayed) = ctx.image_data(image, image_span)?; - let array_index = arrayed - .then(|| { - args.min_args += 1; - self.expression(args.next()?, ctx) - }) - .transpose()?; - let scalar = if let ir::ImageClass::Storage { format, .. } = class { - format.into() - } else { - return Err(Box::new(Error::NotStorageTexture(image_span))); - }; - - let value = - self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; - - args.finish()?; - - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .extend(rctx.emitter.finish(&rctx.function.expressions)); - rctx.emitter.start(&rctx.function.expressions); - let stmt = ir::Statement::ImageStore { - image, - coordinate, - array_index, - value, - }; - rctx.block.push(stmt, span); - return Ok(None); - } - "textureLoad" => { - let mut args = ctx.prepare_args(arguments, 2, span); - - let image = args.next()?; - let image_span = ctx.ast_expressions.get_span(image); - let image = self.expression(image, ctx)?; - - let coordinate = self.expression(args.next()?, ctx)?; - - let (class, arrayed) = ctx.image_data(image, image_span)?; - let array_index = arrayed - .then(|| { - args.min_args += 1; - self.expression(args.next()?, ctx) - }) - .transpose()?; - - let level = class - .is_mipmapped() - .then(|| { - args.min_args += 1; - self.expression(args.next()?, ctx) - }) - .transpose()?; - - let sample = class - .is_multisampled() - .then(|| self.expression(args.next()?, ctx)) - .transpose()?; - - args.finish()?; - - ir::Expression::ImageLoad { - image, - coordinate, - array_index, - level, - sample, - } - } - "textureDimensions" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let image = self.expression(args.next()?, ctx)?; - let level = args - .next() - .map(|arg| self.expression(arg, ctx)) - .ok() - .transpose()?; - args.finish()?; - - ir::Expression::ImageQuery { - image, - query: ir::ImageQuery::Size { level }, - } - } - "textureNumLevels" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let image = self.expression(args.next()?, ctx)?; - args.finish()?; - - ir::Expression::ImageQuery { - image, - query: ir::ImageQuery::NumLevels, - } - } - "textureNumLayers" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let image = self.expression(args.next()?, ctx)?; - args.finish()?; - - ir::Expression::ImageQuery { - image, - query: ir::ImageQuery::NumLayers, - } - } - "textureNumSamples" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let image = self.expression(args.next()?, ctx)?; - args.finish()?; - - ir::Expression::ImageQuery { - image, - query: ir::ImageQuery::NumSamples, - } - } - "rayQueryInitialize" => { - let mut args = ctx.prepare_args(arguments, 3, span); - let query = self.ray_query_pointer(args.next()?, ctx)?; - let acceleration_structure = self.expression(args.next()?, ctx)?; - let descriptor = self.expression(args.next()?, ctx)?; - args.finish()?; - - let _ = ctx.module.generate_ray_desc_type(); - let fun = ir::RayQueryFunction::Initialize { - acceleration_structure, - descriptor, - }; - - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .extend(rctx.emitter.finish(&rctx.function.expressions)); - rctx.emitter.start(&rctx.function.expressions); - rctx.block - .push(ir::Statement::RayQuery { query, fun }, span); - return Ok(None); - } - "getCommittedHitVertexPositions" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.ray_query_pointer(args.next()?, ctx)?; - args.finish()?; - - let _ = ctx.module.generate_vertex_return_type(); - - ir::Expression::RayQueryVertexPositions { - query, - committed: true, - } - } - "getCandidateHitVertexPositions" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.ray_query_pointer(args.next()?, ctx)?; - args.finish()?; - - let _ = ctx.module.generate_vertex_return_type(); - - ir::Expression::RayQueryVertexPositions { - query, - committed: false, - } - } - "rayQueryProceed" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.ray_query_pointer(args.next()?, ctx)?; - args.finish()?; - - let result = - ctx.interrupt_emitter(ir::Expression::RayQueryProceedResult, span)?; - let fun = ir::RayQueryFunction::Proceed { result }; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .push(ir::Statement::RayQuery { query, fun }, span); - return Ok(Some(result)); - } - "rayQueryGenerateIntersection" => { - let mut args = ctx.prepare_args(arguments, 2, span); - let query = self.ray_query_pointer(args.next()?, ctx)?; - let hit_t = self.expression(args.next()?, ctx)?; - args.finish()?; - - let fun = ir::RayQueryFunction::GenerateIntersection { hit_t }; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .push(ir::Statement::RayQuery { query, fun }, span); - return Ok(None); - } - "rayQueryConfirmIntersection" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.ray_query_pointer(args.next()?, ctx)?; - args.finish()?; - - let fun = ir::RayQueryFunction::ConfirmIntersection; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .push(ir::Statement::RayQuery { query, fun }, span); - return Ok(None); - } - "rayQueryTerminate" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.ray_query_pointer(args.next()?, ctx)?; - args.finish()?; - - let fun = ir::RayQueryFunction::Terminate; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .push(ir::Statement::RayQuery { query, fun }, span); - return Ok(None); - } - "rayQueryGetCommittedIntersection" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.ray_query_pointer(args.next()?, ctx)?; - args.finish()?; - - let _ = ctx.module.generate_ray_intersection_type(); - ir::Expression::RayQueryGetIntersection { - query, - committed: true, - } - } - "rayQueryGetCandidateIntersection" => { - let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.ray_query_pointer(args.next()?, ctx)?; - args.finish()?; - - let _ = ctx.module.generate_ray_intersection_type(); - ir::Expression::RayQueryGetIntersection { - query, - committed: false, - } - } - "RayDesc" => { - let ty = ctx.module.generate_ray_desc_type(); - let handle = self.construct( - span, - &ast::ConstructorType::Type(ty), - function.span, - arguments, - ctx, - )?; - return Ok(Some(handle)); - } - "subgroupBallot" => { - let mut args = ctx.prepare_args(arguments, 0, span); - let predicate = if arguments.len() == 1 { - Some(self.expression(args.next()?, ctx)?) - } else { - None - }; - args.finish()?; - - let result = - ctx.interrupt_emitter(ir::Expression::SubgroupBallotResult, span)?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block - .push(ir::Statement::SubgroupBallot { result, predicate }, span); - return Ok(Some(result)); - } - "quadSwapX" => { - let mut args = ctx.prepare_args(arguments, 1, span); - - let argument = self.expression(args.next()?, ctx)?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - )?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupGather { - mode: crate::GatherMode::QuadSwap(crate::Direction::X), - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } - "quadSwapY" => { - let mut args = ctx.prepare_args(arguments, 1, span); - - let argument = self.expression(args.next()?, ctx)?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - )?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupGather { - mode: crate::GatherMode::QuadSwap(crate::Direction::Y), - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } - "quadSwapDiagonal" => { - let mut args = ctx.prepare_args(arguments, 1, span); - - let argument = self.expression(args.next()?, ctx)?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - )?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupGather { - mode: crate::GatherMode::QuadSwap(crate::Direction::Diagonal), - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } - "coopLoad" | "coopLoadT" => { - let row_major = function.name.ends_with("T"); - let mut args = ctx.prepare_args(arguments, 1, span); - let pointer = self.expression(args.next()?, ctx)?; - let (matrix_ty, matrix_span) = result_ty.expect("generic argument"); - let (columns, rows, role) = match ctx.types[matrix_ty] { - ast::Type::CooperativeMatrix { - columns, - rows, - role, - .. - } => (columns, rows, role), - _ => { - return Err(Box::new(Error::InvalidCooperativeLoadType( - matrix_span, - ))) - } - }; - let stride = if args.total_args > 1 { - self.expression(args.next()?, ctx)? - } else { - // Infer the stride from the matrix type - let stride = if row_major { - columns as u32 - } else { - rows as u32 - }; - ctx.append_expression( - ir::Expression::Literal(ir::Literal::U32(stride)), - Span::UNDEFINED, - )? - }; - args.finish()?; - - crate::Expression::CooperativeLoad { - columns, - rows, - role, - data: crate::CooperativeData { - pointer, - stride, - row_major, - }, - } - } - "coopStore" | "coopStoreT" => { - let row_major = function.name.ends_with("T"); - - let mut args = ctx.prepare_args(arguments, 2, span); - let target = self.expression(args.next()?, ctx)?; - let pointer = self.expression(args.next()?, ctx)?; - let stride = if args.total_args > 2 { - self.expression(args.next()?, ctx)? - } else { - // Infer the stride from the matrix type - let stride = match *resolve_inner!(ctx, target) { - ir::TypeInner::CooperativeMatrix { columns, rows, .. } => { - if row_major { - columns as u32 - } else { - rows as u32 - } - } - _ => 0, - }; - ctx.append_expression( - ir::Expression::Literal(ir::Literal::U32(stride)), - Span::UNDEFINED, - )? - }; - args.finish()?; - - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::CooperativeStore { - target, - data: crate::CooperativeData { - pointer, - stride, - row_major, - }, - }, - span, - ); - return Ok(None); - } - "coopMultiplyAdd" => { - let mut args = ctx.prepare_args(arguments, 3, span); - let a = self.expression(args.next()?, ctx)?; - let b = self.expression(args.next()?, ctx)?; - let c = self.expression(args.next()?, ctx)?; - args.finish()?; - - ir::Expression::CooperativeMultiplyAdd { a, b, c } - } - _ => { - return Err(Box::new(Error::UnknownIdent(function.span, function.name))) - } - } - }; - - let expr = ctx.append_expression(expr, span)?; - Ok(Some(expr)) - } + None => self.call_builtin(span, function, arguments, result_ty, is_statement, ctx) } } From 872901a2f92b7e85bafd1c81e4a2e0f49a147d0a Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 11 Dec 2025 11:02:32 -0800 Subject: [PATCH 2/3] refactor(naga): Eliminate early returns in `call_builtin` --- naga/src/front/wgsl/lower/mod.rs | 149 ++++++++++++++++++------------- 1 file changed, 86 insertions(+), 63 deletions(-) diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 416bc37360..ca76301bf5 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1182,6 +1182,24 @@ enum AbstractRule { Allow, } +/// Either a raw IR expression or an already-appended expression handle. Used by `call_builtin`. +enum MaybeHandle { + Value(T), + Handle(Handle), +} + +impl From for MaybeHandle { + fn from(value: T) -> Self { + MaybeHandle::Value(value) + } +} + +impl From> for MaybeHandle { + fn from(handle: Handle) -> Self { + MaybeHandle::Handle(handle) + } +} + pub struct Lowerer<'source, 'temp> { index: &'temp Index<'source>, } @@ -2656,14 +2674,18 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn call_builtin( &mut self, - span: Span, function: &ast::Ident<'source>, arguments: &[Handle>], result_ty: Option<(Handle>, Span)>, is_statement: bool, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Option>> { - let expr = if let Some(fun) = conv::map_relational_fun(function.name) { + // We report all diagnostics associated with builtins with a span for + // just the function identifier, unlike for other kinds of calls, where + // we sometimes report with the span for the entire call expression. + let span = function.span; + + let result: Option> = if let Some(fun) = conv::map_relational_fun(function.name) { let mut args = ctx.prepare_args(arguments, 1, span); let argument = self.expression(args.next()?, ctx)?; args.finish()?; @@ -2683,30 +2705,26 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; if argument_unmodified { - return Ok(Some(argument)); + Some(argument.into()) } else { - ir::Expression::Relational { fun, argument } + Some(ir::Expression::Relational { fun, argument }.into()) } } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) { let mut args = ctx.prepare_args(arguments, 1, span); let expr = self.expression(args.next()?, ctx)?; args.finish()?; - ir::Expression::Derivative { axis, ctrl, expr } + Some(ir::Expression::Derivative { axis, ctrl, expr }.into()) } else if let Some(fun) = conv::map_standard_fun(function.name) { - self.math_function_helper(span, fun, arguments, ctx)? + Some(self.math_function_helper(span, fun, arguments, ctx)?.into()) } else if let Some(fun) = Texture::map(function.name) { - self.texture_sample_helper(fun, arguments, span, ctx)? + Some(self.texture_sample_helper(fun, arguments, span, ctx)?.into()) } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { - return Ok(Some( - self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, - )); + Some(self.subgroup_operation_helper(span, op, cop, arguments, ctx)?.into()) } else if let Some(mode) = SubgroupGather::map(function.name) { - return Ok(Some( - self.subgroup_gather_helper(span, mode, arguments, ctx)?, - )); + Some(self.subgroup_gather_helper(span, mode, arguments, ctx)?.into()) } else if let Some(fun) = ir::AtomicFunction::map(function.name) { - return self.atomic_helper(span, fun, arguments, is_statement, ctx); + self.atomic_helper(span, fun, arguments, is_statement, ctx)?.map(Into::into) } else { match function.name { "select" => { @@ -2774,25 +2792,26 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let [reject, accept] = values; - ir::Expression::Select { + Some(ir::Expression::Select { reject, accept, condition, } + .into()) } "arrayLength" => { let mut args = ctx.prepare_args(arguments, 1, span); let expr = self.expression(args.next()?, ctx)?; args.finish()?; - ir::Expression::ArrayLength(expr) + Some(ir::Expression::ArrayLength(expr).into()) } "atomicLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); let (pointer, _scalar) = self.atomic_pointer(args.next()?, ctx)?; args.finish()?; - ir::Expression::Load { pointer } + Some(ir::Expression::Load { pointer }.into()) } "atomicStore" => { let mut args = ctx.prepare_args(arguments, 2, span); @@ -2807,7 +2826,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rctx.emitter.start(&rctx.function.expressions); rctx.block .push(ir::Statement::Store { pointer, value }, span); - return Ok(None); + None } "atomicCompareExchangeWeak" => { let mut args = ctx.prepare_args(arguments, 3, span); @@ -2852,7 +2871,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - return Ok(Some(result)); + Some(result.into()) } "textureAtomicMin" | "textureAtomicMax" | "textureAtomicAdd" | "textureAtomicAnd" | "textureAtomicOr" | "textureAtomicXor" => { @@ -2896,7 +2915,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { value, }; rctx.block.push(stmt, span); - return Ok(None); + None } "storageBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; @@ -2904,7 +2923,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::ControlBarrier(ir::Barrier::STORAGE), span); - return Ok(None); + None } "workgroupBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; @@ -2912,7 +2931,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::ControlBarrier(ir::Barrier::WORK_GROUP), span); - return Ok(None); + None } "subgroupBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; @@ -2920,7 +2939,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::ControlBarrier(ir::Barrier::SUB_GROUP), span); - return Ok(None); + None } "textureBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; @@ -2928,7 +2947,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::ControlBarrier(ir::Barrier::TEXTURE), span); - return Ok(None); + None } "workgroupUniformLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -2981,7 +3000,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, ); - return Ok(Some(result)); + Some(result.into()) } "textureStore" => { let mut args = ctx.prepare_args(arguments, 3, span); @@ -3021,7 +3040,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { value, }; rctx.block.push(stmt, span); - return Ok(None); + None } "textureLoad" => { let mut args = ctx.prepare_args(arguments, 2, span); @@ -3055,13 +3074,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; - ir::Expression::ImageLoad { + Some(ir::Expression::ImageLoad { image, coordinate, array_index, level, sample, - } + }.into()) } "textureDimensions" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3073,40 +3092,40 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .transpose()?; args.finish()?; - ir::Expression::ImageQuery { + Some(ir::Expression::ImageQuery { image, query: ir::ImageQuery::Size { level }, - } + }.into()) } "textureNumLevels" => { let mut args = ctx.prepare_args(arguments, 1, span); let image = self.expression(args.next()?, ctx)?; args.finish()?; - ir::Expression::ImageQuery { + Some(ir::Expression::ImageQuery { image, query: ir::ImageQuery::NumLevels, - } + }.into()) } "textureNumLayers" => { let mut args = ctx.prepare_args(arguments, 1, span); let image = self.expression(args.next()?, ctx)?; args.finish()?; - ir::Expression::ImageQuery { + Some(ir::Expression::ImageQuery { image, query: ir::ImageQuery::NumLayers, - } + }.into()) } "textureNumSamples" => { let mut args = ctx.prepare_args(arguments, 1, span); let image = self.expression(args.next()?, ctx)?; args.finish()?; - ir::Expression::ImageQuery { + Some(ir::Expression::ImageQuery { image, query: ir::ImageQuery::NumSamples, - } + }.into()) } "rayQueryInitialize" => { let mut args = ctx.prepare_args(arguments, 3, span); @@ -3127,7 +3146,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rctx.emitter.start(&rctx.function.expressions); rctx.block .push(ir::Statement::RayQuery { query, fun }, span); - return Ok(None); + None } "getCommittedHitVertexPositions" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3136,10 +3155,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let _ = ctx.module.generate_vertex_return_type(); - ir::Expression::RayQueryVertexPositions { + Some(ir::Expression::RayQueryVertexPositions { query, committed: true, - } + }.into()) } "getCandidateHitVertexPositions" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3148,10 +3167,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let _ = ctx.module.generate_vertex_return_type(); - ir::Expression::RayQueryVertexPositions { + Some(ir::Expression::RayQueryVertexPositions { query, committed: false, - } + }.into()) } "rayQueryProceed" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3164,7 +3183,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, span); - return Ok(Some(result)); + Some(result.into()) } "rayQueryGenerateIntersection" => { let mut args = ctx.prepare_args(arguments, 2, span); @@ -3176,7 +3195,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, span); - return Ok(None); + None } "rayQueryConfirmIntersection" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3187,7 +3206,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, span); - return Ok(None); + None } "rayQueryTerminate" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3198,7 +3217,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, span); - return Ok(None); + None } "rayQueryGetCommittedIntersection" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3206,10 +3225,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; let _ = ctx.module.generate_ray_intersection_type(); - ir::Expression::RayQueryGetIntersection { + Some(ir::Expression::RayQueryGetIntersection { query, committed: true, - } + }.into()) } "rayQueryGetCandidateIntersection" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3217,21 +3236,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; let _ = ctx.module.generate_ray_intersection_type(); - ir::Expression::RayQueryGetIntersection { + Some(ir::Expression::RayQueryGetIntersection { query, committed: false, - } + }.into()) } "RayDesc" => { let ty = ctx.module.generate_ray_desc_type(); let handle = self.construct( span, &ast::ConstructorType::Type(ty), - function.span, + span, arguments, ctx, )?; - return Ok(Some(handle)); + Some(handle.into()) } "subgroupBallot" => { let mut args = ctx.prepare_args(arguments, 0, span); @@ -3247,7 +3266,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::SubgroupBallot { result, predicate }, span); - return Ok(Some(result)); + Some(result.into()) } "quadSwapX" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3270,7 +3289,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - return Ok(Some(result)); + Some(result.into()) } "quadSwapY" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3293,7 +3312,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - return Ok(Some(result)); + Some(result.into()) } "quadSwapDiagonal" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3316,7 +3335,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - return Ok(Some(result)); + Some(result.into()) } "coopLoad" | "coopLoadT" => { let row_major = function.name.ends_with("T"); @@ -3348,7 +3367,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; args.finish()?; - crate::Expression::CooperativeLoad { + Some(crate::Expression::CooperativeLoad { columns, rows, role, @@ -3357,7 +3376,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { stride, row_major, }, - } + }.into()) } "coopStore" | "coopStoreT" => { let row_major = function.name.ends_with("T"); @@ -3398,7 +3417,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - return Ok(None); + None } "coopMultiplyAdd" => { let mut args = ctx.prepare_args(arguments, 3, span); @@ -3407,14 +3426,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let c = self.expression(args.next()?, ctx)?; args.finish()?; - ir::Expression::CooperativeMultiplyAdd { a, b, c } + Some(ir::Expression::CooperativeMultiplyAdd { a, b, c }.into()) } _ => return Err(Box::new(Error::UnknownIdent(function.span, function.name))), } }; - let expr = ctx.append_expression(expr, span)?; - Ok(Some(expr)) + match result { + Some(MaybeHandle::Value(expr)) => Ok(Some(ctx.append_expression(expr, span)?)), + Some(MaybeHandle::Handle(handle)) => Ok(Some(handle)), + None => Ok(None), + } } /// Generate Naga IR for call expressions and statements, and type @@ -3444,6 +3466,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { is_statement: bool, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Option>> { + // `span` is for the entire call expression, while `function_span` is just + // for the function identifier. let function_span = function.span; match ctx.globals.get(function.name) { Some(&LoweredGlobalDecl::Type(ty)) => { @@ -3525,8 +3549,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(result) } - - None => self.call_builtin(span, function, arguments, result_ty, is_statement, ctx) + None => self.call_builtin(function, arguments, result_ty, is_statement, ctx) } } From 88b11420c73a9f0c74cbe804465e9532dd6341c9 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 11 Dec 2025 11:04:22 -0800 Subject: [PATCH 3/3] naga: Enforce `@must_use` for built-in functions --- CHANGELOG.md | 1 + cts_runner/test.lst | 29 +- naga/src/front/wgsl/lower/mod.rs | 306 +++++++++++++------- naga/tests/in/wgsl/subgroup-operations.wgsl | 50 ++-- naga/tests/naga/wgsl_errors.rs | 12 +- 5 files changed, 250 insertions(+), 148 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b0aee5d2cf..73d984c1f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,7 @@ Bottom level categories: - The validator checks that override-sized arrays have a positive size, if overrides have been resolved. By @andyleiserson in [#8822](https://github.com/gfx-rs/wgpu/pull/8822). - Fix some cases where f16 constants were not working. By @andyleiserson in [#8816](https://github.com/gfx-rs/wgpu/pull/8816). +- Naga now enforces the `@must_use` attribute on WGSL built-in functions, when applicable. You can waive the error with a phony assignment, e.g., `_ = subgroupElect()`. By @andyleiserson in [#8713](https://github.com/gfx-rs/wgpu/pull/8713). #### GLES diff --git a/cts_runner/test.lst b/cts_runner/test.lst index 4028a1ad45..5857ce6833 100644 --- a/cts_runner/test.lst +++ b/cts_runner/test.lst @@ -251,7 +251,24 @@ webgpu:shader,validation,expression,binary,short_circuiting_and_or:array_overrid webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_types:* webgpu:shader,validation,expression,binary,short_circuiting_and_or:scalar_vector:op="%26%26";lhs="bool";rhs="bool" webgpu:shader,validation,expression,call,builtin,all:arguments:test="ptr_deref" -webgpu:shader,validation,expression,call,builtin,max:values:* +webgpu:shader,validation,expression,call,builtin,arrayLength:* +webgpu:shader,validation,expression,call,builtin,barriers:* +webgpu:shader,validation,expression,call,builtin,cos:* +webgpu:shader,validation,expression,call,builtin,floor:* +webgpu:shader,validation,expression,call,builtin,fract:* +webgpu:shader,validation,expression,call,builtin,max:* +webgpu:shader,validation,expression,call,builtin,min:* +webgpu:shader,validation,expression,call,builtin,radians:* +webgpu:shader,validation,expression,call,builtin,sign:* +webgpu:shader,validation,expression,call,builtin,sin:* +webgpu:shader,validation,expression,call,builtin,step:* +webgpu:shader,validation,expression,call,builtin,tan:* +webgpu:shader,validation,expression,call,builtin,tanh:* +webgpu:shader,validation,expression,call,builtin,textureNumLayers:* +webgpu:shader,validation,expression,call,builtin,textureNumLevels:* +webgpu:shader,validation,expression,call,builtin,textureNumSamples:* +webgpu:shader,validation,expression,call,builtin,textureStore:* +webgpu:shader,validation,expression,call,builtin,trunc:* // FAIL: others in `value_constructor` due to https://github.com/gfx-rs/wgpu/issues/4720, possibly more webgpu:shader,validation,expression,call,builtin,value_constructor:array_value:* webgpu:shader,validation,expression,call,builtin,value_constructor:matrix_zero_value:* @@ -260,15 +277,7 @@ webgpu:shader,validation,expression,call,builtin,value_constructor:scalar_zero_v webgpu:shader,validation,expression,call,builtin,value_constructor:struct_value:* webgpu:shader,validation,expression,call,builtin,value_constructor:vector_splat:* webgpu:shader,validation,expression,call,builtin,value_constructor:vector_zero_value:* -// NOTE: This is supposed to be an exhaustive listing underneath -// `webgpu:shader,validation,expression,call,builtin,workgroupUniformLoad:*`, so exceptions can be -// worked around. -//FAIL: https://github.com/gfx-rs/wgpu/pull/8713 -// webgpu:shader,validation,expression,call,builtin,workgroupUniformLoad:must_use:use=false -webgpu:shader,validation,expression,call,builtin,workgroupUniformLoad:must_use:use=true -webgpu:shader,validation,expression,call,builtin,workgroupUniformLoad:no_atomics:* -webgpu:shader,validation,expression,call,builtin,workgroupUniformLoad:only_in_compute:* -webgpu:shader,validation,expression,call,builtin,workgroupUniformLoad:param_constructible_only:* +webgpu:shader,validation,expression,call,builtin,workgroupUniformLoad:* webgpu:shader,validation,extension,dual_source_blending:blend_src_syntax_validation:* webgpu:shader,validation,statement,statement_behavior:invalid_statements:body="break_if" webgpu:shader,validation,statement,statement_behavior:invalid_statements:body="break" diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index ca76301bf5..df9a9869f5 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2680,12 +2680,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { is_statement: bool, ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<'source, Option>> { + const MUST_USE_YES: bool = true; + const MUST_USE_NO: bool = false; + // We report all diagnostics associated with builtins with a span for // just the function identifier, unlike for other kinds of calls, where // we sometimes report with the span for the entire call expression. let span = function.span; - let result: Option> = if let Some(fun) = conv::map_relational_fun(function.name) { + let (result, must_use) = if let Some(fun) = conv::map_relational_fun(function.name) { let mut args = ctx.prepare_args(arguments, 1, span); let argument = self.expression(args.next()?, ctx)?; args.finish()?; @@ -2704,27 +2707,55 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ) }; - if argument_unmodified { + let result = if argument_unmodified { Some(argument.into()) } else { Some(ir::Expression::Relational { fun, argument }.into()) - } + }; + (result, MUST_USE_YES) } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) { let mut args = ctx.prepare_args(arguments, 1, span); let expr = self.expression(args.next()?, ctx)?; args.finish()?; - Some(ir::Expression::Derivative { axis, ctrl, expr }.into()) + ( + Some(ir::Expression::Derivative { axis, ctrl, expr }.into()), + MUST_USE_YES, + ) } else if let Some(fun) = conv::map_standard_fun(function.name) { - Some(self.math_function_helper(span, fun, arguments, ctx)?.into()) + ( + Some(self.math_function_helper(span, fun, arguments, ctx)?.into()), + MUST_USE_YES, + ) } else if let Some(fun) = Texture::map(function.name) { - Some(self.texture_sample_helper(fun, arguments, span, ctx)?.into()) + ( + Some( + self.texture_sample_helper(fun, arguments, span, ctx)? + .into(), + ), + MUST_USE_YES, + ) } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { - Some(self.subgroup_operation_helper(span, op, cop, arguments, ctx)?.into()) + ( + Some( + self.subgroup_operation_helper(span, op, cop, arguments, ctx)? + .into(), + ), + MUST_USE_YES, + ) } else if let Some(mode) = SubgroupGather::map(function.name) { - Some(self.subgroup_gather_helper(span, mode, arguments, ctx)?.into()) + ( + Some( + self.subgroup_gather_helper(span, mode, arguments, ctx)? + .into(), + ), + MUST_USE_YES, + ) } else if let Some(fun) = ir::AtomicFunction::map(function.name) { - self.atomic_helper(span, fun, arguments, is_statement, ctx)?.map(Into::into) + let result = self + .atomic_helper(span, fun, arguments, is_statement, ctx)? + .map(Into::into); + (result, MUST_USE_NO) } else { match function.name { "select" => { @@ -2742,16 +2773,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let diagnostic_details = |ctx: &ExpressionContext<'_, '_, '_>, - ty_res: &proc::TypeResolution, - orig_expr| { + ty_res: &proc::TypeResolution, + orig_expr| { ( ctx.ast_expressions.get_span(orig_expr), format!("`{}`", ctx.as_diagnostic_display(ty_res)), ) }; - for (&value, orig_value) in - values.iter().zip([reject_orig, accept_orig]) - { + for (&value, orig_value) in values.iter().zip([reject_orig, accept_orig]) { let value_ty_res = resolve!(ctx, value); if value_ty_res .inner_with(&ctx.module.types) @@ -2792,32 +2821,36 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let [reject, accept] = values; - Some(ir::Expression::Select { - reject, - accept, - condition, - } - .into()) + ( + Some( + ir::Expression::Select { + reject, + accept, + condition, + } + .into(), + ), + MUST_USE_YES, + ) } "arrayLength" => { let mut args = ctx.prepare_args(arguments, 1, span); let expr = self.expression(args.next()?, ctx)?; args.finish()?; - Some(ir::Expression::ArrayLength(expr).into()) + (Some(ir::Expression::ArrayLength(expr).into()), MUST_USE_YES) } "atomicLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); let (pointer, _scalar) = self.atomic_pointer(args.next()?, ctx)?; args.finish()?; - Some(ir::Expression::Load { pointer }.into()) + (Some(ir::Expression::Load { pointer }.into()), MUST_USE_NO) } "atomicStore" => { let mut args = ctx.prepare_args(arguments, 2, span); let (pointer, scalar) = self.atomic_pointer(args.next()?, ctx)?; - let value = - self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; + let value = self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; args.finish()?; let rctx = ctx.runtime_expression_ctx(span)?; @@ -2826,15 +2859,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rctx.emitter.start(&rctx.function.expressions); rctx.block .push(ir::Statement::Store { pointer, value }, span); - None + (None, MUST_USE_NO) } "atomicCompareExchangeWeak" => { let mut args = ctx.prepare_args(arguments, 3, span); let (pointer, scalar) = self.atomic_pointer(args.next()?, ctx)?; - let compare = - self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; + let compare = self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; let value = args.next()?; let value_span = ctx.ast_expressions.get_span(value); @@ -2845,17 +2877,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expression = match *resolve_inner!(ctx, value) { ir::TypeInner::Scalar(scalar) => ir::Expression::AtomicResult { ty: ctx.module.generate_predeclared_type( - ir::PredeclaredType::AtomicCompareExchangeWeakResult( - scalar, - ), + ir::PredeclaredType::AtomicCompareExchangeWeakResult(scalar), ), comparison: true, }, - _ => { - return Err(Box::new(Error::InvalidAtomicOperandType( - value_span, - ))) - } + _ => return Err(Box::new(Error::InvalidAtomicOperandType(value_span))), }; let result = ctx.interrupt_emitter(expression, span)?; @@ -2871,7 +2897,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - Some(result.into()) + (Some(result.into()), MUST_USE_NO) } "textureAtomicMin" | "textureAtomicMax" | "textureAtomicAdd" | "textureAtomicAnd" | "textureAtomicOr" | "textureAtomicXor" => { @@ -2915,7 +2941,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { value, }; rctx.block.push(stmt, span); - None + (None, MUST_USE_NO) } "storageBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; @@ -2923,7 +2949,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::ControlBarrier(ir::Barrier::STORAGE), span); - None + (None, MUST_USE_NO) } "workgroupBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; @@ -2931,7 +2957,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::ControlBarrier(ir::Barrier::WORK_GROUP), span); - None + (None, MUST_USE_NO) } "subgroupBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; @@ -2939,7 +2965,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::ControlBarrier(ir::Barrier::SUB_GROUP), span); - None + (None, MUST_USE_NO) } "textureBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; @@ -2947,7 +2973,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::ControlBarrier(ir::Barrier::TEXTURE), span); - None + (None, MUST_USE_NO) } "workgroupUniformLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3000,7 +3026,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, ); - Some(result.into()) + (Some(result.into()), MUST_USE_YES) } "textureStore" => { let mut args = ctx.prepare_args(arguments, 3, span); @@ -3024,8 +3050,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Err(Box::new(Error::NotStorageTexture(image_span))); }; - let value = - self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; + let value = self.expression_with_leaf_scalar(args.next()?, scalar, ctx)?; args.finish()?; @@ -3040,7 +3065,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { value, }; rctx.block.push(stmt, span); - None + (None, MUST_USE_NO) } "textureLoad" => { let mut args = ctx.prepare_args(arguments, 2, span); @@ -3074,13 +3099,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; - Some(ir::Expression::ImageLoad { - image, - coordinate, - array_index, - level, - sample, - }.into()) + ( + Some( + ir::Expression::ImageLoad { + image, + coordinate, + array_index, + level, + sample, + } + .into(), + ), + MUST_USE_YES, + ) } "textureDimensions" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3092,40 +3123,64 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .transpose()?; args.finish()?; - Some(ir::Expression::ImageQuery { - image, - query: ir::ImageQuery::Size { level }, - }.into()) + ( + Some( + ir::Expression::ImageQuery { + image, + query: ir::ImageQuery::Size { level }, + } + .into(), + ), + MUST_USE_YES, + ) } "textureNumLevels" => { let mut args = ctx.prepare_args(arguments, 1, span); let image = self.expression(args.next()?, ctx)?; args.finish()?; - Some(ir::Expression::ImageQuery { - image, - query: ir::ImageQuery::NumLevels, - }.into()) + ( + Some( + ir::Expression::ImageQuery { + image, + query: ir::ImageQuery::NumLevels, + } + .into(), + ), + MUST_USE_YES, + ) } "textureNumLayers" => { let mut args = ctx.prepare_args(arguments, 1, span); let image = self.expression(args.next()?, ctx)?; args.finish()?; - Some(ir::Expression::ImageQuery { - image, - query: ir::ImageQuery::NumLayers, - }.into()) + ( + Some( + ir::Expression::ImageQuery { + image, + query: ir::ImageQuery::NumLayers, + } + .into(), + ), + MUST_USE_YES, + ) } "textureNumSamples" => { let mut args = ctx.prepare_args(arguments, 1, span); let image = self.expression(args.next()?, ctx)?; args.finish()?; - Some(ir::Expression::ImageQuery { - image, - query: ir::ImageQuery::NumSamples, - }.into()) + ( + Some( + ir::Expression::ImageQuery { + image, + query: ir::ImageQuery::NumSamples, + } + .into(), + ), + MUST_USE_YES, + ) } "rayQueryInitialize" => { let mut args = ctx.prepare_args(arguments, 3, span); @@ -3146,7 +3201,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rctx.emitter.start(&rctx.function.expressions); rctx.block .push(ir::Statement::RayQuery { query, fun }, span); - None + (None, MUST_USE_NO) } "getCommittedHitVertexPositions" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3155,10 +3210,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let _ = ctx.module.generate_vertex_return_type(); - Some(ir::Expression::RayQueryVertexPositions { - query, - committed: true, - }.into()) + ( + Some( + ir::Expression::RayQueryVertexPositions { + query, + committed: true, + } + .into(), + ), + MUST_USE_NO, + ) } "getCandidateHitVertexPositions" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3167,10 +3228,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let _ = ctx.module.generate_vertex_return_type(); - Some(ir::Expression::RayQueryVertexPositions { - query, - committed: false, - }.into()) + ( + Some( + ir::Expression::RayQueryVertexPositions { + query, + committed: false, + } + .into(), + ), + MUST_USE_NO, + ) } "rayQueryProceed" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3183,7 +3250,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, span); - Some(result.into()) + (Some(result.into()), MUST_USE_NO) } "rayQueryGenerateIntersection" => { let mut args = ctx.prepare_args(arguments, 2, span); @@ -3195,7 +3262,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, span); - None + (None, MUST_USE_NO) } "rayQueryConfirmIntersection" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3206,7 +3273,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, span); - None + (None, MUST_USE_NO) } "rayQueryTerminate" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3217,7 +3284,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::RayQuery { query, fun }, span); - None + (None, MUST_USE_NO) } "rayQueryGetCommittedIntersection" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3225,10 +3292,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; let _ = ctx.module.generate_ray_intersection_type(); - Some(ir::Expression::RayQueryGetIntersection { - query, - committed: true, - }.into()) + ( + Some( + ir::Expression::RayQueryGetIntersection { + query, + committed: true, + } + .into(), + ), + MUST_USE_NO, + ) } "rayQueryGetCandidateIntersection" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3236,10 +3309,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; let _ = ctx.module.generate_ray_intersection_type(); - Some(ir::Expression::RayQueryGetIntersection { - query, - committed: false, - }.into()) + ( + Some( + ir::Expression::RayQueryGetIntersection { + query, + committed: false, + } + .into(), + ), + MUST_USE_NO, + ) } "RayDesc" => { let ty = ctx.module.generate_ray_desc_type(); @@ -3250,7 +3329,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { arguments, ctx, )?; - Some(handle.into()) + (Some(handle.into()), MUST_USE_NO) } "subgroupBallot" => { let mut args = ctx.prepare_args(arguments, 0, span); @@ -3266,7 +3345,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(ir::Statement::SubgroupBallot { result, predicate }, span); - Some(result.into()) + (Some(result.into()), MUST_USE_YES) } "quadSwapX" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3289,7 +3368,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - Some(result.into()) + (Some(result.into()), MUST_USE_YES) } "quadSwapY" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3312,7 +3391,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - Some(result.into()) + (Some(result.into()), MUST_USE_YES) } "quadSwapDiagonal" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3335,7 +3414,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - Some(result.into()) + (Some(result.into()), MUST_USE_YES) } "coopLoad" | "coopLoadT" => { let row_major = function.name.ends_with("T"); @@ -3367,16 +3446,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; args.finish()?; - Some(crate::Expression::CooperativeLoad { - columns, - rows, - role, - data: crate::CooperativeData { - pointer, - stride, - row_major, - }, - }.into()) + ( + Some( + crate::Expression::CooperativeLoad { + columns, + rows, + role, + data: crate::CooperativeData { + pointer, + stride, + row_major, + }, + } + .into(), + ), + MUST_USE_YES, + ) } "coopStore" | "coopStoreT" => { let row_major = function.name.ends_with("T"); @@ -3417,7 +3502,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - None + (None, MUST_USE_NO) } "coopMultiplyAdd" => { let mut args = ctx.prepare_args(arguments, 3, span); @@ -3426,12 +3511,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let c = self.expression(args.next()?, ctx)?; args.finish()?; - Some(ir::Expression::CooperativeMultiplyAdd { a, b, c }.into()) + ( + Some(ir::Expression::CooperativeMultiplyAdd { a, b, c }.into()), + MUST_USE_YES, + ) } _ => return Err(Box::new(Error::UnknownIdent(function.span, function.name))), } }; + if must_use && is_statement { + return Err(Box::new(Error::FunctionMustUseUnused(function.span))); + } + match result { Some(MaybeHandle::Value(expr)) => Ok(Some(ctx.append_expression(expr, span)?)), Some(MaybeHandle::Handle(handle)) => Ok(Some(handle)), @@ -3549,7 +3641,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(result) } - None => self.call_builtin(function, arguments, result_ty, is_statement, ctx) + None => self.call_builtin(function, arguments, result_ty, is_statement, ctx), } } diff --git a/naga/tests/in/wgsl/subgroup-operations.wgsl b/naga/tests/in/wgsl/subgroup-operations.wgsl index 26b3d98e84..fae0197f5f 100644 --- a/naga/tests/in/wgsl/subgroup-operations.wgsl +++ b/naga/tests/in/wgsl/subgroup-operations.wgsl @@ -11,32 +11,32 @@ fn main( ) { subgroupBarrier(); - subgroupBallot((subgroup_invocation_id & 1u) == 1u); - subgroupBallot(); + _ = subgroupBallot((subgroup_invocation_id & 1u) == 1u); + _ = subgroupBallot(); - subgroupAll(subgroup_invocation_id != 0u); - subgroupAny(subgroup_invocation_id == 0u); - subgroupAdd(subgroup_invocation_id); - subgroupMul(subgroup_invocation_id); - subgroupMin(subgroup_invocation_id); - subgroupMax(subgroup_invocation_id); - subgroupAnd(subgroup_invocation_id); - subgroupOr(subgroup_invocation_id); - subgroupXor(subgroup_invocation_id); - subgroupExclusiveAdd(subgroup_invocation_id); - subgroupExclusiveMul(subgroup_invocation_id); - subgroupInclusiveAdd(subgroup_invocation_id); - subgroupInclusiveMul(subgroup_invocation_id); + _ = subgroupAll(subgroup_invocation_id != 0u); + _ = subgroupAny(subgroup_invocation_id == 0u); + _ = subgroupAdd(subgroup_invocation_id); + _ = subgroupMul(subgroup_invocation_id); + _ = subgroupMin(subgroup_invocation_id); + _ = subgroupMax(subgroup_invocation_id); + _ = subgroupAnd(subgroup_invocation_id); + _ = subgroupOr(subgroup_invocation_id); + _ = subgroupXor(subgroup_invocation_id); + _ = subgroupExclusiveAdd(subgroup_invocation_id); + _ = subgroupExclusiveMul(subgroup_invocation_id); + _ = subgroupInclusiveAdd(subgroup_invocation_id); + _ = subgroupInclusiveMul(subgroup_invocation_id); - subgroupBroadcastFirst(subgroup_invocation_id); - subgroupBroadcast(subgroup_invocation_id, 4u); - subgroupShuffle(subgroup_invocation_id, sizes.subgroup_size - 1u - subgroup_invocation_id); - subgroupShuffleDown(subgroup_invocation_id, 1u); - subgroupShuffleUp(subgroup_invocation_id, 1u); - subgroupShuffleXor(subgroup_invocation_id, sizes.subgroup_size - 1u); + _ = subgroupBroadcastFirst(subgroup_invocation_id); + _ = subgroupBroadcast(subgroup_invocation_id, 4u); + _ = subgroupShuffle(subgroup_invocation_id, sizes.subgroup_size - 1u - subgroup_invocation_id); + _ = subgroupShuffleDown(subgroup_invocation_id, 1u); + _ = subgroupShuffleUp(subgroup_invocation_id, 1u); + _ = subgroupShuffleXor(subgroup_invocation_id, sizes.subgroup_size - 1u); - quadBroadcast(subgroup_invocation_id, 4u); - quadSwapX(subgroup_invocation_id); - quadSwapY(subgroup_invocation_id); - quadSwapDiagonal(subgroup_invocation_id); + _ = quadBroadcast(subgroup_invocation_id, 4u); + _ = quadSwapX(subgroup_invocation_id); + _ = quadSwapY(subgroup_invocation_id); + _ = quadSwapDiagonal(subgroup_invocation_id); } diff --git a/naga/tests/naga/wgsl_errors.rs b/naga/tests/naga/wgsl_errors.rs index 96c4da2cdb..5a9a2e99fb 100644 --- a/naga/tests/naga/wgsl_errors.rs +++ b/naga/tests/naga/wgsl_errors.rs @@ -3923,7 +3923,7 @@ fn subgroup_capability() { &format!(" {stage_attr} fn main() {{ - subgroupBallot(); + _ = subgroupBallot(); }} "), Err(naga::valid::ValidationError::EntryPoint { @@ -3948,7 +3948,7 @@ fn subgroup_capability() { " {stage_attr} fn main() {{ - subgroupBallot(); + _ = subgroupBallot(); }} " ), @@ -3965,7 +3965,7 @@ fn subgroup_capability() { " @vertex fn main() -> @builtin(position) vec4 {{ - subgroupBallot(); + _ = subgroupBallot(); return vec4(); }} ": @@ -3977,7 +3977,7 @@ fn subgroup_capability() { " @vertex fn main() -> @builtin(position) vec4 {{ - subgroupBallot(); + _ = subgroupBallot(); return vec4(); }} ", @@ -4074,7 +4074,7 @@ fn subgroup_invalid_broadcast() { check_validation! { r#" fn main(id: u32) { - subgroupBroadcast(123, id); + _ = subgroupBroadcast(123, id); } "#: Err(naga::valid::ValidationError::Function { @@ -4088,7 +4088,7 @@ fn subgroup_invalid_broadcast() { check_validation! { r#" fn main(id: u32) { - quadBroadcast(123, id); + _ = quadBroadcast(123, id); } "#: Err(naga::valid::ValidationError::Function {