Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 103 additions & 2 deletions naga/src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ use core::fmt::Write;
use super::{
super::FunctionCtx,
writer::{
ABS_FUNCTION, DIV_FUNCTION, EXTRACT_BITS_FUNCTION, INSERT_BITS_FUNCTION, MOD_FUNCTION,
NEG_FUNCTION,
ABS_FUNCTION, DIV_FUNCTION, EXTRACT_BITS_FUNCTION, F2I32_FUNCTION, F2I64_FUNCTION,
F2U32_FUNCTION, F2U64_FUNCTION, INSERT_BITS_FUNCTION, MOD_FUNCTION, NEG_FUNCTION,
},
BackendResult, WrappedType,
};
Expand Down Expand Up @@ -97,6 +97,15 @@ pub(super) struct WrappedBinaryOp {
pub(super) right_ty: (Option<crate::VectorSize>, crate::Scalar),
}

#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
pub(super) struct WrappedCast {
// This can only represent scalar or vector types. If we ever need to wrap
// casts with other types, we'll need a better representation.
pub(super) vector_size: Option<crate::VectorSize>,
pub(super) src_scalar: crate::Scalar,
pub(super) dst_scalar: crate::Scalar,
}

/// HLSL backend requires its own `ImageQuery` enum.
///
/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
Expand Down Expand Up @@ -1355,6 +1364,97 @@ impl<W: Write> super::Writer<'_, W> {
Ok(())
}

fn write_wrapped_cast_functions(
&mut self,
module: &crate::Module,
func_ctx: &FunctionCtx,
) -> BackendResult {
for (_, expression) in func_ctx.expressions.iter() {
if let crate::Expression::As {
expr,
kind,
convert: Some(width),
} = *expression
{
// Avoid undefined behaviour when casting from a float to integer
// when the value is out of range for the target type. Additionally
// ensure we clamp to the correct value as per the WGSL spec.
//
// https://www.w3.org/TR/WGSL/#floating-point-conversion:
// * If X is exactly representable in the target type T, then the
// result is that value.
// * Otherwise, the result is the value in T closest to
// truncate(X) and also exactly representable in the original
// floating point type.
let src_ty = func_ctx.resolve_type(expr, &module.types);
let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
continue;
};
let dst_scalar = crate::Scalar { kind, width };
if src_scalar.kind != ScalarKind::Float
|| (dst_scalar.kind != ScalarKind::Sint && dst_scalar.kind != ScalarKind::Uint)
{
continue;
}

let wrapped = WrappedCast {
src_scalar,
vector_size,
dst_scalar,
};
if !self.wrapped.insert(WrappedType::Cast(wrapped)) {
continue;
}

let (src_ty, dst_ty) = match vector_size {
None => (
crate::TypeInner::Scalar(src_scalar),
crate::TypeInner::Scalar(dst_scalar),
),
Some(vector_size) => (
crate::TypeInner::Vector {
scalar: src_scalar,
size: vector_size,
},
crate::TypeInner::Vector {
scalar: dst_scalar,
size: vector_size,
},
),
};
let (min, max) =
crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
let cast_str = format!(
"{}{}",
dst_scalar.to_hlsl_str()?,
vector_size
.map(crate::common::vector_size_str)
.unwrap_or(""),
);
let fun_name = match dst_scalar {
crate::Scalar::I32 => F2I32_FUNCTION,
crate::Scalar::U32 => F2U32_FUNCTION,
crate::Scalar::I64 => F2I64_FUNCTION,
crate::Scalar::U64 => F2U64_FUNCTION,
_ => unreachable!(),
};
self.write_value_type(module, &dst_ty)?;
write!(self.out, " {fun_name}(")?;
self.write_value_type(module, &src_ty)?;
writeln!(self.out, " value) {{")?;
let level = crate::back::Level(1);
write!(self.out, "{level}return {cast_str}(clamp(value, ")?;
self.write_literal(min)?;
write!(self.out, ", ")?;
self.write_literal(max)?;
writeln!(self.out, "));",)?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}
}
Ok(())
}

/// Helper function that writes various wrapped functions
pub(super) fn write_wrapped_functions(
&mut self,
Expand All @@ -1366,6 +1466,7 @@ impl<W: Write> super::Writer<'_, W> {
self.write_wrapped_binary_ops(module, func_ctx)?;
self.write_wrapped_expression_functions(module, func_ctx.expressions, Some(func_ctx))?;
self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?;
self.write_wrapped_cast_functions(module, func_ctx)?;

for (handle, _) in func_ctx.expressions.iter() {
match func_ctx.expressions[handle] {
Expand Down
4 changes: 4 additions & 0 deletions naga/src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,10 @@ pub const RESERVED: &[&str] = &[
super::writer::DIV_FUNCTION,
super::writer::MOD_FUNCTION,
super::writer::NEG_FUNCTION,
super::writer::F2I32_FUNCTION,
super::writer::F2U32_FUNCTION,
super::writer::F2I64_FUNCTION,
super::writer::F2U64_FUNCTION,
];

// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ enum WrappedType {
Math(help::WrappedMath),
UnaryOp(help::WrappedUnaryOp),
BinaryOp(help::WrappedBinaryOp),
Cast(help::WrappedCast),
}

#[derive(Default)]
Expand Down
156 changes: 92 additions & 64 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ pub(crate) const ABS_FUNCTION: &str = "naga_abs";
pub(crate) const DIV_FUNCTION: &str = "naga_div";
pub(crate) const MOD_FUNCTION: &str = "naga_mod";
pub(crate) const NEG_FUNCTION: &str = "naga_neg";
pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";

struct EpStructMember {
name: String,
Expand Down Expand Up @@ -2612,6 +2616,28 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
})
}

pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
match literal {
crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
// HLSL has no suffix for explicit i32 literals, but not using any suffix
// makes the type ambiguous which prevents overload resolution from
// working. So we explicitly use the int() constructor syntax.
crate::Literal::I32(value) => write!(self.out, "int({value})")?,
crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
crate::Literal::I64(value) => write!(self.out, "{value}L")?,
crate::Literal::Bool(value) => write!(self.out, "{value}")?,
crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
return Err(Error::Custom(
"Abstract types should not appear in IR presented to backends".into(),
));
}
}
Ok(())
}

fn write_possibly_const_expression<E>(
&mut self,
module: &Module,
Expand All @@ -2625,26 +2651,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
use crate::Expression;

match expressions[expr] {
Expression::Literal(literal) => match literal {
// Floats are written using `Debug` instead of `Display` because it always appends the
// decimal part even it's zero
crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
// HLSL has no suffix for explicit i32 literals, but not using any suffix
// makes the type ambiguous which prevents overload resolution from
// working. So we explicitly use the int() constructor syntax.
crate::Literal::I32(value) => write!(self.out, "int({value})")?,
crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
crate::Literal::I64(value) => write!(self.out, "{value}L")?,
crate::Literal::Bool(value) => write!(self.out, "{value}")?,
crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
return Err(Error::Custom(
"Abstract types should not appear in IR presented to backends".into(),
));
}
},
Expression::Literal(literal) => {
self.write_literal(literal)?;
}
Expression::Constant(handle) => {
let constant = &module.constants[handle];
if constant.name.is_some() {
Expand Down Expand Up @@ -3320,53 +3329,72 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
convert,
} => {
let inner = func_ctx.resolve_type(expr, &module.types);
let close_paren = match convert {
Some(dst_width) => {
let scalar = Scalar {
kind,
width: dst_width,
};
match *inner {
TypeInner::Vector { size, .. } => {
write!(
self.out,
"{}{}(",
scalar.to_hlsl_str()?,
common::vector_size_str(size)
)?;
}
TypeInner::Scalar(_) => {
write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
}
TypeInner::Matrix { columns, rows, .. } => {
write!(
self.out,
"{}{}x{}(",
scalar.to_hlsl_str()?,
common::vector_size_str(columns),
common::vector_size_str(rows)
)?;
}
_ => {
return Err(Error::Unimplemented(format!(
"write_expr expression::as {inner:?}"
)));
}
};
true
}
None => {
if inner.scalar_width() == Some(8) {
false
} else {
write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
if inner.scalar_kind() == Some(ScalarKind::Float)
&& (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
&& convert.is_some()
{
// Use helper functions for float to int casts in order to
// avoid undefined behaviour when value is out of range for
// the target type.
let fun_name = match (kind, convert) {
(ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
(ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
(ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
(ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
_ => unreachable!(),
};
write!(self.out, "{fun_name}(")?;
self.write_expr(module, expr, func_ctx)?;
write!(self.out, ")")?;
} else {
let close_paren = match convert {
Some(dst_width) => {
let scalar = Scalar {
kind,
width: dst_width,
};
match *inner {
TypeInner::Vector { size, .. } => {
write!(
self.out,
"{}{}(",
scalar.to_hlsl_str()?,
common::vector_size_str(size)
)?;
}
TypeInner::Scalar(_) => {
write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
}
TypeInner::Matrix { columns, rows, .. } => {
write!(
self.out,
"{}{}x{}(",
scalar.to_hlsl_str()?,
common::vector_size_str(columns),
common::vector_size_str(rows)
)?;
}
_ => {
return Err(Error::Unimplemented(format!(
"write_expr expression::as {inner:?}"
)));
}
};
true
}
None => {
if inner.scalar_width() == Some(8) {
false
} else {
write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
true
}
}
};
self.write_expr(module, expr, func_ctx)?;
if close_paren {
write!(self.out, ")")?;
}
};
self.write_expr(module, expr, func_ctx)?;
if close_paren {
write!(self.out, ")")?;
}
}
Expression::Math {
Expand Down
4 changes: 4 additions & 0 deletions naga/src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ pub const RESERVED: &[&str] = &[
super::writer::DIV_FUNCTION,
super::writer::MOD_FUNCTION,
super::writer::NEG_FUNCTION,
super::writer::F2I32_FUNCTION,
super::writer::F2U32_FUNCTION,
super::writer::F2I64_FUNCTION,
super::writer::F2U64_FUNCTION,
super::writer::ARGUMENT_BUFFER_WRAPPER_STRUCT,
];

Expand Down
Loading