Skip to content

Commit 7a3923d

Browse files
committed
[naga hlsl-out msl-out spv-out] Avoid undefined behaviour when casting floats to integers
Currently we generate code to convert floating point values to integers using constructor-style casts in HLSL, static_cast in MSL, and OpConvertFToS/OpConvertFToU instructions in SPV. Unfortunately the behaviour of these operations is undefined when the original value is outside of the range of the target type. This patch avoids undefined behaviour by first clamping the value to be inside the target type's range, then performing the cast. Additionally, we specifically clamp to the minimum and maximum values that are exactly representable in both the original and the target type, as per the WGSL spec[1]. Note that these may not be the same as the minimum and maximum values of the target type. We additionally must ensure we clamp in the same manner for conversions during const evaluation. Lastly, although not part of the WGSL spec, we do the same for casting from F64 and/or to I64 or U64. [1] https://www.w3.org/TR/WGSL/#floating-point-conversion
1 parent e037f4b commit 7a3923d

34 files changed

+3549
-1834
lines changed

naga/src/back/hlsl/help.rs

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ use core::fmt::Write;
3232
use super::{
3333
super::FunctionCtx,
3434
writer::{
35-
ABS_FUNCTION, DIV_FUNCTION, EXTRACT_BITS_FUNCTION, INSERT_BITS_FUNCTION, MOD_FUNCTION,
36-
NEG_FUNCTION,
35+
ABS_FUNCTION, DIV_FUNCTION, EXTRACT_BITS_FUNCTION, F2I32_FUNCTION, F2I64_FUNCTION,
36+
F2U32_FUNCTION, F2U64_FUNCTION, INSERT_BITS_FUNCTION, MOD_FUNCTION, NEG_FUNCTION,
3737
},
3838
BackendResult, WrappedType,
3939
};
@@ -97,6 +97,15 @@ pub(super) struct WrappedBinaryOp {
9797
pub(super) right_ty: (Option<crate::VectorSize>, crate::Scalar),
9898
}
9999

100+
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
101+
pub(super) struct WrappedCast {
102+
// This can only represent scalar or vector types. If we ever need to wrap
103+
// casts with other types, we'll need a better representation.
104+
pub(super) vector_size: Option<crate::VectorSize>,
105+
pub(super) src_scalar: crate::Scalar,
106+
pub(super) dst_scalar: crate::Scalar,
107+
}
108+
100109
/// HLSL backend requires its own `ImageQuery` enum.
101110
///
102111
/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
@@ -1355,6 +1364,97 @@ impl<W: Write> super::Writer<'_, W> {
13551364
Ok(())
13561365
}
13571366

1367+
fn write_wrapped_cast_functions(
1368+
&mut self,
1369+
module: &crate::Module,
1370+
func_ctx: &FunctionCtx,
1371+
) -> BackendResult {
1372+
for (_, expression) in func_ctx.expressions.iter() {
1373+
if let crate::Expression::As {
1374+
expr,
1375+
kind,
1376+
convert: Some(width),
1377+
} = *expression
1378+
{
1379+
// Avoid undefined behaviour when casting from a float to integer
1380+
// when the value is out of range for the target type. Additionally
1381+
// ensure we clamp to the correct value as per the WGSL spec.
1382+
//
1383+
// https://www.w3.org/TR/WGSL/#floating-point-conversion:
1384+
// * If X is exactly representable in the target type T, then the
1385+
// result is that value.
1386+
// * Otherwise, the result is the value in T closest to
1387+
// truncate(X) and also exactly representable in the original
1388+
// floating point type.
1389+
let src_ty = func_ctx.resolve_type(expr, &module.types);
1390+
let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
1391+
continue;
1392+
};
1393+
let dst_scalar = crate::Scalar { kind, width };
1394+
if src_scalar.kind != ScalarKind::Float
1395+
|| (dst_scalar.kind != ScalarKind::Sint && dst_scalar.kind != ScalarKind::Uint)
1396+
{
1397+
continue;
1398+
}
1399+
1400+
let wrapped = WrappedCast {
1401+
src_scalar,
1402+
vector_size,
1403+
dst_scalar,
1404+
};
1405+
if !self.wrapped.insert(WrappedType::Cast(wrapped)) {
1406+
continue;
1407+
}
1408+
1409+
let (src_ty, dst_ty) = match vector_size {
1410+
None => (
1411+
crate::TypeInner::Scalar(src_scalar),
1412+
crate::TypeInner::Scalar(dst_scalar),
1413+
),
1414+
Some(vector_size) => (
1415+
crate::TypeInner::Vector {
1416+
scalar: src_scalar,
1417+
size: vector_size,
1418+
},
1419+
crate::TypeInner::Vector {
1420+
scalar: dst_scalar,
1421+
size: vector_size,
1422+
},
1423+
),
1424+
};
1425+
let (min, max) =
1426+
crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
1427+
let cast_str = format!(
1428+
"{}{}",
1429+
dst_scalar.to_hlsl_str()?,
1430+
vector_size
1431+
.map(crate::common::vector_size_str)
1432+
.unwrap_or(""),
1433+
);
1434+
let fun_name = match dst_scalar {
1435+
crate::Scalar::I32 => F2I32_FUNCTION,
1436+
crate::Scalar::U32 => F2U32_FUNCTION,
1437+
crate::Scalar::I64 => F2I64_FUNCTION,
1438+
crate::Scalar::U64 => F2U64_FUNCTION,
1439+
_ => unreachable!(),
1440+
};
1441+
self.write_value_type(module, &dst_ty)?;
1442+
write!(self.out, " {fun_name}(")?;
1443+
self.write_value_type(module, &src_ty)?;
1444+
writeln!(self.out, " value) {{")?;
1445+
let level = crate::back::Level(1);
1446+
write!(self.out, "{level}return {cast_str}(clamp(value, ")?;
1447+
self.write_literal(min)?;
1448+
write!(self.out, ", ")?;
1449+
self.write_literal(max)?;
1450+
writeln!(self.out, "));",)?;
1451+
writeln!(self.out, "}}")?;
1452+
writeln!(self.out)?;
1453+
}
1454+
}
1455+
Ok(())
1456+
}
1457+
13581458
/// Helper function that writes various wrapped functions
13591459
pub(super) fn write_wrapped_functions(
13601460
&mut self,
@@ -1366,6 +1466,7 @@ impl<W: Write> super::Writer<'_, W> {
13661466
self.write_wrapped_binary_ops(module, func_ctx)?;
13671467
self.write_wrapped_expression_functions(module, func_ctx.expressions, Some(func_ctx))?;
13681468
self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?;
1469+
self.write_wrapped_cast_functions(module, func_ctx)?;
13691470

13701471
for (handle, _) in func_ctx.expressions.iter() {
13711472
match func_ctx.expressions[handle] {

naga/src/back/hlsl/keywords.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,10 @@ pub const RESERVED: &[&str] = &[
830830
super::writer::DIV_FUNCTION,
831831
super::writer::MOD_FUNCTION,
832832
super::writer::NEG_FUNCTION,
833+
super::writer::F2I32_FUNCTION,
834+
super::writer::F2U32_FUNCTION,
835+
super::writer::F2I64_FUNCTION,
836+
super::writer::F2U64_FUNCTION,
833837
];
834838

835839
// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254

naga/src/back/hlsl/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ enum WrappedType {
462462
Math(help::WrappedMath),
463463
UnaryOp(help::WrappedUnaryOp),
464464
BinaryOp(help::WrappedBinaryOp),
465+
Cast(help::WrappedCast),
465466
}
466467

467468
#[derive(Default)]

naga/src/back/hlsl/writer.rs

Lines changed: 92 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ pub(crate) const ABS_FUNCTION: &str = "naga_abs";
3838
pub(crate) const DIV_FUNCTION: &str = "naga_div";
3939
pub(crate) const MOD_FUNCTION: &str = "naga_mod";
4040
pub(crate) const NEG_FUNCTION: &str = "naga_neg";
41+
pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
42+
pub(crate) const F2U32_FUNCTION: &str = "naga_f2u32";
43+
pub(crate) const F2I64_FUNCTION: &str = "naga_f2i64";
44+
pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
4145

4246
struct EpStructMember {
4347
name: String,
@@ -2612,6 +2616,28 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
26122616
})
26132617
}
26142618

2619+
pub(super) fn write_literal(&mut self, literal: crate::Literal) -> BackendResult {
2620+
match literal {
2621+
crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2622+
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2623+
crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2624+
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2625+
// HLSL has no suffix for explicit i32 literals, but not using any suffix
2626+
// makes the type ambiguous which prevents overload resolution from
2627+
// working. So we explicitly use the int() constructor syntax.
2628+
crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2629+
crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2630+
crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2631+
crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2632+
crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2633+
return Err(Error::Custom(
2634+
"Abstract types should not appear in IR presented to backends".into(),
2635+
));
2636+
}
2637+
}
2638+
Ok(())
2639+
}
2640+
26152641
fn write_possibly_const_expression<E>(
26162642
&mut self,
26172643
module: &Module,
@@ -2625,26 +2651,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
26252651
use crate::Expression;
26262652

26272653
match expressions[expr] {
2628-
Expression::Literal(literal) => match literal {
2629-
// Floats are written using `Debug` instead of `Display` because it always appends the
2630-
// decimal part even it's zero
2631-
crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
2632-
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
2633-
crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
2634-
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
2635-
// HLSL has no suffix for explicit i32 literals, but not using any suffix
2636-
// makes the type ambiguous which prevents overload resolution from
2637-
// working. So we explicitly use the int() constructor syntax.
2638-
crate::Literal::I32(value) => write!(self.out, "int({value})")?,
2639-
crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
2640-
crate::Literal::I64(value) => write!(self.out, "{value}L")?,
2641-
crate::Literal::Bool(value) => write!(self.out, "{value}")?,
2642-
crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
2643-
return Err(Error::Custom(
2644-
"Abstract types should not appear in IR presented to backends".into(),
2645-
));
2646-
}
2647-
},
2654+
Expression::Literal(literal) => {
2655+
self.write_literal(literal)?;
2656+
}
26482657
Expression::Constant(handle) => {
26492658
let constant = &module.constants[handle];
26502659
if constant.name.is_some() {
@@ -3320,53 +3329,72 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
33203329
convert,
33213330
} => {
33223331
let inner = func_ctx.resolve_type(expr, &module.types);
3323-
let close_paren = match convert {
3324-
Some(dst_width) => {
3325-
let scalar = Scalar {
3326-
kind,
3327-
width: dst_width,
3328-
};
3329-
match *inner {
3330-
TypeInner::Vector { size, .. } => {
3331-
write!(
3332-
self.out,
3333-
"{}{}(",
3334-
scalar.to_hlsl_str()?,
3335-
common::vector_size_str(size)
3336-
)?;
3337-
}
3338-
TypeInner::Scalar(_) => {
3339-
write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3340-
}
3341-
TypeInner::Matrix { columns, rows, .. } => {
3342-
write!(
3343-
self.out,
3344-
"{}{}x{}(",
3345-
scalar.to_hlsl_str()?,
3346-
common::vector_size_str(columns),
3347-
common::vector_size_str(rows)
3348-
)?;
3349-
}
3350-
_ => {
3351-
return Err(Error::Unimplemented(format!(
3352-
"write_expr expression::as {inner:?}"
3353-
)));
3354-
}
3355-
};
3356-
true
3357-
}
3358-
None => {
3359-
if inner.scalar_width() == Some(8) {
3360-
false
3361-
} else {
3362-
write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3332+
if inner.scalar_kind() == Some(ScalarKind::Float)
3333+
&& (kind == ScalarKind::Sint || kind == ScalarKind::Uint)
3334+
&& convert.is_some()
3335+
{
3336+
// Use helper functions for float to int casts in order to
3337+
// avoid undefined behaviour when value is out of range for
3338+
// the target type.
3339+
let fun_name = match (kind, convert) {
3340+
(ScalarKind::Sint, Some(4)) => F2I32_FUNCTION,
3341+
(ScalarKind::Uint, Some(4)) => F2U32_FUNCTION,
3342+
(ScalarKind::Sint, Some(8)) => F2I64_FUNCTION,
3343+
(ScalarKind::Uint, Some(8)) => F2U64_FUNCTION,
3344+
_ => unreachable!(),
3345+
};
3346+
write!(self.out, "{fun_name}(")?;
3347+
self.write_expr(module, expr, func_ctx)?;
3348+
write!(self.out, ")")?;
3349+
} else {
3350+
let close_paren = match convert {
3351+
Some(dst_width) => {
3352+
let scalar = Scalar {
3353+
kind,
3354+
width: dst_width,
3355+
};
3356+
match *inner {
3357+
TypeInner::Vector { size, .. } => {
3358+
write!(
3359+
self.out,
3360+
"{}{}(",
3361+
scalar.to_hlsl_str()?,
3362+
common::vector_size_str(size)
3363+
)?;
3364+
}
3365+
TypeInner::Scalar(_) => {
3366+
write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
3367+
}
3368+
TypeInner::Matrix { columns, rows, .. } => {
3369+
write!(
3370+
self.out,
3371+
"{}{}x{}(",
3372+
scalar.to_hlsl_str()?,
3373+
common::vector_size_str(columns),
3374+
common::vector_size_str(rows)
3375+
)?;
3376+
}
3377+
_ => {
3378+
return Err(Error::Unimplemented(format!(
3379+
"write_expr expression::as {inner:?}"
3380+
)));
3381+
}
3382+
};
33633383
true
33643384
}
3385+
None => {
3386+
if inner.scalar_width() == Some(8) {
3387+
false
3388+
} else {
3389+
write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
3390+
true
3391+
}
3392+
}
3393+
};
3394+
self.write_expr(module, expr, func_ctx)?;
3395+
if close_paren {
3396+
write!(self.out, ")")?;
33653397
}
3366-
};
3367-
self.write_expr(module, expr, func_ctx)?;
3368-
if close_paren {
3369-
write!(self.out, ")")?;
33703398
}
33713399
}
33723400
Expression::Math {

naga/src/back/msl/keywords.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,10 @@ pub const RESERVED: &[&str] = &[
349349
super::writer::DIV_FUNCTION,
350350
super::writer::MOD_FUNCTION,
351351
super::writer::NEG_FUNCTION,
352+
super::writer::F2I32_FUNCTION,
353+
super::writer::F2U32_FUNCTION,
354+
super::writer::F2I64_FUNCTION,
355+
super::writer::F2U64_FUNCTION,
352356
super::writer::ARGUMENT_BUFFER_WRAPPER_STRUCT,
353357
];
354358

0 commit comments

Comments
 (0)