Skip to content

Commit

Permalink
[wgsl-in] Handle repeated or missing @workgroup_size (#2435)
Browse files Browse the repository at this point in the history
  • Loading branch information
fornwall authored Aug 17, 2023
1 parent 7a19f3a commit f6e99a4
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 13 deletions.
11 changes: 10 additions & 1 deletion src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ pub enum Error<'a> {
Other,
ExpectedArraySize(Span),
NonPositiveArrayLength(Span),
MissingWorkgroupSize(Span),
}

impl<'a> Error<'a> {
Expand Down Expand Up @@ -433,7 +434,7 @@ impl<'a> Error<'a> {
},
Error::RepeatedAttribute(bad_span) => ParseError {
message: format!("repeated attribute: '{}'", &source[bad_span]),
labels: vec![(bad_span, "repated attribute".into())],
labels: vec![(bad_span, "repeated attribute".into())],
notes: vec![],
},
Error::UnknownAttribute(bad_span) => ParseError {
Expand Down Expand Up @@ -704,6 +705,14 @@ impl<'a> Error<'a> {
labels: vec![(span, "must be greater than zero".into())],
notes: vec![],
},
Error::MissingWorkgroupSize(span) => ParseError {
message: "workgroup size is missing on compute shader entry point".to_string(),
labels: vec![(
span,
"must be paired with a @workgroup_size attribute".into(),
)],
notes: vec![],
},
}
}
}
2 changes: 1 addition & 1 deletion src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
name: f.name.name.to_string(),
stage: entry.stage,
early_depth_test: entry.early_depth_test,
workgroup_size: entry.workgroup_size,
workgroup_size: entry.workgroup_size.unwrap_or([0, 0, 0]),
function,
});
Ok(LoweredGlobalDecl::EntryPoint)
Expand Down
2 changes: 1 addition & 1 deletion src/front/wgsl/parse/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ pub struct FunctionResult<'a> {
pub struct EntryPoint {
pub stage: crate::ShaderStage,
pub early_depth_test: Option<crate::EarlyDepthTest>,
pub workgroup_size: [u32; 3],
pub workgroup_size: Option<[u32; 3]>,
}

#[cfg(doc)]
Expand Down
30 changes: 20 additions & 10 deletions src/front/wgsl/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::front::wgsl::error::{Error, ExpectedToken};
use crate::front::wgsl::parse::lexer::{Lexer, Token};
use crate::front::wgsl::parse::number::Number;
use crate::front::SymbolTable;
use crate::{Arena, FastHashSet, Handle, Span};
use crate::{Arena, FastHashSet, Handle, ShaderStage, Span};

pub mod ast;
pub mod conv;
Expand Down Expand Up @@ -2158,7 +2158,8 @@ impl Parser {
// read attributes
let mut binding = None;
let mut stage = ParsedAttribute::default();
let mut workgroup_size = [0u32; 3];
let mut compute_span = Span::new(0, 0);
let mut workgroup_size = ParsedAttribute::default();
let mut early_depth_test = ParsedAttribute::default();
let (mut bind_index, mut bind_group) =
(ParsedAttribute::default(), ParsedAttribute::default());
Expand All @@ -2184,11 +2185,12 @@ impl Parser {
}
("compute", name_span) => {
stage.set(crate::ShaderStage::Compute, name_span)?;
compute_span = name_span;
}
("workgroup_size", _) => {
("workgroup_size", name_span) => {
lexer.expect(Token::Paren('('))?;
workgroup_size = [1u32; 3];
for (i, size) in workgroup_size.iter_mut().enumerate() {
let mut new_workgroup_size = [1u32; 3];
for (i, size) in new_workgroup_size.iter_mut().enumerate() {
*size = Self::generic_non_negative_int_literal(lexer)?;
match lexer.next() {
(Token::Paren(')'), _) => break,
Expand All @@ -2201,6 +2203,7 @@ impl Parser {
}
}
}
workgroup_size.set(new_workgroup_size, name_span)?;
}
("early_depth_test", name_span) => {
let conservative = if lexer.skip(Token::Paren('(')) {
Expand Down Expand Up @@ -2281,11 +2284,18 @@ impl Parser {
(Token::Word("fn"), _) => {
let function = self.function_decl(lexer, out, &mut dependencies)?;
Some(ast::GlobalDeclKind::Fn(ast::Function {
entry_point: stage.value.map(|stage| ast::EntryPoint {
stage,
early_depth_test: early_depth_test.value,
workgroup_size,
}),
entry_point: if let Some(stage) = stage.value {
if stage == ShaderStage::Compute && workgroup_size.value.is_none() {
return Err(Error::MissingWorkgroupSize(compute_span));
}
Some(ast::EntryPoint {
stage,
early_depth_test: early_depth_test.value,
workgroup_size: workgroup_size.value,
})
} else {
None
},
..function
}))
}
Expand Down
16 changes: 16 additions & 0 deletions src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ fn parse_repeated_attributes() {
("size(16)", template_struct),
("vertex", template_stage),
("early_depth_test(less_equal)", template_resource),
("workgroup_size(1)", template_stage),
] {
let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}"));
let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32;
Expand All @@ -548,3 +549,18 @@ fn parse_repeated_attributes() {
));
}
}

#[test]
fn parse_missing_workgroup_size() {
use crate::{
front::wgsl::{error::Error, Frontend},
Span,
};

let shader = "@compute fn vs() -> vec4<f32> { return vec4<f32>(0.0); }";
let result = Frontend::new().inner(shader);
assert!(matches!(
result.unwrap_err(),
Error::MissingWorkgroupSize(span) if span == Span::new(1, 8)
));
}

0 comments on commit f6e99a4

Please sign in to comment.