Skip to content

Commit

Permalink
[naga wgsl] Let unary operators accept and produce abstract types.
Browse files Browse the repository at this point in the history
Fixes #4445.
Fixes #4492.
Fixes #4435.
  • Loading branch information
jimblandy authored and teoxoy committed Dec 14, 2023
1 parent c4b4387 commit d9d051b
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 188 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ This feature allowed you to call `global_id` on any wgpu opaque handle to get a

#### Naga

- Naga's WGSL front end now allows binary operators to produce values with abstract types, rather than concretizing thir operands. By @jimblandy in [#4850](https://github.com/gfx-rs/wgpu/pull/4850).
- Naga's WGSL front end now allows operators to produce values with abstract types, rather than concretizing thir operands. By @jimblandy in [#4850](https://github.com/gfx-rs/wgpu/pull/4850) and [#4870](https://github.com/gfx-rs/wgpu/pull/4870).

- Naga's WGSL front and back ends now have experimental support for 64-bit floating-point literals: `1.0lf` denotes an `f64` value. There has been experimental support for an `f64` type for a while, but until now there was no syntax for writing literals with that type. As before, Naga module validation rejects `f64` values unless `naga::valid::Capabilities::FLOAT64` is requested. By @jimblandy in [#4747](https://github.com/gfx-rs/wgpu/pull/4747).

Expand Down
11 changes: 10 additions & 1 deletion naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,16 @@ impl<W: Write> Writer<W> {
Expression::Literal(literal) => match literal {
crate::Literal::F32(value) => write!(self.out, "{}f", value)?,
crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
crate::Literal::I32(value) => write!(self.out, "{}i", value)?,
crate::Literal::I32(value) => {
// `-2147483648i` is not valid WGSL. The most negative `i32`
// value can only be expressed in WGSL using AbstractInt and
// a unary negation operator.
if value == i32::MIN {
write!(self.out, "i32(-2147483648)")?;
} else {
write!(self.out, "{}i", value)?;
}
}
crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
crate::Literal::F64(value) => write!(self.out, "{:?}lf", value)?,
crate::Literal::I64(_) => {
Expand Down
2 changes: 1 addition & 1 deletion naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(Typed::Plain(handle));
}
ast::Expression::Unary { op, expr } => {
let expr = self.expression(expr, ctx)?;
let expr = self.expression_for_abstract(expr, ctx)?;
Typed::Plain(crate::Expression::Unary { op, expr })
}
ast::Expression::AddrOf(expr) => {
Expand Down
103 changes: 42 additions & 61 deletions naga/src/front/wgsl/parse/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ fn consume_token(input: &str, generic: bool) -> (Token<'_>, &str) {
let og_chars = chars.as_str();
match chars.next() {
Some('>') => (Token::Arrow, chars.as_str()),
Some('0'..='9' | '.') => consume_number(input),
Some('-') => (Token::DecrementOperation, chars.as_str()),
Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
_ => (Token::Operation(cur), og_chars),
Expand Down Expand Up @@ -496,44 +495,60 @@ fn test_numbers() {

// MIN / MAX //

// min / max decimal signed integer
// min / max decimal integer
sub_test(
"-2147483648i 2147483647i -2147483649i 2147483648i",
"0i 2147483647i 2147483648i",
&[
Token::Number(Ok(Number::I32(i32::MIN))),
Token::Number(Ok(Number::I32(0))),
Token::Number(Ok(Number::I32(i32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
Token::Number(Err(NumberError::NotRepresentable)),
],
);
// min / max decimal unsigned integer
sub_test(
"0u 4294967295u -1u 4294967296u",
"0u 4294967295u 4294967296u",
&[
Token::Number(Ok(Number::U32(u32::MIN))),
Token::Number(Ok(Number::U32(u32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
Token::Number(Err(NumberError::NotRepresentable)),
],
);

// min / max hexadecimal signed integer
sub_test(
"-0x80000000i 0x7FFFFFFFi -0x80000001i 0x80000000i",
"0x0i 0x7FFFFFFFi 0x80000000i",
&[
Token::Number(Ok(Number::I32(i32::MIN))),
Token::Number(Ok(Number::I32(0))),
Token::Number(Ok(Number::I32(i32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
Token::Number(Err(NumberError::NotRepresentable)),
],
);
// min / max hexadecimal unsigned integer
sub_test(
"0x0u 0xFFFFFFFFu -0x1u 0x100000000u",
"0x0u 0xFFFFFFFFu 0x100000000u",
&[
Token::Number(Ok(Number::U32(u32::MIN))),
Token::Number(Ok(Number::U32(u32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
],
);

// min/max decimal abstract int
sub_test(
"0 9223372036854775807 9223372036854775808",
&[
Token::Number(Ok(Number::AbstractInt(0))),
Token::Number(Ok(Number::AbstractInt(i64::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
],
);

// min/max hexadecimal abstract int
sub_test(
"0 0x7fffffffffffffff 0x8000000000000000",
&[
Token::Number(Ok(Number::AbstractInt(0))),
Token::Number(Ok(Number::AbstractInt(i64::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
],
);
Expand All @@ -548,77 +563,43 @@ fn test_numbers() {
const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994;
/// ≈ 1 + 2^−23
const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001;
/// ≈ -(2^127 * (2 − 2^−23))
const SMALLEST_NORMAL_F32: f32 = f32::MIN;
/// ≈ 2^127 * (2 − 2^−23)
const LARGEST_NORMAL_F32: f32 = f32::MAX;

// decimal floating point
sub_test(
"1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f -3.40282347e+38f 3.40282347e+38f",
"1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f 3.40282347e+38f",
&[
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_F32_LESS_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_F32_LARGER_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))),
Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))),
Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))),
Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))),
Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))),
Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))),
],
);
sub_test(
"-3.40282367e+38f 3.40282367e+38f",
"3.40282367e+38f",
&[
Token::Number(Err(NumberError::NotRepresentable)), // ≈ -2^128
Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128
],
);

// hexadecimal floating point
sub_test(
"0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f -0xFFFFFFp+104f 0xFFFFFFp+104f",
"0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f 0xFFFFFFp+104f",
&[
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_F32_LESS_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_F32_LARGER_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))),
Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))),
Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))),
Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))),
Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))),
Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))),
],
);
sub_test(
"-0x1p128f 0x1p128f 0x1.000001p0f",
"0x1p128f 0x1.000001p0f",
&[
Token::Number(Err(NumberError::NotRepresentable)), // = -2^128
Token::Number(Err(NumberError::NotRepresentable)), // = 2^128
Token::Number(Err(NumberError::NotRepresentable)),
],
Expand Down
59 changes: 18 additions & 41 deletions naga/src/front/wgsl/parse/number.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::borrow::Cow;

use crate::front::wgsl::error::NumberError;
use crate::front::wgsl::parse::lexer::Token;

Expand All @@ -20,8 +18,6 @@ pub enum Number {
F64(f64),
}

// TODO: when implementing Creation-Time Expressions, remove the ability to match the minus sign

pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
let (result, rest) = parse(input);
(Token::Number(result), rest)
Expand Down Expand Up @@ -64,7 +60,9 @@ enum FloatKind {
// | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? /

// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing
// -?(?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?))
// (?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?))

// Leading signs are handled as unary operators.

fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
/// returns `true` and consumes `X` bytes from the given byte buffer
Expand Down Expand Up @@ -152,8 +150,6 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {

let general_extract = ExtractSubStr::start(input, bytes);

let is_negative = consume!(bytes, b'-');

if consume!(bytes, b'0', b'x' | b'X') {
let digits_extract = ExtractSubStr::start(input, bytes);

Expand Down Expand Up @@ -216,10 +212,7 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
} else {
let kind = consume_map!(bytes, [b'i' => IntKind::I32, b'u' => IntKind::U32]);

(
parse_hex_int(is_negative, digits, kind),
rest_to_str!(bytes),
)
(parse_hex_int(digits, kind), rest_to_str!(bytes))
}
}
} else {
Expand Down Expand Up @@ -272,7 +265,7 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}

let digits_with_sign = general_extract.end(bytes);
let digits = general_extract.end(bytes);

let kind = consume_map!(bytes, [
b'i' => Kind::Int(IntKind::I32),
Expand All @@ -282,17 +275,14 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
b'l', b'f' => Kind::Float(FloatKind::F64),
]);

(
parse_dec(is_negative, digits_with_sign, kind),
rest_to_str!(bytes),
)
(parse_dec(digits, kind), rest_to_str!(bytes))
}
}
}
}

fn parse_hex_float_missing_exponent(
// format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ )
// format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ )
significand: &str,
kind: Option<FloatKind>,
) -> Result<Number, NumberError> {
Expand All @@ -301,7 +291,7 @@ fn parse_hex_float_missing_exponent(
}

fn parse_hex_float_missing_period(
// format: -?0[xX] [0-9a-fA-F]+
// format: 0[xX] [0-9a-fA-F]+
significand: &str,
// format: [pP][+-]?[0-9]+
exponent: &str,
Expand All @@ -312,29 +302,22 @@ fn parse_hex_float_missing_period(
}

fn parse_hex_int(
is_negative: bool,
// format: [0-9a-fA-F]+
digits: &str,
kind: Option<IntKind>,
) -> Result<Number, NumberError> {
let digits_with_sign = if is_negative {
Cow::Owned(format!("-{digits}"))
} else {
Cow::Borrowed(digits)
};
parse_int(&digits_with_sign, kind, 16, is_negative)
parse_int(digits, kind, 16)
}

fn parse_dec(
is_negative: bool,
// format: -? ( [0-9] | [1-9][0-9]+ )
digits_with_sign: &str,
// format: ( [0-9] | [1-9][0-9]+ )
digits: &str,
kind: Option<Kind>,
) -> Result<Number, NumberError> {
match kind {
None => parse_int(digits_with_sign, None, 10, is_negative),
Some(Kind::Int(kind)) => parse_int(digits_with_sign, Some(kind), 10, is_negative),
Some(Kind::Float(kind)) => parse_dec_float(digits_with_sign, Some(kind)),
None => parse_int(digits, None, 10),
Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10),
Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)),
}
}

Expand Down Expand Up @@ -363,7 +346,7 @@ fn parse_dec(

// Therefore we only check for overflow manually for decimal floating point literals

// input format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+
// input format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+
fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
match kind {
None => match hexf_parse::parse_hexf64(input, false) {
Expand All @@ -385,8 +368,8 @@ fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, Numbe
}
}

// input format: -? ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)?
// | -? [0-9]+ [eE][+-]?[0-9]+
// input format: ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)?
// | [0-9]+ [eE][+-]?[0-9]+
fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
match kind {
None => {
Expand All @@ -411,12 +394,7 @@ fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, Numbe
}
}

fn parse_int(
input: &str,
kind: Option<IntKind>,
radix: u32,
is_negative: bool,
) -> Result<Number, NumberError> {
fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, NumberError> {
fn map_err(e: core::num::ParseIntError) -> NumberError {
match *e.kind() {
core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
Expand All @@ -434,7 +412,6 @@ fn parse_int(
Ok(num) => Ok(Number::I32(num)),
Err(e) => Err(map_err(e)),
},
Some(IntKind::U32) if is_negative => Err(NumberError::NotRepresentable),
Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
Ok(num) => Ok(Number::U32(num)),
Err(e) => Err(map_err(e)),
Expand Down
5 changes: 4 additions & 1 deletion naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1191,8 +1191,10 @@ impl<'a> ConstantEvaluator<'a> {
let expr = match self.expressions[expr] {
Expression::Literal(value) => Expression::Literal(match op {
UnaryOperator::Negate => match value {
Literal::I32(v) => Literal::I32(-v),
Literal::I32(v) => Literal::I32(v.wrapping_neg()),
Literal::F32(v) => Literal::F32(-v),
Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
_ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
},
UnaryOperator::LogicalNot => match value {
Expand All @@ -1202,6 +1204,7 @@ impl<'a> ConstantEvaluator<'a> {
UnaryOperator::BitwiseNot => match value {
Literal::I32(v) => Literal::I32(!v),
Literal::U32(v) => Literal::U32(!v),
Literal::AbstractInt(v) => Literal::AbstractInt(!v),
_ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
},
}),
Expand Down
Loading

0 comments on commit d9d051b

Please sign in to comment.