Skip to content

Commit

Permalink
[wgsl-in] Allow sign() to take int argument
Browse files Browse the repository at this point in the history
  • Loading branch information
fornwall committed Sep 5, 2023
1 parent 5329aa2 commit c69e5c1
Show file tree
Hide file tree
Showing 50 changed files with 706 additions and 105 deletions.
1 change: 1 addition & 0 deletions src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,5 @@ pub const RESERVED: &[&str] = &[
"clamped_lod_e",
super::writer::FREXP_FUNCTION,
super::writer::MODF_FUNCTION,
super::writer::NAGA_ISIGN_FUNCTION,
];
41 changes: 37 additions & 4 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const NAGA_ISIGN_FUNCTION: &str = "naga_isign";

/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
///
Expand Down Expand Up @@ -1647,8 +1648,9 @@ impl<W: Write> Writer<W> {
} => {
use crate::MathFunction as Mf;

let scalar_argument = match *context.resolve_type(arg) {
crate::TypeInner::Scalar { .. } => true,
let arg_type = context.resolve_type(arg);
let scalar_argument = match arg_type {
&crate::TypeInner::Scalar { .. } => true,
_ => false,
};

Expand Down Expand Up @@ -1713,7 +1715,17 @@ impl<W: Write> Writer<W> {
Mf::Reflect => "reflect",
Mf::Refract => "refract",
// computational
Mf::Sign => "sign",
Mf::Sign => match arg_type {
&crate::TypeInner::Scalar {
kind: crate::ScalarKind::Sint,
..
}
| crate::TypeInner::Vector {
kind: crate::ScalarKind::Sint,
..
} => NAGA_ISIGN_FUNCTION,
_ => "sign",
},
Mf::Fma => "fma",
Mf::Mix => "mix",
Mf::Step => "step",
Expand Down Expand Up @@ -1816,7 +1828,7 @@ impl<W: Write> Writer<W> {
write!(self.out, "((")?;
self.put_expression(arg, context, false)?;
write!(self.out, ") * 57.295779513082322865)")?;
} else if fun == Mf::Modf || fun == Mf::Frexp {
} else if fun == Mf::Modf || fun == Mf::Frexp || fun_name == NAGA_ISIGN_FUNCTION {
write!(self.out, "{fun_name}")?;
self.put_call_parameters(iter::once(arg), context)?;
} else {
Expand Down Expand Up @@ -3091,6 +3103,7 @@ impl<W: Write> Writer<W> {

self.write_type_defs(module)?;
self.write_global_constants(module)?;
self.write_polyfills()?;
self.write_functions(module, info, options, pipeline_options)
}

Expand Down Expand Up @@ -4103,6 +4116,26 @@ impl<W: Write> Writer<W> {
Ok(info)
}

fn write_polyfills(&mut self) -> Result<(), Error> {
writeln!(self.out)?;
for size in 1..5 {
let tmp;
let type_name = if size == 1 {
"int"
} else {
tmp = format!("int{size}");
&tmp
};
writeln!(
self.out,
"{type_name} {NAGA_ISIGN_FUNCTION}({type_name} arg) {{
return {NAMESPACE}::select({NAMESPACE}::select({type_name}(-1), {type_name}(1), (arg > 0)), 0, (arg == 0));
}}"
)?;
}
Ok(())
}

fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
// Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
// so we try to avoid it here.
Expand Down
17 changes: 16 additions & 1 deletion src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,6 @@ impl super::Validator {
| Mf::Log
| Mf::Log2
| Mf::Length
| Mf::Sign
| Mf::Sqrt
| Mf::InverseSqrt => {
if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
Expand All @@ -992,6 +991,22 @@ impl super::Validator {
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::Sign => {
if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar {
kind: Sk::Float | Sk::Sint,
..
}
| Ti::Vector {
kind: Sk::Float | Sk::Sint,
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), None, None) => ty1,
Expand Down
1 change: 1 addition & 0 deletions tests/in/math-functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ fn main() {
let d = radians(v);
let e = saturate(v);
let g = refract(v, v, f);
let h = sign(-1);
let const_dot = dot(vec2<i32>(), vec2<i32>());
let first_leading_bit_abs = firstLeadingBit(abs(0u));
let flb_a = firstLeadingBit(-1);
Expand Down
5 changes: 3 additions & 2 deletions tests/out/glsl/math-functions.main.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void main() {
vec4 d = radians(v);
vec4 e = clamp(v, vec4(0.0), vec4(1.0));
vec4 g = refract(v, v, 1.0);
int h = sign(-1);
int const_dot = ( + ivec2(0).x * ivec2(0).x + ivec2(0).y * ivec2(0).y);
uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u)))));
int flb_a = findMSB(-1);
Expand All @@ -81,8 +82,8 @@ void main() {
ivec2 ctz_h = ivec2(min(uvec2(findLSB(ivec2(1))), uvec2(32u)));
int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1));
uint clz_b = uint(31 - findMSB(1u));
ivec2 _e58 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e58), ivec2(0), lessThan(_e58, ivec2(0)));
ivec2 _e60 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e60), ivec2(0), lessThan(_e60, ivec2(0)));
uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u)));
float lde_a = ldexp(1.0, 2);
vec2 lde_b = ldexp(vec2(1.0, 2.0), ivec2(3, 4));
Expand Down
5 changes: 3 additions & 2 deletions tests/out/hlsl/math-functions.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void main()
float4 d = radians(v);
float4 e = saturate(v);
float4 g = refract(v, v, 1.0);
int h = sign(-1);
int const_dot = dot((int2)0, (int2)0);
uint first_leading_bit_abs = firstbithigh(abs(0u));
int flb_a = asint(firstbithigh(-1));
Expand All @@ -91,8 +92,8 @@ void main()
int2 ctz_h = asint(min((32u).xx, firstbitlow((1).xx)));
int clz_a = (-1 < 0 ? 0 : 31 - asint(firstbithigh(-1)));
uint clz_b = (31u - firstbithigh(1u));
int2 _expr58 = (-1).xx;
int2 clz_c = (_expr58 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr58)));
int2 _expr60 = (-1).xx;
int2 clz_c = (_expr60 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr60)));
uint2 clz_d = ((31u).xx - firstbithigh((1u).xx));
float lde_a = ldexp(1.0, 2);
float2 lde_b = ldexp(float2(1.0, 2.0), int2(3, 4));
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/access.msl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ struct type_26 {
metal::float4 inner[2];
};

int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

void test_matrix_within_struct_accesses(
constant Baz& baz
) {
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/array-in-ctor.msl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ struct Ah {
type_1 inner;
};

int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

kernel void cs_main(
device Ah const& ah [[user(fake0)]]
) {
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/array-in-function-return-type.msl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ struct type_1 {
float inner[2];
};

int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

type_1 ret_array(
) {
return type_1 {1.0, 2.0};
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/atomicOps.msl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ struct Struct {
type_2 atomic_arr;
};

int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

struct cs_mainInput {
};
kernel void cs_main(
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/binding-arrays.msl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ struct FragmentIn {
uint index;
};

int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

struct main_Input {
uint index [[user(loc0), flat]];
};
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/bitcast.msl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
using metal::uint;


int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

kernel void main_(
) {
metal::int2 i2_ = {};
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/bits.msl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
using metal::uint;


int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

kernel void main_(
) {
int i = {};
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/boids.msl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ struct Particles {
};
constant uint NUM_PARTICLES = 1500u;

int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

struct main_Input {
};
kernel void main_(
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/bounds-check-image-restrict.msl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
using metal::uint;


int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

metal::float4 test_textureLoad_1d(
int coords,
int level,
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/bounds-check-image-rzsw.msl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ struct DefaultConstructible {
};


int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

metal::float4 test_textureLoad_1d(
int coords,
int level,
Expand Down
13 changes: 13 additions & 0 deletions tests/out/msl/bounds-check-restrict.msl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ struct Globals {
type_4 d;
};

int naga_isign(int arg) {
return metal::select(metal::select(int(-1), int(1), (arg > 0)), 0, (arg == 0));
}
int2 naga_isign(int2 arg) {
return metal::select(metal::select(int2(-1), int2(1), (arg > 0)), 0, (arg == 0));
}
int3 naga_isign(int3 arg) {
return metal::select(metal::select(int3(-1), int3(1), (arg > 0)), 0, (arg == 0));
}
int4 naga_isign(int4 arg) {
return metal::select(metal::select(int4(-1), int4(1), (arg > 0)), 0, (arg == 0));
}

float index_array(
int i,
device Globals const& globals,
Expand Down
Loading

0 comments on commit c69e5c1

Please sign in to comment.