Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga wgsl] Let unary operators accept and produce abstract types. #4870

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ This feature allowed you to call `global_id` on any wgpu opaque handle to get a

#### Naga

- Naga's WGSL front end now allows binary operators to produce values with abstract types, rather than concretizing thir operands. By @jimblandy in [#4850](https://github.com/gfx-rs/wgpu/pull/4850).
- Naga's WGSL front end now allows operators to produce values with abstract types, rather than concretizing thir operands. By @jimblandy in [#4850](https://github.com/gfx-rs/wgpu/pull/4850) and [#4870](https://github.com/gfx-rs/wgpu/pull/4870).

- Naga's WGSL front and back ends now have experimental support for 64-bit floating-point literals: `1.0lf` denotes an `f64` value. There has been experimental support for an `f64` type for a while, but until now there was no syntax for writing literals with that type. As before, Naga module validation rejects `f64` values unless `naga::valid::Capabilities::FLOAT64` is requested. By @jimblandy in [#4747](https://github.com/gfx-rs/wgpu/pull/4747).

Expand Down
11 changes: 10 additions & 1 deletion naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,16 @@ impl<W: Write> Writer<W> {
Expression::Literal(literal) => match literal {
crate::Literal::F32(value) => write!(self.out, "{}f", value)?,
crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
crate::Literal::I32(value) => write!(self.out, "{}i", value)?,
crate::Literal::I32(value) => {
// `-2147483648i` is not valid WGSL. The most negative `i32`
// value can only be expressed in WGSL using AbstractInt and
// a unary negation operator.
if value == i32::MIN {
write!(self.out, "i32(-2147483648)")?;
} else {
write!(self.out, "{}i", value)?;
}
}
crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
crate::Literal::F64(value) => write!(self.out, "{:?}lf", value)?,
crate::Literal::I64(_) => {
Expand Down
2 changes: 1 addition & 1 deletion naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(Typed::Plain(handle));
}
ast::Expression::Unary { op, expr } => {
let expr = self.expression(expr, ctx)?;
let expr = self.expression_for_abstract(expr, ctx)?;
Typed::Plain(crate::Expression::Unary { op, expr })
}
ast::Expression::AddrOf(expr) => {
Expand Down
103 changes: 42 additions & 61 deletions naga/src/front/wgsl/parse/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ fn consume_token(input: &str, generic: bool) -> (Token<'_>, &str) {
let og_chars = chars.as_str();
match chars.next() {
Some('>') => (Token::Arrow, chars.as_str()),
Some('0'..='9' | '.') => consume_number(input),
Some('-') => (Token::DecrementOperation, chars.as_str()),
Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
_ => (Token::Operation(cur), og_chars),
Expand Down Expand Up @@ -496,44 +495,60 @@ fn test_numbers() {

// MIN / MAX //

// min / max decimal signed integer
// min / max decimal integer
sub_test(
"-2147483648i 2147483647i -2147483649i 2147483648i",
"0i 2147483647i 2147483648i",
&[
Token::Number(Ok(Number::I32(i32::MIN))),
Token::Number(Ok(Number::I32(0))),
Token::Number(Ok(Number::I32(i32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
Token::Number(Err(NumberError::NotRepresentable)),
],
);
// min / max decimal unsigned integer
sub_test(
"0u 4294967295u -1u 4294967296u",
"0u 4294967295u 4294967296u",
&[
Token::Number(Ok(Number::U32(u32::MIN))),
Token::Number(Ok(Number::U32(u32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
Token::Number(Err(NumberError::NotRepresentable)),
],
);

// min / max hexadecimal signed integer
sub_test(
"-0x80000000i 0x7FFFFFFFi -0x80000001i 0x80000000i",
"0x0i 0x7FFFFFFFi 0x80000000i",
&[
Token::Number(Ok(Number::I32(i32::MIN))),
Token::Number(Ok(Number::I32(0))),
Token::Number(Ok(Number::I32(i32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
Token::Number(Err(NumberError::NotRepresentable)),
],
);
// min / max hexadecimal unsigned integer
sub_test(
"0x0u 0xFFFFFFFFu -0x1u 0x100000000u",
"0x0u 0xFFFFFFFFu 0x100000000u",
&[
Token::Number(Ok(Number::U32(u32::MIN))),
Token::Number(Ok(Number::U32(u32::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
],
);

// min/max decimal abstract int
sub_test(
"0 9223372036854775807 9223372036854775808",
&[
Token::Number(Ok(Number::AbstractInt(0))),
Token::Number(Ok(Number::AbstractInt(i64::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
],
);

// min/max hexadecimal abstract int
sub_test(
"0 0x7fffffffffffffff 0x8000000000000000",
&[
Token::Number(Ok(Number::AbstractInt(0))),
Token::Number(Ok(Number::AbstractInt(i64::MAX))),
Token::Number(Err(NumberError::NotRepresentable)),
],
);
Expand All @@ -548,77 +563,43 @@ fn test_numbers() {
const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994;
/// ≈ 1 + 2^−23
const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001;
/// ≈ -(2^127 * (2 − 2^−23))
const SMALLEST_NORMAL_F32: f32 = f32::MIN;
/// ≈ 2^127 * (2 − 2^−23)
const LARGEST_NORMAL_F32: f32 = f32::MAX;

// decimal floating point
sub_test(
"1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f -3.40282347e+38f 3.40282347e+38f",
"1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f 3.40282347e+38f",
&[
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_F32_LESS_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_F32_LARGER_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))),
Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))),
Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))),
Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))),
Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))),
Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))),
],
);
sub_test(
"-3.40282367e+38f 3.40282367e+38f",
"3.40282367e+38f",
&[
Token::Number(Err(NumberError::NotRepresentable)), // ≈ -2^128
Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128
],
);

// hexadecimal floating point
sub_test(
"0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f -0xFFFFFFp+104f 0xFFFFFFp+104f",
"0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f 0xFFFFFFp+104f",
&[
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_SUBNORMAL_F32,
))),
Token::Number(Ok(Number::F32(
SMALLEST_POSITIVE_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_F32_LESS_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_F32_LARGER_THAN_ONE,
))),
Token::Number(Ok(Number::F32(
SMALLEST_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(
LARGEST_NORMAL_F32,
))),
Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))),
Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))),
Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))),
Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))),
Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))),
Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))),
],
);
sub_test(
"-0x1p128f 0x1p128f 0x1.000001p0f",
"0x1p128f 0x1.000001p0f",
&[
Token::Number(Err(NumberError::NotRepresentable)), // = -2^128
Token::Number(Err(NumberError::NotRepresentable)), // = 2^128
Token::Number(Err(NumberError::NotRepresentable)),
],
Expand Down
59 changes: 18 additions & 41 deletions naga/src/front/wgsl/parse/number.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::borrow::Cow;

use crate::front::wgsl::error::NumberError;
use crate::front::wgsl::parse::lexer::Token;

Expand All @@ -20,8 +18,6 @@ pub enum Number {
F64(f64),
}

// TODO: when implementing Creation-Time Expressions, remove the ability to match the minus sign

pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
let (result, rest) = parse(input);
(Token::Number(result), rest)
Expand Down Expand Up @@ -64,7 +60,9 @@ enum FloatKind {
// | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? /

// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing
// -?(?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?))
// (?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?))

// Leading signs are handled as unary operators.

fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
/// returns `true` and consumes `X` bytes from the given byte buffer
Expand Down Expand Up @@ -152,8 +150,6 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {

let general_extract = ExtractSubStr::start(input, bytes);

let is_negative = consume!(bytes, b'-');

if consume!(bytes, b'0', b'x' | b'X') {
let digits_extract = ExtractSubStr::start(input, bytes);

Expand Down Expand Up @@ -216,10 +212,7 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
} else {
let kind = consume_map!(bytes, [b'i' => IntKind::I32, b'u' => IntKind::U32]);

(
parse_hex_int(is_negative, digits, kind),
rest_to_str!(bytes),
)
(parse_hex_int(digits, kind), rest_to_str!(bytes))
}
}
} else {
Expand Down Expand Up @@ -272,7 +265,7 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
return (Err(NumberError::Invalid), rest_to_str!(bytes));
}

let digits_with_sign = general_extract.end(bytes);
let digits = general_extract.end(bytes);

let kind = consume_map!(bytes, [
b'i' => Kind::Int(IntKind::I32),
Expand All @@ -282,17 +275,14 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
b'l', b'f' => Kind::Float(FloatKind::F64),
]);

(
parse_dec(is_negative, digits_with_sign, kind),
rest_to_str!(bytes),
)
(parse_dec(digits, kind), rest_to_str!(bytes))
}
}
}
}

fn parse_hex_float_missing_exponent(
// format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ )
// format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ )
significand: &str,
kind: Option<FloatKind>,
) -> Result<Number, NumberError> {
Expand All @@ -301,7 +291,7 @@ fn parse_hex_float_missing_exponent(
}

fn parse_hex_float_missing_period(
// format: -?0[xX] [0-9a-fA-F]+
// format: 0[xX] [0-9a-fA-F]+
significand: &str,
// format: [pP][+-]?[0-9]+
exponent: &str,
Expand All @@ -312,29 +302,22 @@ fn parse_hex_float_missing_period(
}

fn parse_hex_int(
is_negative: bool,
// format: [0-9a-fA-F]+
digits: &str,
kind: Option<IntKind>,
) -> Result<Number, NumberError> {
let digits_with_sign = if is_negative {
Cow::Owned(format!("-{digits}"))
} else {
Cow::Borrowed(digits)
};
parse_int(&digits_with_sign, kind, 16, is_negative)
parse_int(digits, kind, 16)
}

fn parse_dec(
is_negative: bool,
// format: -? ( [0-9] | [1-9][0-9]+ )
digits_with_sign: &str,
// format: ( [0-9] | [1-9][0-9]+ )
digits: &str,
kind: Option<Kind>,
) -> Result<Number, NumberError> {
match kind {
None => parse_int(digits_with_sign, None, 10, is_negative),
Some(Kind::Int(kind)) => parse_int(digits_with_sign, Some(kind), 10, is_negative),
Some(Kind::Float(kind)) => parse_dec_float(digits_with_sign, Some(kind)),
None => parse_int(digits, None, 10),
Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10),
Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)),
}
}

Expand Down Expand Up @@ -363,7 +346,7 @@ fn parse_dec(

// Therefore we only check for overflow manually for decimal floating point literals

// input format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+
// input format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+
fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
match kind {
None => match hexf_parse::parse_hexf64(input, false) {
Expand All @@ -385,8 +368,8 @@ fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, Numbe
}
}

// input format: -? ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)?
// | -? [0-9]+ [eE][+-]?[0-9]+
// input format: ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)?
// | [0-9]+ [eE][+-]?[0-9]+
fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
match kind {
None => {
Expand All @@ -411,12 +394,7 @@ fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, Numbe
}
}

fn parse_int(
input: &str,
kind: Option<IntKind>,
radix: u32,
is_negative: bool,
) -> Result<Number, NumberError> {
fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, NumberError> {
fn map_err(e: core::num::ParseIntError) -> NumberError {
match *e.kind() {
core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
Expand All @@ -434,7 +412,6 @@ fn parse_int(
Ok(num) => Ok(Number::I32(num)),
Err(e) => Err(map_err(e)),
},
Some(IntKind::U32) if is_negative => Err(NumberError::NotRepresentable),
Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
Ok(num) => Ok(Number::U32(num)),
Err(e) => Err(map_err(e)),
Expand Down
5 changes: 4 additions & 1 deletion naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1191,8 +1191,10 @@ impl<'a> ConstantEvaluator<'a> {
let expr = match self.expressions[expr] {
Expression::Literal(value) => Expression::Literal(match op {
UnaryOperator::Negate => match value {
Literal::I32(v) => Literal::I32(-v),
Literal::I32(v) => Literal::I32(v.wrapping_neg()),
Literal::F32(v) => Literal::F32(-v),
Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
_ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
},
UnaryOperator::LogicalNot => match value {
Expand All @@ -1202,6 +1204,7 @@ impl<'a> ConstantEvaluator<'a> {
UnaryOperator::BitwiseNot => match value {
Literal::I32(v) => Literal::I32(!v),
Literal::U32(v) => Literal::U32(!v),
Literal::AbstractInt(v) => Literal::AbstractInt(!v),
_ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
},
}),
Expand Down
Loading
Loading