Skip to content

Commit

Permalink
msl: ray query support
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Feb 26, 2023
1 parent c163b9d commit 7586a9c
Show file tree
Hide file tree
Showing 13 changed files with 295 additions and 30 deletions.
23 changes: 23 additions & 0 deletions src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,26 @@ impl crate::Statement {
}
}
}

bitflags::bitflags! {
/// Ray flags.
#[derive(Default)]
pub struct RayFlag: u32 {
const OPAQUE = 0x01;
const NO_OPAQUE = 0x02;
const TERMINATE_ON_FIRST_HIT = 0x04;
const SKIP_CLOSEST_HIT_SHADER = 0x08;
const CULL_FRONT_FACING = 0x10;
const CULL_BACK_FACING = 0x20;
const CULL_OPAQUE = 0x40;
const CULL_NO_OPAQUE = 0x80;
const SKIP_TRIANGLES = 0x100;
const SKIP_AABBS = 0x200;
}
}

#[repr(u32)]
enum RayIntersectionType {
Triangle = 1,
BoundingBox = 4,
}
10 changes: 2 additions & 8 deletions src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,7 @@ impl Options {
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
texture: None,
sampler: None,
binding_array_size: None,
mutable: false,
..Default::default()
})),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
Expand All @@ -338,10 +335,7 @@ impl Options {
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
texture: None,
sampler: None,
binding_array_size: None,
mutable: false,
..Default::default()
})),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
Expand Down
199 changes: 186 additions & 13 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ const WRAPPED_ARRAY_FIELD: &str = "inner";
// Some more general handling of pointers is needed to be implemented here.
const ATOMIC_REFERENCE: &str = "&";

const RT_NAMESPACE: &str = "metal::raytracing";
const RAY_QUERY_TYPE: &str = "_RayQuery";
const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
///
/// The `sizes` slice determines whether this function writes a
Expand Down Expand Up @@ -194,8 +201,11 @@ impl<'a> Display for TypeContext<'a> {
crate::TypeInner::Sampler { comparison: _ } => {
write!(out, "{NAMESPACE}::sampler")
}
crate::TypeInner::AccelerationStructure | crate::TypeInner::RayQuery => {
unreachable!("Ray queries are not supported yet");
crate::TypeInner::AccelerationStructure => {
write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
}
crate::TypeInner::RayQuery => {
write!(out, "{RAY_QUERY_TYPE}")
}
crate::TypeInner::BindingArray { base, size } => {
let base_tyname = Self {
Expand Down Expand Up @@ -1864,8 +1874,39 @@ impl<W: Write> Writer<W> {
write!(self.out, ")")?;
}
}
// hot supported yet
crate::Expression::RayQueryGetIntersection { .. } => unreachable!(),
crate::Expression::RayQueryGetIntersection { query, committed } => {
if !committed {
unimplemented!()
}
let ty = context.module.special_types.ray_intersection.unwrap();
let type_name = &self.names[&NameKey::Type(ty)];
write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
self.put_expression(query, context, true)?;
write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
let fields = [
"distance",
"user_instance_id",
"instance_id",
"", // SBT offset
"geometry_id",
"primitive_id",
"triangle_barycentric_coord",
"triangle_front_facing",
"", // padding
"object_to_world_transform",
"world_to_object_transform",
];
for field in fields {
write!(self.out, ", ")?;
if field.is_empty() {
write!(self.out, "{{}}")?;
} else {
self.put_expression(query, context, true)?;
write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
}
}
write!(self.out, "}}")?;
}
}
Ok(())
}
Expand Down Expand Up @@ -2324,13 +2365,24 @@ impl<W: Write> Writer<W> {
) {
use crate::Expression;
self.need_bake_expressions.clear();

for (expr_handle, expr) in func.expressions.iter() {
// Expressions whose reference count is above the
// threshold should always be stored in temporaries.
let expr_info = &info[expr_handle];
let min_ref_count = func.expressions[expr_handle].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(expr_handle);
} else {
match expr_info.ty {
// force ray desc to be baked: it's used multiple times internally
TypeResolution::Handle(h)
if Some(h) == context.module.special_types.ray_desc =>
{
self.need_bake_expressions.insert(expr_handle);
}
_ => {}
}
}

if let Expression::Math { fun, arg, arg1, .. } = *expr {
Expand All @@ -2342,11 +2394,11 @@ impl<W: Write> Writer<W> {
// times, once for each component (see `put_dot_product`), so to
// avoid duplicated evaluation, we must bake integer operands.

use crate::TypeInner;
// check what kind of product this is depending
// on the resolve type of the Dot function itself
let inner = context.resolve_type(expr_handle);
if let TypeInner::Scalar { kind, .. } = *inner {
if let crate::TypeInner::Scalar { kind, .. } =
*context.resolve_type(expr_handle)
{
match kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
Expand Down Expand Up @@ -2771,7 +2823,100 @@ impl<W: Write> Writer<W> {
// done
writeln!(self.out, ";")?;
}
crate::Statement::RayQuery { .. } => unreachable!(),
crate::Statement::RayQuery { query, ref fun } => {
match *fun {
crate::RayQueryFunction::Initialize {
acceleration_structure,
descriptor,
} => {
//TODO: how to deal with winding?
write!(self.out, "{level}")?;
self.put_expression(query, &context.expression, true)?;
writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
{
let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
write!(self.out, "{level}")?;
self.put_expression(query, &context.expression, true)?;
write!(
self.out,
".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
)?;
self.put_expression(descriptor, &context.expression, true)?;
write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
self.put_expression(descriptor, &context.expression, true)?;
write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
}
{
let f_opaque = back::RayFlag::OPAQUE.bits();
let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
write!(self.out, "{level}")?;
self.put_expression(query, &context.expression, true)?;
write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
self.put_expression(descriptor, &context.expression, true)?;
write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
self.put_expression(descriptor, &context.expression, true)?;
write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
}
{
let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
write!(self.out, "{level}")?;
self.put_expression(query, &context.expression, true)?;
write!(
self.out,
".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
)?;
self.put_expression(descriptor, &context.expression, true)?;
writeln!(self.out, ".flags & {flag}) != 0);")?;
}

write!(self.out, "{level}")?;
self.put_expression(query, &context.expression, true)?;
write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
self.put_expression(query, &context.expression, true)?;
write!(
self.out,
".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
)?;
self.put_expression(descriptor, &context.expression, true)?;
write!(self.out, ".origin, ")?;
self.put_expression(descriptor, &context.expression, true)?;
write!(self.out, ".dir, ")?;
self.put_expression(descriptor, &context.expression, true)?;
write!(self.out, ".tmin, ")?;
self.put_expression(descriptor, &context.expression, true)?;
write!(self.out, ".tmax), ")?;
self.put_expression(acceleration_structure, &context.expression, true)?;
write!(self.out, ", ")?;
self.put_expression(descriptor, &context.expression, true)?;
write!(self.out, ".cull_mask);")?;

write!(self.out, "{level}")?;
self.put_expression(query, &context.expression, true)?;
writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
}
crate::RayQueryFunction::Proceed { result } => {
write!(self.out, "{level}")?;
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_baking_expression(result, &context.expression, &name)?;
self.named_expressions.insert(result, name);
self.put_expression(query, &context.expression, true)?;
writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
//TODO: actually proceed?

write!(self.out, "{level}")?;
self.put_expression(query, &context.expression, true)?;
writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
}
crate::RayQueryFunction::Terminate => {
write!(self.out, "{level}")?;
self.put_expression(query, &context.expression, true)?;
writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?;
}
}
}
}
}

Expand Down Expand Up @@ -2883,14 +3028,41 @@ impl<W: Write> Writer<W> {
writeln!(self.out)?;
// Work around Metal bug where `uint` is not available by default
writeln!(self.out, "using {NAMESPACE}::uint;")?;
writeln!(self.out)?;

if module.types.iter().any(|(_, t)| match t.inner {
crate::TypeInner::RayQuery => true,
_ => false,
}) {
let tab = back::INDENT;
writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
writeln!(
self.out,
"{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
)?;
writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
writeln!(self.out, "}};")?;
writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
let v_triangle = back::RayIntersectionType::Triangle as u32;
let v_bbox = back::RayIntersectionType::BoundingBox as u32;
writeln!(
self.out,
"{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
)?;
writeln!(
self.out,
"{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
)?;
writeln!(self.out, "}}")?;
}
if options
.bounds_check_policies
.contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
{
self.put_default_constructible()?;
}
writeln!(self.out)?;

{
let mut indices = vec![];
Expand Down Expand Up @@ -2932,11 +3104,12 @@ impl<W: Write> Writer<W> {
///
/// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
fn put_default_constructible(&mut self) -> BackendResult {
let tab = back::INDENT;
writeln!(self.out, "struct DefaultConstructible {{")?;
writeln!(self.out, " template<typename T>")?;
writeln!(self.out, " operator T() && {{")?;
writeln!(self.out, " return T {{}};")?;
writeln!(self.out, " }}")?;
writeln!(self.out, "{tab}template<typename T>")?;
writeln!(self.out, "{tab}operator T() && {{")?;
writeln!(self.out, "{tab}{tab}return T {{}};")?;
writeln!(self.out, "{tab}}}")?;
writeln!(self.out, "}};")?;
Ok(())
}
Expand Down
8 changes: 8 additions & 0 deletions tests/in/ray-query.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,12 @@
spv: (
version: (1, 4),
),
msl: (
lang_version: (2, 4),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
zero_initialize_workgroup_memory: false,
per_entry_point_map: {},
inline_samplers: [],
),
)
13 changes: 11 additions & 2 deletions tests/in/ray-query.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@
var acc_struct: acceleration_structure;

/*
let RAY_FLAG_NONE = 0u;
let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 4u;
let RAY_FLAG_NONE = 0x00u;
let RAY_FLAG_OPAQUE = 0x01u;
let RAY_FLAG_NO_OPAQUE = 0x02u;
let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04u;
let RAY_FLAG_SKIP_CLOSEST_HIT_SHADER = 0x08u;
let RAY_FLAG_CULL_FRONT_FACING = 0x10u;
let RAY_FLAG_CULL_BACK_FACING = 0x20u;
let RAY_FLAG_CULL_OPAQUE = 0x40u;
let RAY_FLAG_CULL_NO_OPAQUE = 0x80u;
let RAY_FLAG_SKIP_TRIANGLES = 0x100u;
let RAY_FLAG_SKIP_AABBS = 0x200u;

let RAY_QUERY_INTERSECTION_NONE = 0u;
let RAY_QUERY_INTERSECTION_TRIANGLE = 1u;
Expand Down
2 changes: 1 addition & 1 deletion tests/out/msl/binding-arrays.msl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
#include <simd/simd.h>

using metal::uint;

struct DefaultConstructible {
template<typename T>
operator T() && {
return T {};
}
};

struct UniformIndex {
uint index;
};
Expand Down
2 changes: 1 addition & 1 deletion tests/out/msl/bounds-check-image-rzsw.msl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
#include <simd/simd.h>

using metal::uint;

struct DefaultConstructible {
template<typename T>
operator T() && {
return T {};
}
};

constant metal::int2 const_type_4_ = {0, 0};
constant metal::int3 const_type_7_ = {0, 0, 0};
constant metal::float4 const_type_2_ = {0.0, 0.0, 0.0, 0.0};
Expand Down
Loading

0 comments on commit 7586a9c

Please sign in to comment.