diff --git a/src/compose/error.rs b/src/compose/error.rs index 8292d25..185f41a 100644 --- a/src/compose/error.rs +++ b/src/compose/error.rs @@ -134,6 +134,12 @@ pub enum ComposerErrorInner { }, #[error("#define statements are only allowed at the start of the top-level shaders")] DefineInModule(usize), + #[error("Invalid WGSL directive '{directive}' at line {position}: {reason}")] + InvalidWgslDirective { + directive: String, + position: usize, + reason: String, + }, } struct ErrorSources<'a> { @@ -239,7 +245,8 @@ impl ComposerError { | ComposerErrorInner::OverrideNotVirtual { pos, .. } | ComposerErrorInner::GlslInvalidVersion(pos) | ComposerErrorInner::DefineInModule(pos) - | ComposerErrorInner::InvalidShaderDefDefinitionValue { pos, .. } => { + | ComposerErrorInner::InvalidShaderDefDefinitionValue { pos, .. } + | ComposerErrorInner::InvalidWgslDirective { position: pos, .. } => { (vec![Label::primary((), *pos..*pos)], vec![]) } ComposerErrorInner::WgslBackError(e) => { diff --git a/src/compose/mod.rs b/src/compose/mod.rs index c1203b7..59f4aea 100644 --- a/src/compose/mod.rs +++ b/src/compose/mod.rs @@ -140,6 +140,9 @@ use crate::{ pub use self::error::{ComposerError, ComposerErrorInner, ErrSource}; use self::preprocess::Preprocessor; +pub use self::wgsl_directives::{ + DiagnosticDirective, EnableDirective, RequiresDirective, WgslDirectives, +}; pub mod comment_strip_iter; pub mod error; @@ -147,6 +150,7 @@ pub mod parse_imports; pub mod preprocess; mod test; pub mod tokenizer; +pub mod wgsl_directives; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, Default)] pub enum ShaderLanguage { @@ -270,8 +274,6 @@ pub struct ComposableModuleDefinition { modules: HashMap, // used in spans when this module is included module_index: usize, - // preprocessor meta data - // metadata: PreprocessorMetaData, } impl ComposableModuleDefinition { @@ -580,11 +582,17 @@ impl Composer { language: ShaderLanguage, imports: &[ImportDefinition], shader_defs: &HashMap, + wgsl_directives: Option, ) -> Result { debug!("creating IR for {} with defs: {:?}", name, shader_defs); + let mut wgsl_string = String::new(); + if let Some(wgsl_directives) = &wgsl_directives { + trace!("adding WGSL directives for {}", name); + wgsl_string = wgsl_directives.to_wgsl_string(); + } let mut module_string = match language { - ShaderLanguage::Wgsl => String::new(), + ShaderLanguage::Wgsl => wgsl_string, #[cfg(feature = "glsl")] ShaderLanguage::Glsl => String::from("#version 450\n"), }; @@ -842,6 +850,7 @@ impl Composer { demote_entrypoints: bool, source: &str, imports: Vec, + wgsl_directives: Option, ) -> Result { let mut imports: Vec<_> = imports .into_iter() @@ -975,6 +984,7 @@ impl Composer { module_definition.language, &imports, shader_defs, + wgsl_directives, )?; // from here on errors need to be reported using the modified source with start_offset @@ -1376,6 +1386,7 @@ impl Composer { true, &preprocessed_source, imports, + None, ) .map_err(|err| err.into()) } @@ -1519,6 +1530,7 @@ impl Composer { name: module_name, mut imports, mut effective_defs, + cleaned_source, .. } = self .preprocessor @@ -1595,7 +1607,7 @@ impl Composer { let module_set = ComposableModuleDefinition { name: module_name.clone(), - sanitized_source: substituted_source, + sanitized_source: cleaned_source, file_path: file_path.to_owned(), language, effective_defs: effective_defs.into_iter().collect(), @@ -1650,7 +1662,13 @@ impl Composer { let sanitized_source = self.sanitize_and_set_auto_bindings(source); - let PreprocessorMetaData { name, defines, .. } = self + let PreprocessorMetaData { + name, + defines, + wgsl_directives, + cleaned_source, + .. + } = self .preprocessor .get_preprocessor_metadata(&sanitized_source, true) .map_err(|inner| ComposerError { @@ -1667,7 +1685,7 @@ impl Composer { let PreprocessOutput { imports, .. } = self .preprocessor - .preprocess(&sanitized_source, &shader_defs) + .preprocess(&cleaned_source, &shader_defs) .map_err(|inner| ComposerError { inner, source: ErrSource::Constructing { @@ -1734,7 +1752,7 @@ impl Composer { let definition = ComposableModuleDefinition { name, - sanitized_source: sanitized_source.clone(), + sanitized_source: cleaned_source.clone(), language: shader_type.into(), file_path: file_path.to_owned(), module_index: 0, @@ -1751,7 +1769,7 @@ impl Composer { imports, } = self .preprocessor - .preprocess(&sanitized_source, &shader_defs) + .preprocess(&cleaned_source, &shader_defs) .map_err(|inner| ComposerError { inner, source: ErrSource::Constructing { @@ -1770,6 +1788,7 @@ impl Composer { false, &preprocessed_source, imports, + Some(wgsl_directives), ) .map_err(|e| ComposerError { inner: e.inner, diff --git a/src/compose/preprocess.rs b/src/compose/preprocess.rs index 1fbed3e..ebedf3c 100644 --- a/src/compose/preprocess.rs +++ b/src/compose/preprocess.rs @@ -6,6 +6,7 @@ use regex::Regex; use super::{ comment_strip_iter::CommentReplaceExt, parse_imports::{parse_imports, substitute_identifiers}, + wgsl_directives::{DiagnosticDirective, EnableDirective, RequiresDirective, WgslDirectives}, ComposerErrorInner, ImportDefWithOffset, ShaderDefValue, }; @@ -22,6 +23,9 @@ pub struct Preprocessor { import_regex: Regex, define_import_path_regex: Regex, define_shader_def_regex: Regex, + enable_regex: Regex, + requires_regex: Regex, + diagnostic_regex: Regex, } impl Default for Preprocessor { @@ -42,6 +46,10 @@ impl Default for Preprocessor { define_import_path_regex: Regex::new(r"^\s*#\s*define_import_path\s+([^\s]+)").unwrap(), define_shader_def_regex: Regex::new(r"^\s*#\s*define\s+([\w|\d|_]+)\s*([-\w|\d]+)?") .unwrap(), + enable_regex: Regex::new(r"^\s*enable\s+([^;]+)\s*;").unwrap(), + requires_regex: Regex::new(r"^\s*requires\s+([^;]+)\s*;").unwrap(), + diagnostic_regex: Regex::new(r"^\s*diagnostic\s*\(\s*([^,]+)\s*,\s*([^)]+)\s*\)\s*;") + .unwrap(), } } } @@ -52,6 +60,8 @@ pub struct PreprocessorMetaData { pub imports: Vec, pub defines: HashMap, pub effective_defs: HashSet, + pub wgsl_directives: WgslDirectives, + pub cleaned_source: String, } enum ScopeLevel { @@ -133,6 +143,145 @@ pub struct PreprocessOutput { } impl Preprocessor { + fn parse_enable_directive( + &self, + line: &str, + line_idx: usize, + ) -> Result, ComposerErrorInner> { + if let Some(cap) = self.enable_regex.captures(line) { + let extensions_str = cap.get(1).unwrap().as_str().trim(); + let extensions: Vec = extensions_str + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + if extensions.is_empty() { + return Err(ComposerErrorInner::InvalidWgslDirective { + directive: line.to_string(), + position: line_idx, + reason: "No extensions specified".to_string(), + }); + } + + Ok(Some(EnableDirective { + extensions, + source_location: line_idx, + })) + } else { + Ok(None) + } + } + + fn parse_requires_directive( + &self, + line: &str, + line_idx: usize, + ) -> Result, ComposerErrorInner> { + if let Some(cap) = self.requires_regex.captures(line) { + let extensions_str = cap.get(1).unwrap().as_str().trim(); + let extensions: Vec = extensions_str + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + if extensions.is_empty() { + return Err(ComposerErrorInner::InvalidWgslDirective { + directive: line.to_string(), + position: line_idx, + reason: "No extensions specified".to_string(), + }); + } + + Ok(Some(RequiresDirective { + extensions, + source_location: line_idx, + })) + } else { + Ok(None) + } + } + + fn parse_diagnostic_directive( + &self, + line: &str, + line_idx: usize, + ) -> Result, ComposerErrorInner> { + if let Some(cap) = self.diagnostic_regex.captures(line) { + let severity = cap.get(1).unwrap().as_str().trim().to_string(); + let rule = cap.get(2).unwrap().as_str().trim().to_string(); + + match severity.as_str() { + "off" | "info" | "warn" | "error" => {} + _ => { + return Err(ComposerErrorInner::InvalidWgslDirective { + directive: line.to_string(), + position: line_idx, + reason: format!( + "Invalid severity '{}'. Must be one of: off, info, warn, error", + severity + ), + }); + } + } + + Ok(Some(DiagnosticDirective { + severity, + rule, + source_location: line_idx, + })) + } else { + Ok(None) + } + } + + pub fn extract_wgsl_directives( + &self, + source: &str, + ) -> Result<(String, WgslDirectives), ComposerErrorInner> { + let mut directives = WgslDirectives::default(); + let mut cleaned_lines = Vec::new(); + let mut in_directive_section = true; + + for (line_idx, line) in source.lines().enumerate() { + let trimmed = line.trim(); + + if trimmed.is_empty() || trimmed.starts_with("//") { + cleaned_lines.push(line); + continue; + } + + if in_directive_section { + if let Some(enable) = self.parse_enable_directive(trimmed, line_idx)? { + cleaned_lines.push(""); + directives.enables.push(enable); + continue; + } else if let Some(requires) = self.parse_requires_directive(trimmed, line_idx)? { + cleaned_lines.push(""); + directives.requires.push(requires); + continue; + } else if let Some(diagnostic) = + self.parse_diagnostic_directive(trimmed, line_idx)? + { + cleaned_lines.push(""); + directives.diagnostics.push(diagnostic); + continue; + } else if !trimmed.starts_with("enable") + && !trimmed.starts_with("requires") + && !trimmed.starts_with("diagnostic") + { + in_directive_section = false; + } + } + + cleaned_lines.push(line); + } + + let cleaned_source = cleaned_lines.join("\n"); + Ok((cleaned_source, directives)) + } + fn check_scope<'a>( &self, shader_defs: &HashMap, @@ -379,6 +528,8 @@ impl Preprocessor { shader_str: &str, allow_defines: bool, ) -> Result { + let (shader_str, wgsl_directives) = self.extract_wgsl_directives(shader_str)?; + let mut declared_imports = IndexMap::default(); let mut used_imports = IndexMap::default(); let mut name = None; @@ -477,6 +628,8 @@ impl Preprocessor { imports: used_imports.into_values().collect(), defines, effective_defs, + wgsl_directives, + cleaned_source: shader_str.to_string(), }) } } diff --git a/src/compose/test.rs b/src/compose/test.rs index 9985090..0f37a1e 100644 --- a/src/compose/test.rs +++ b/src/compose/test.rs @@ -5,6 +5,7 @@ mod test { use std::io::Write; use std::{borrow::Cow, collections::HashMap}; + use naga::valid::Capabilities; use wgpu::{ BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BufferDescriptor, BufferUsages, CommandEncoderDescriptor, ComputePassDescriptor, @@ -1557,4 +1558,33 @@ mod test { f32::from_le_bytes(view.try_into().unwrap()) } + + #[test] + fn wgsl_dual_source_blending() { + let mut composer = Composer::default(); + composer + .capabilities + .set(Capabilities::DUAL_SOURCE_BLENDING, true); + + let module = composer + .make_naga_module(NagaModuleDescriptor { + source: include_str!("tests/dual_source_blending/blending.wgsl"), + file_path: "tests/dual_source_blending/blending.wgsl", + ..Default::default() + }) + .unwrap(); + + let info = composer.create_validator().validate(&module).unwrap(); + let wgsl = naga::back::wgsl::write_string( + &module, + &info, + naga::back::wgsl::WriterFlags::EXPLICIT_TYPES, + ) + .unwrap(); + + // let mut f = std::fs::File::create("wgsl_dual_source_blending.txt").unwrap(); + // f.write_all(wgsl.as_bytes()).unwrap(); + // drop(f); + output_eq!(wgsl, "tests/expected/wgsl_dual_source_blending.txt"); + } } diff --git a/src/compose/tests/dual_source_blending/blending.wgsl b/src/compose/tests/dual_source_blending/blending.wgsl new file mode 100644 index 0000000..7a43569 --- /dev/null +++ b/src/compose/tests/dual_source_blending/blending.wgsl @@ -0,0 +1,13 @@ +enable dual_source_blending; + +struct FragmentOutput { + @location(0) @blend_src(0) source0_: vec4, + @location(0) @blend_src(1) source1_: vec4, +} + +@fragment +fn fragment( + @builtin(position) frag_coord: vec4, +) -> FragmentOutput { + return FragmentOutput(frag_coord, frag_coord); +} diff --git a/src/compose/tests/expected/wgsl_dual_source_blending.txt b/src/compose/tests/expected/wgsl_dual_source_blending.txt new file mode 100644 index 0000000..6058fb8 --- /dev/null +++ b/src/compose/tests/expected/wgsl_dual_source_blending.txt @@ -0,0 +1,10 @@ +enable dual_source_blending; +struct FragmentOutput { + @location(0) @blend_src(0) source0_: vec4, + @location(0) @blend_src(1) source1_: vec4, +} + +@fragment +fn fragment(@builtin(position) frag_coord: vec4) -> FragmentOutput { + return FragmentOutput(frag_coord, frag_coord); +} diff --git a/src/compose/wgsl_directives.rs b/src/compose/wgsl_directives.rs new file mode 100644 index 0000000..59dc0ef --- /dev/null +++ b/src/compose/wgsl_directives.rs @@ -0,0 +1,105 @@ +use std::collections::HashSet; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct EnableDirective { + pub extensions: Vec, + pub source_location: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RequiresDirective { + pub extensions: Vec, + pub source_location: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DiagnosticDirective { + pub severity: String, + pub rule: String, + pub source_location: usize, +} + +#[derive(Debug, Clone, Default)] +pub struct WgslDirectives { + pub enables: Vec, + pub requires: Vec, + pub diagnostics: Vec, +} + +impl WgslDirectives { + pub fn to_wgsl_string(&self) -> String { + let mut result = String::new(); + + let mut all_enables = HashSet::new(); + for enable in &self.enables { + all_enables.extend(enable.extensions.iter().cloned()); + } + if !all_enables.is_empty() { + let mut enables: Vec<_> = all_enables.into_iter().collect(); + enables.sort(); + result.push_str(&format!("enable {};\n", enables.join(", "))); + } + + let mut all_requires = HashSet::new(); + for requires in &self.requires { + all_requires.extend(requires.extensions.iter().cloned()); + } + if !all_requires.is_empty() { + let mut requires: Vec<_> = all_requires.into_iter().collect(); + requires.sort(); + result.push_str(&format!("requires {};\n", requires.join(", "))); + } + + for diagnostic in &self.diagnostics { + result.push_str(&format!( + "diagnostic({}, {});\n", + diagnostic.severity, diagnostic.rule + )); + } + + if !result.is_empty() { + result.push('\n'); // Add blank line after directives + } + + result + } + + pub fn is_empty(&self) -> bool { + self.enables.is_empty() && self.requires.is_empty() && self.diagnostics.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_wgsl_directives_empty() { + let directives = WgslDirectives::default(); + assert!(directives.is_empty()); + assert_eq!(directives.to_wgsl_string(), ""); + } + + #[test] + fn test_wgsl_directives_to_string() { + let mut directives = WgslDirectives::default(); + directives.enables.push(EnableDirective { + extensions: vec!["f16".to_string(), "subgroups".to_string()], + source_location: 0, + }); + directives.requires.push(RequiresDirective { + extensions: vec!["readonly_and_readwrite_storage_textures".to_string()], + source_location: 0, + }); + directives.diagnostics.push(DiagnosticDirective { + severity: "warn".to_string(), + rule: "derivative_uniformity".to_string(), + source_location: 0, + }); + + let result = directives.to_wgsl_string(); + assert!(result.contains("enable f16, subgroups;")); + assert!(result.contains("requires readonly_and_readwrite_storage_textures;")); + assert!(result.contains("diagnostic(warn, derivative_uniformity);")); + } +}