diff --git a/cts_runner/test.lst b/cts_runner/test.lst index e9e338eb7b5..40174b75fec 100644 --- a/cts_runner/test.lst +++ b/cts_runner/test.lst @@ -209,6 +209,7 @@ webgpu:shader,execution,flow_control,return:* // Fails on Metal in CI only, not when running locally. fails-if(metal) webgpu:shader,execution,robust_access_vertex:vertex_buffer_access:indexed=true;indirect=false;drawCallTestParameter="baseVertex";type="float32x4";additionalBuffers=4;partialLastNumber=false;offsetVertexBuffer=true webgpu:shader,validation,const_assert,const_assert:* +webgpu:shader,validation,expression,access,array:early_eval_errors:case="override_in_bounds" webgpu:shader,validation,expression,binary,short_circuiting_and_or:array_override:op="%26%26";a_val=1;b_val=1 webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_types:* webgpu:shader,validation,expression,binary,short_circuiting_and_or:scalar_vector:op="%26%26";lhs="bool";rhs="bool" diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 045bb9fa15b..fdcc0b4d4db 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -280,7 +280,7 @@ fn process_workgroup_size_override( Some(h) => { ep.workgroup_size[i] = module .to_ctx() - .eval_expr_to_u32(adjusted_global_expressions[h]) + .get_const_val(adjusted_global_expressions[h]) .map(|n| { if n == 0 { Err(PipelineConstantError::NegativeWorkgroupSize) @@ -309,13 +309,13 @@ fn process_mesh_shader_overrides( if let Some(r#override) = mesh_info.max_vertices_override { mesh_info.max_vertices = module .to_ctx() - .eval_expr_to_u32(adjusted_global_expressions[r#override]) + .get_const_val(adjusted_global_expressions[r#override]) .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?; } if let Some(r#override) = mesh_info.max_primitives_override { mesh_info.max_primitives = module .to_ctx() - .eval_expr_to_u32(adjusted_global_expressions[r#override]) + .get_const_val(adjusted_global_expressions[r#override]) .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?; } } diff --git a/naga/src/front/glsl/context.rs b/naga/src/front/glsl/context.rs index 94f09ba022d..8cb4ce53223 100644 --- a/naga/src/front/glsl/context.rs +++ b/naga/src/front/glsl/context.rs @@ -561,7 +561,7 @@ impl<'a> Context<'a> { _ => self .module .to_ctx() - .eval_expr_to_u32_from(index, &self.expressions) + .get_const_val_from(index, &self.expressions) .ok(), }; diff --git a/naga/src/front/glsl/parser.rs b/naga/src/front/glsl/parser.rs index 79001d8700f..959a9b8e60d 100644 --- a/naga/src/front/glsl/parser.rs +++ b/naga/src/front/glsl/parser.rs @@ -14,7 +14,7 @@ use super::{ variables::{GlobalOrConstant, VarDeclaration}, Frontend, Result, }; -use crate::{arena::Handle, proc::U32EvalError, Expression, Module, Span, Type}; +use crate::{arena::Handle, proc::ConstValueError, Expression, Module, Span, Type}; mod declarations; mod expressions; @@ -211,15 +211,15 @@ impl<'source> ParsingContext<'source> { ctx.global_expression_kind_tracker, )?; - let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr); + let res = ctx.module.to_ctx().get_const_val(const_expr); let int = match res { Ok(value) => Ok(value), - Err(U32EvalError::Negative) => Err(Error { + Err(ConstValueError::Negative) => Err(Error { kind: ErrorKind::SemanticError("int constant overflows".into()), meta, }), - Err(U32EvalError::NonConst) => Err(Error { + Err(ConstValueError::NonConst | ConstValueError::InvalidType) => Err(Error { kind: ErrorKind::SemanticError("Expected a uint constant".into()), meta, }), diff --git a/naga/src/front/spv/next_block.rs b/naga/src/front/spv/next_block.rs index 2a995090149..d3bc5db686d 100644 --- a/naga/src/front/spv/next_block.rs +++ b/naga/src/front/spv/next_block.rs @@ -277,7 +277,7 @@ impl> Frontend { let index_maybe = match *index_expr_data { crate::Expression::Constant(const_handle) => Some( ctx.gctx() - .eval_expr_to_u32(ctx.module.constants[const_handle].init) + .get_const_val(ctx.module.constants[const_handle].init) .map_err(|_| { Error::InvalidAccess(crate::Expression::Constant( const_handle, diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 14525c76fb7..c1a0293a1fd 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -535,50 +535,28 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .map_err(|e| Box::new(Error::ConstantEvaluatorError(e.into(), span))) } - fn const_eval_expr_to_u32( + fn get_const_val>( &self, handle: Handle, - ) -> core::result::Result { + ) -> core::result::Result { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => { if !ctx.local_expression_kind_tracker.is_const(handle) { - return Err(proc::U32EvalError::NonConst); + return Err(proc::ConstValueError::NonConst); } self.module .to_ctx() - .eval_expr_to_u32_from(handle, &ctx.function.expressions) + .get_const_val_from(handle, &ctx.function.expressions) } ExpressionContextType::Constant(Some(ref ctx)) => { assert!(ctx.local_expression_kind_tracker.is_const(handle)); self.module .to_ctx() - .eval_expr_to_u32_from(handle, &ctx.function.expressions) + .get_const_val_from(handle, &ctx.function.expressions) } - ExpressionContextType::Constant(None) => self.module.to_ctx().eval_expr_to_u32(handle), - ExpressionContextType::Override => Err(proc::U32EvalError::NonConst), - } - } - - fn const_eval_expr_to_bool(&self, handle: Handle) -> Option { - match self.expr_type { - ExpressionContextType::Runtime(ref ctx) => { - if !ctx.local_expression_kind_tracker.is_const(handle) { - return None; - } - - self.module - .to_ctx() - .eval_expr_to_bool_from(handle, &ctx.function.expressions) - } - ExpressionContextType::Constant(Some(ref ctx)) => { - assert!(ctx.local_expression_kind_tracker.is_const(handle)); - self.module - .to_ctx() - .eval_expr_to_bool_from(handle, &ctx.function.expressions) - } - ExpressionContextType::Constant(None) => self.module.to_ctx().eval_expr_to_bool(handle), - ExpressionContextType::Override => None, + ExpressionContextType::Constant(None) => self.module.to_ctx().get_const_val(handle), + ExpressionContextType::Override => Err(proc::ConstValueError::NonConst), } } @@ -716,12 +694,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { let index = self .module .to_ctx() - .eval_expr_to_u32_from(expr, &rctx.function.expressions) + .get_const_val_from::(expr, &rctx.function.expressions) .map_err(|err| match err { - proc::U32EvalError::NonConst => { + proc::ConstValueError::NonConst | proc::ConstValueError::InvalidType => { Error::ExpectedConstExprConcreteIntegerScalar(component_span) } - proc::U32EvalError::Negative => Error::ExpectedNonNegative(component_span), + proc::ConstValueError::Negative => { + Error::ExpectedNonNegative(component_span) + } })?; ir::SwizzleComponent::XYZW .get(index as usize) @@ -1417,11 +1397,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { match ctx .module .to_ctx() - .eval_expr_to_bool_from(condition, &ctx.module.global_expressions) + .get_const_val_from(condition, &ctx.module.global_expressions) { - Some(true) => Ok(()), - Some(false) => Err(Error::ConstAssertFailed(span)), - _ => Err(Error::NotBool(span)), + Ok(true) => Ok(()), + Ok(false) => Err(Error::ConstAssertFailed(span)), + Err(proc::ConstValueError::NonConst | proc::ConstValueError::Negative) => { + unreachable!() + } + Err(proc::ConstValueError::InvalidType) => Err(Error::NotBool(span)), }?; } } @@ -1940,14 +1923,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { match ctx .module .to_ctx() - .eval_expr_to_literal_from(expr, &ctx.function.expressions) + .get_const_val_from(expr, &ctx.function.expressions) { - Some(ir::Literal::I32(value)) => { - ir::SwitchValue::I32(value) - } - Some(ir::Literal::U32(value)) => { - ir::SwitchValue::U32(value) - } + Ok(ir::Literal::I32(value)) => ir::SwitchValue::I32(value), + Ok(ir::Literal::U32(value)) => ir::SwitchValue::U32(value), _ => { return Err(Box::new(Error::InvalidSwitchCase { span, @@ -2168,11 +2147,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { match ctx .module .to_ctx() - .eval_expr_to_bool_from(condition, &ctx.function.expressions) + .get_const_val_from(condition, &ctx.function.expressions) { - Some(true) => Ok(()), - Some(false) => Err(Error::ConstAssertFailed(span)), - _ => Err(Error::NotBool(span)), + Ok(true) => Ok(()), + Ok(false) => Err(Error::ConstAssertFailed(span)), + Err(proc::ConstValueError::NonConst | proc::ConstValueError::Negative) => { + unreachable!() + } + Err(proc::ConstValueError::InvalidType) => Err(Error::NotBool(span)), }?; block.extend(emitter.finish(&ctx.function.expressions)); @@ -2370,7 +2352,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } - lowered_base.try_map(|base| match ctx.const_eval_expr_to_u32(index).ok() { + lowered_base.try_map(|base| match ctx.get_const_val(index).ok() { Some(index) => Ok::<_, Box>(ir::Expression::AccessIndex { base, index }), None => { // When an abstract array value e is indexed by an expression @@ -2565,7 +2547,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { result_var, ))) } else { - let left_val = ctx.const_eval_expr_to_bool(left); + let left_val: Option = ctx.get_const_val(left).ok(); if left_val.is_some_and(|left_val| { op == crate::BinaryOperator::LogicalAnd && !left_val @@ -4201,10 +4183,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let value = ctx .module .to_ctx() - .eval_expr_to_u32(expr) + .get_const_val(expr) .map_err(|err| match err { - proc::U32EvalError::NonConst => Error::ExpectedConstExprConcreteIntegerScalar(span), - proc::U32EvalError::Negative => Error::ExpectedNonNegative(span), + proc::ConstValueError::NonConst | proc::ConstValueError::InvalidType => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } + proc::ConstValueError::Negative => Error::ExpectedNonNegative(span), })?; Ok((value, span)) } @@ -4220,12 +4204,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let const_expr = self.expression(expr, &mut ctx.as_const()); match const_expr { Ok(value) => { - let len = ctx.const_eval_expr_to_u32(value).map_err(|err| { + let len = ctx.get_const_val(value).map_err(|err| { Box::new(match err { - proc::U32EvalError::NonConst => { + proc::ConstValueError::NonConst + | proc::ConstValueError::InvalidType => { Error::ExpectedConstExprConcreteIntegerScalar(span) } - proc::U32EvalError::Negative => { + proc::ConstValueError::Negative => { Error::ExpectedPositiveArrayLength(span) } }) diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 4c471179add..c9ff886542b 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -1268,7 +1268,11 @@ impl<'a> ConstantEvaluator<'a> { let base = self.check_and_get(base)?; let index = self.check_and_get(index)?; - self.access(base, self.constant_index(index)?, span) + let index_val: u32 = self + .to_ctx() + .get_const_val_from(index, self.expressions) + .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?; + self.access(base, index_val as usize, span) } Expression::Swizzle { size, @@ -2116,24 +2120,6 @@ impl<'a> ConstantEvaluator<'a> { } } - fn constant_index(&self, expr: Handle) -> Result { - match self.expressions[expr] { - Expression::ZeroValue(ty) - if matches!( - self.types[ty].inner, - TypeInner::Scalar(crate::Scalar { - kind: ScalarKind::Uint, - .. - }) - ) => - { - Ok(0) - } - Expression::Literal(Literal::U32(index)) => Ok(index as usize), - _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy), - } - } - /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions. /// /// [`ZeroValue`]: Expression::ZeroValue diff --git a/naga/src/proc/index.rs b/naga/src/proc/index.rs index 87eaac7775c..991b8ca03d7 100644 --- a/naga/src/proc/index.rs +++ b/naga/src/proc/index.rs @@ -483,7 +483,7 @@ impl GuardedIndex { expressions: &crate::Arena, module: &crate::Module, ) -> Self { - match module.to_ctx().eval_expr_to_u32_from(expr, expressions) { + match module.to_ctx().get_const_val_from(expr, expressions) { Ok(value) => Self::Known(value), Err(_) => Self::Expression(expr), } diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index bc1e72e1132..482914b5290 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -172,6 +172,29 @@ impl crate::Literal { } } +impl TryFrom for u32 { + type Error = ConstValueError; + + fn try_from(value: crate::Literal) -> Result { + match value { + crate::Literal::U32(value) => Ok(value), + crate::Literal::I32(value) => value.try_into().map_err(|_| ConstValueError::Negative), + _ => Err(ConstValueError::InvalidType), + } + } +} + +impl TryFrom for bool { + type Error = ConstValueError; + + fn try_from(value: crate::Literal) -> Result { + match value { + crate::Literal::Bool(value) => Ok(value), + _ => Err(ConstValueError::InvalidType), + } + } +} + impl super::AddressSpace { pub fn access(self) -> crate::StorageAccess { use crate::StorageAccess as Sa; @@ -425,9 +448,16 @@ impl crate::Module { } #[derive(Debug)] -pub(super) enum U32EvalError { +pub enum ConstValueError { NonConst, Negative, + InvalidType, +} + +impl From for ConstValueError { + fn from(_: core::convert::Infallible) -> Self { + unreachable!() + } } #[derive(Clone, Copy)] @@ -452,63 +482,26 @@ impl GlobalCtx<'_> { )), allow(dead_code) )] - pub(super) fn eval_expr_to_u32( - &self, - handle: crate::Handle, - ) -> Result { - self.eval_expr_to_u32_from(handle, self.global_expressions) - } - - /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`. - pub(super) fn eval_expr_to_u32_from( - &self, - handle: crate::Handle, - arena: &crate::Arena, - ) -> Result { - match self.eval_expr_to_literal_from(handle, arena) { - Some(crate::Literal::U32(value)) => Ok(value), - Some(crate::Literal::I32(value)) => { - value.try_into().map_err(|_| U32EvalError::Negative) - } - _ => Err(U32EvalError::NonConst), - } - } - - /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `bool`. - #[cfg_attr(not(feature = "wgsl-in"), allow(dead_code))] - pub(super) fn eval_expr_to_bool( - &self, - handle: crate::Handle, - ) -> Option { - self.eval_expr_to_bool_from(handle, self.global_expressions) - } - - /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `bool`. - #[cfg_attr(not(feature = "wgsl-in"), allow(dead_code))] - pub(super) fn eval_expr_to_bool_from( + pub(super) fn get_const_val( &self, handle: crate::Handle, - arena: &crate::Arena, - ) -> Option { - match self.eval_expr_to_literal_from(handle, arena) { - Some(crate::Literal::Bool(value)) => Some(value), - _ => None, - } + ) -> Result + where + T: TryFrom, + E: Into, + { + self.get_const_val_from(handle, self.global_expressions) } - #[expect(dead_code)] - pub(crate) fn eval_expr_to_literal( - &self, - handle: crate::Handle, - ) -> Option { - self.eval_expr_to_literal_from(handle, self.global_expressions) - } - - pub(super) fn eval_expr_to_literal_from( + pub(super) fn get_const_val_from( &self, handle: crate::Handle, arena: &crate::Arena, - ) -> Option { + ) -> Result + where + T: TryFrom, + E: Into, + { fn get( gctx: GlobalCtx, handle: crate::Handle, @@ -523,11 +516,15 @@ impl GlobalCtx<'_> { _ => None, } } - match arena[handle] { + let value = match arena[handle] { crate::Expression::Constant(c) => { get(*self, self.constants[c].init, self.global_expressions) } _ => get(*self, handle, arena), + }; + match value { + Some(v) => v.try_into().map_err(Into::into), + None => Err(ConstValueError::NonConst), } } @@ -561,9 +558,11 @@ impl crate::ArraySize { let Some(expr) = gctx.overrides[handle].init else { return Err(ResolveArraySizeError::NonConstArrayLength); }; - let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err { - U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength, - U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength, + let length = gctx.get_const_val(expr).map_err(|err| match err { + ConstValueError::NonConst => ResolveArraySizeError::NonConstArrayLength, + ConstValueError::Negative | ConstValueError::InvalidType => { + ResolveArraySizeError::ExpectedPositiveArrayLength + } })?; if length == 0 { diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 13ebb8461b7..7f2d4bb0dc9 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -287,7 +287,7 @@ impl super::Validator { // If index is const we can do check for non-negative index match module .to_ctx() - .eval_expr_to_u32_from(index, &function.expressions) + .get_const_val_from(index, &function.expressions) { Ok(value) => { let length = if self.overrides_resolved { @@ -303,10 +303,13 @@ impl super::Validator { } } } - Err(crate::proc::U32EvalError::Negative) => { + Err(crate::proc::ConstValueError::Negative) => { return Err(ExpressionError::NegativeIndex(base)) } - Err(crate::proc::U32EvalError::NonConst) => {} + Err(crate::proc::ConstValueError::NonConst) => {} + Err(crate::proc::ConstValueError::InvalidType) => { + return Err(ExpressionError::InvalidIndexType(index)) + } } ShaderStages::all()