diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 10fd5d72aa..648affb221 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2079,8 +2079,47 @@ impl<'w> BlockContext<'w> { value_id, ) } - crate::AtomicFunction::Exchange { compare: Some(_) } => { - return Err(Error::FeatureNotImplemented("atomic CompareExchange")); + crate::AtomicFunction::Exchange { compare: Some(cmp) } => { + // TODO: look this up from the atomic expression's scalar type so that it works with i32 as well + let scalar_u32 = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: None, + })); + let bool_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + pointer_space: None, + })); + + let cas_result_id = self.gen_id(); + let equality_result_id = self.gen_id(); + let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange); + cas_instr.set_type(scalar_u32); + cas_instr.set_result(cas_result_id); + cas_instr.add_operand(pointer_id); + cas_instr.add_operand(scope_constant_id); + cas_instr.add_operand(semantics_id); // semantics if equal + cas_instr.add_operand(semantics_id); // semantics if not equal + cas_instr.add_operand(value_id); + cas_instr.add_operand(self.cached[cmp]); + block.body.push(cas_instr); + block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool_type_id, + equality_result_id, + cas_result_id, + self.cached[cmp], + )); + Instruction::composite_construct( + result_type_id, + id, + &[cas_result_id, equality_result_id], + ) } }; diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 2873e6c73c..91feb9dfd6 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -1628,7 +1628,7 @@ impl Parser { crate::TypeInner::Scalar { kind, width } => crate::Expression::AtomicResult { kind, width, - comparison: false, + comparison: None, }, _ => return Err(Error::InvalidAtomicOperandType(value_span)), }; @@ -1857,10 +1857,50 @@ impl Parser { let expression = match *ctx.resolve_type(value)? { crate::TypeInner::Scalar { kind, width } => { + let bool_ty = ctx.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + }, + }, + NagaSpan::UNDEFINED, + ); + let scalar_ty = ctx.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind, width }, + }, + NagaSpan::UNDEFINED, + ); + let struct_ty = ctx.types.insert( + crate::Type { + name: Some("__atomic_compare_exchange_result".to_string()), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("old_value".to_string()), + ty: scalar_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("exchanged".to_string()), + ty: bool_ty, + binding: None, + offset: 4, + }, + ], + span: 8, + }, + }, + NagaSpan::UNDEFINED, + ); crate::Expression::AtomicResult { kind, width, - comparison: true, + comparison: Some(struct_ty), } } _ => return Err(Error::InvalidAtomicOperandType(value_span)), diff --git a/src/lib.rs b/src/lib.rs index e122d1224c..3c61fb2e69 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1404,7 +1404,7 @@ pub enum Expression { AtomicResult { kind: ScalarKind, width: Bytes, - comparison: bool, + comparison: Option>, }, /// Get the length of an array. /// The expression must resolve to a pointer to an array with a dynamic size. diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 9a5922ea76..7669708e8a 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -649,12 +649,8 @@ impl<'a> ResolveContext<'a> { width, comparison, } => { - if comparison { - TypeResolution::Value(Ti::Vector { - size: crate::VectorSize::Bi, - kind, - width, - }) + if let Some(struct_ty) = comparison { + TypeResolution::Handle(struct_ty) } else { TypeResolution::Value(Ti::Scalar { kind, width }) } diff --git a/src/valid/function.rs b/src/valid/function.rs index 5684f670fe..ee02315e7a 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -363,11 +363,17 @@ impl super::Validator { .into_other()); } match context.expressions[result] { - //TODO: support atomic result with comparison + //TODO: does the result of an atomicCompareExchange need additional validation, or does the existing validation for + // the struct type it returns suffice? crate::Expression::AtomicResult { kind, width, - comparison: false, + comparison: Some(_), + } if kind == ptr_kind && width == ptr_width => {} + crate::Expression::AtomicResult { + kind, + width, + comparison: None, } if kind == ptr_kind && width == ptr_width => {} _ => { return Err(AtomicError::ResultTypeMismatch(result)