Skip to content

Commit

Permalink
[wgsl-in] Handle modf and frexp
Browse files Browse the repository at this point in the history
  • Loading branch information
fornwall committed Aug 23, 2023
1 parent 3da9355 commit 15f9694
Show file tree
Hide file tree
Showing 26 changed files with 478 additions and 140 deletions.
3 changes: 3 additions & 0 deletions src/back/glsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,4 +477,7 @@ pub const RESERVED_KEYWORDS: &[&str] = &[
// entry point name (should not be shadowed)
//
"main",
// Naga utilities:
super::MODF_FUNCTION,
super::FREXP_FUNCTION,
];
34 changes: 32 additions & 2 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320];
/// of detail for bounds checking in `ImageLoad`
const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod";

pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

/// Mapping between resources and bindings.
pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>;

Expand Down Expand Up @@ -625,6 +628,33 @@ impl<'a, W: Write> Writer<'a, W> {
}
}

// Write functions to create special types.
if let Some(ty_struct) = self.module.special_types.modf_result {
let struct_name = &self.names[&NameKey::Type(ty_struct)];
writeln!(self.out)?;
writeln!(
self.out,
"{} {MODF_FUNCTION}(float arg) {{
float whole;
float fract = modf(arg, whole);
return {}(fract, whole);
}}",
struct_name, struct_name
)?;
}
if let Some(ty_struct) = self.module.special_types.frexp_result {
let struct_name = &self.names[&NameKey::Type(ty_struct)];
writeln!(self.out)?;
writeln!(
self.out,
"{struct_name} {FREXP_FUNCTION}(float arg) {{
int exp;
float fract = frexp(arg, exp);
return {struct_name}(fract, exp);
}}"
)?;
}

// Write all named constants
let mut constants = self
.module
Expand Down Expand Up @@ -2985,8 +3015,8 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Round => "roundEven",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Modf => MODF_FUNCTION,
Mf::Frexp => FREXP_FUNCTION,
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Expand Down
36 changes: 36 additions & 0 deletions src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,42 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}

pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult {
if let Some(ty_struct) = module.special_types.modf_result {
let struct_name = &self.names[&NameKey::Type(ty_struct)];
writeln!(
self.out,
"{struct_name} {}(in float arg) {{
float whole;
float fract = modf(arg, whole);
{struct_name} result;
result.whole = whole;
result.fract = fract;
return result;
}}",
super::writer::MODF_FUNCTION,
)?;
writeln!(self.out)?;
}
if let Some(ty_struct) = module.special_types.frexp_result {
let struct_name = &self.names[&NameKey::Type(ty_struct)];
writeln!(
self.out,
"{struct_name} {}(in float arg) {{
float exp_;
float fract = frexp(arg, exp_);
{struct_name} result;
result.exp_ = exp_;
result.fract = fract;
return result;
}}",
super::writer::FREXP_FUNCTION
)?;
writeln!(self.out)?;
}
Ok(())
}

/// Helper function that writes compose wrapped functions
pub(super) fn write_wrapped_compose_functions(
&mut self,
Expand Down
3 changes: 3 additions & 0 deletions src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,9 @@ pub const RESERVED: &[&str] = &[
"TextureBuffer",
"ConstantBuffer",
"RayQuery",
// Naga utilities
super::writer::FREXP_FUNCTION,
super::writer::MODF_FUNCTION,
];

// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
Expand Down
9 changes: 7 additions & 2 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ const SPECIAL_BASE_VERTEX: &str = "base_vertex";
const SPECIAL_BASE_INSTANCE: &str = "base_instance";
const SPECIAL_OTHER: &str = "other";

pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

struct EpStructMember {
name: String,
ty: Handle<crate::Type>,
Expand Down Expand Up @@ -244,6 +247,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

self.write_special_functions(module)?;

self.write_wrapped_compose_functions(module, &module.const_expressions)?;

// Write all named constants
Expand Down Expand Up @@ -2665,8 +2670,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Round => Function::Regular("round"),
Mf::Fract => Function::Regular("frac"),
Mf::Trunc => Function::Regular("trunc"),
Mf::Modf => Function::Regular("modf"),
Mf::Frexp => Function::Regular("frexp"),
Mf::Modf => Function::Regular(MODF_FUNCTION),
Mf::Frexp => Function::Regular(FREXP_FUNCTION),
Mf::Ldexp => Function::Regular("ldexp"),
// exponent
Mf::Exp => Function::Regular("exp"),
Expand Down
2 changes: 2 additions & 0 deletions src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,6 @@ pub const RESERVED: &[&str] = &[
// Naga utilities
"DefaultConstructible",
"clamped_lod_e",
super::writer::FREXP_FUNCTION,
super::writer::MODF_FUNCTION,
];
35 changes: 33 additions & 2 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
///
/// The `sizes` slice determines whether this function writes a
Expand Down Expand Up @@ -1678,8 +1681,8 @@ impl<W: Write> Writer<W> {
Mf::Round => "rint",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Modf => MODF_FUNCTION,
Mf::Frexp => FREXP_FUNCTION,
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Expand Down Expand Up @@ -1813,6 +1816,9 @@ impl<W: Write> Writer<W> {
write!(self.out, "((")?;
self.put_expression(arg, context, false)?;
write!(self.out, ") * 57.295779513082322865)")?;
} else if fun == Mf::Modf || fun == Mf::Frexp {
write!(self.out, "{fun_name}")?;
self.put_call_parameters(iter::once(arg), context)?;
} else {
write!(self.out, "{NAMESPACE}::{fun_name}")?;
self.put_call_parameters(
Expand Down Expand Up @@ -3236,6 +3242,31 @@ impl<W: Write> Writer<W> {
}
}
}

if let Some(struct_ty) = module.special_types.modf_result {
let struct_name = &self.names[&NameKey::Type(struct_ty)];
writeln!(
self.out,
"struct {struct_name} {MODF_FUNCTION}(float arg) {{
float whole;
float fract = {NAMESPACE}::modf(arg, whole);
return {struct_name}{{ fract, whole }};
}};"
)?;
}

if let Some(struct_ty) = module.special_types.frexp_result {
let struct_name = &self.names[&NameKey::Type(struct_ty)];
writeln!(
self.out,
"struct {struct_name} {FREXP_FUNCTION}(float arg) {{
int exp;
float fract = {NAMESPACE}::frexp(arg, exp);
return {struct_name}{{ fract, exp }};
}};"
)?;
}

Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,8 @@ impl<'w> BlockContext<'w> {
Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
Mf::Modf => MathOp::Ext(spirv::GLOp::Modf),
Mf::Frexp => MathOp::Ext(spirv::GLOp::Frexp),
Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
// geometry
Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
Expand Down
21 changes: 14 additions & 7 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ impl<W: Write> Writer<W> {
self.ep_results.clear();
}

fn is_builtin_wgsl_struct(&self, ty_struct: &crate::Type) -> bool {
ty_struct
.name
.as_ref()
.map_or(false, |name| name.starts_with("__"))
}

pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
self.reset(module);

Expand All @@ -108,13 +115,13 @@ impl<W: Write> Writer<W> {

// Write all structs
for (handle, ty) in module.types.iter() {
if let TypeInner::Struct {
ref members,
span: _,
} = ty.inner
{
self.write_struct(module, handle, members)?;
writeln!(self.out)?;
if let TypeInner::Struct { ref members, .. } = ty.inner {
{
if !self.is_builtin_wgsl_struct(ty) {
self.write_struct(module, handle, members)?;
writeln!(self.out)?;
}
}
}
}

Expand Down
97 changes: 97 additions & 0 deletions src/front/type_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,101 @@ impl crate::Module {
self.special_types.ray_intersection = Some(handle);
handle
}

/// Populate this module's [`crate::SpecialTypes::modf_result`] type.
pub fn generate_modf_result(&mut self) {
if self.special_types.modf_result.is_some() {
return;
}

let float_ty = self.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Float,
width: 4,
},
},
Span::UNDEFINED,
);

let handle = self.types.insert(
crate::Type {
name: Some("__modf_result_f32".to_string()),
inner: crate::TypeInner::Struct {
members: vec![
crate::StructMember {
name: Some("fract".to_string()),
ty: float_ty,
binding: None,
offset: 0,
},
crate::StructMember {
name: Some("whole".to_string()),
ty: float_ty,
binding: None,
offset: 4,
},
],
span: 8,
},
},
Span::UNDEFINED,
);
self.special_types.modf_result = Some(handle);
}

/// Populate this module's [`crate::SpecialTypes::frexp_result`] type.
pub fn generate_frexp_result(&mut self) {
if self.special_types.frexp_result.is_some() {
return;
}

let sint_ty = self.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Sint,
width: 4,
},
},
Span::UNDEFINED,
);

let float_ty = self.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Float,
width: 4,
},
},
Span::UNDEFINED,
);

let handle = self.types.insert(
crate::Type {
name: Some("__frexp_result_f32".to_string()),
inner: crate::TypeInner::Struct {
members: vec![
crate::StructMember {
name: Some("fract".to_string()),
ty: float_ty,
binding: None,
offset: 0,
},
crate::StructMember {
name: Some("exp".to_string()),
ty: sint_ty,
binding: None,
offset: 4,
},
],
span: 8,
},
},
Span::UNDEFINED,
);
self.special_types.frexp_result = Some(handle);
}
}
7 changes: 7 additions & 0 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut args = ctx.prepare_args(arguments, expected, span);

let arg = self.expression(args.next()?, ctx.reborrow())?;

let arg1 = args
.next()
.map(|x| self.expression(x, ctx.reborrow()))
Expand All @@ -1731,6 +1732,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

args.finish()?;

if fun == crate::MathFunction::Modf {
ctx.module.generate_modf_result();
} else if fun == crate::MathFunction::Frexp {
ctx.module.generate_frexp_result();
};

crate::Expression::Math {
fun,
arg,
Expand Down
5 changes: 3 additions & 2 deletions src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,11 @@ fn binary_expression_mixed_scalar_and_vector_operands() {
#[test]
fn parse_pointers() {
parse_str(
"fn foo() {
"fn foo(a: ptr<private, f32>) -> f32 { return *a; }
fn bar() {
var x: f32 = 1.0;
let px = &x;
let py = frexp(0.5, px);
let py = foo(px);
}",
)
.unwrap();
Expand Down
Loading

0 comments on commit 15f9694

Please sign in to comment.