diff --git a/ext/drjit-core b/ext/drjit-core index 3018bb5c7..15296bfda 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit 3018bb5c7ff007e8637da17d824fa557ccb4abda +Subproject commit 15296bfdaf2348b6004d57a745a0bc1506e4475f diff --git a/include/drjit/array_router.h b/include/drjit/array_router.h index 7cba18d79..c15f74921 100644 --- a/include/drjit/array_router.h +++ b/include/drjit/array_router.h @@ -1164,6 +1164,21 @@ void scatter_reduce(ReduceOp op, Target &&target, const Value &value, } } +template +void scatter_reduce_kahan(Target &&target_1, Target &&target_2, + const Value &value, const Index &index, + const mask_t &mask = true) { + static_assert( + is_jit_v && + is_jit_v && + is_jit_v && + array_depth_v == array_depth_v && + array_depth_v == 1, + "Only flat JIT arrays are supported at the moment"); + + value.scatter_reduce_kahan_(target_1, target_2, index, mask); +} + template decltype(auto) migrate(const T &value, TargetType target) { static_assert(std::is_enum_v); diff --git a/include/drjit/autodiff.h b/include/drjit/autodiff.h index 203af9097..858bc0764 100644 --- a/include/drjit/autodiff.h +++ b/include/drjit/autodiff.h @@ -1409,6 +1409,27 @@ struct DiffArray : ArrayBase, is_mask_v, DiffArray> } } + void scatter_reduce_kahan_(DiffArray &dst_1, DiffArray &dst_2, const IndexType &offset, + const MaskType &mask = true) const { + if constexpr (std::is_scalar_v) { + (void) dst_1; (void) dst_2; (void) offset; (void) mask; + drjit_raise("Array scatter_reduce operation not supported for scalar array type."); + } else { + scatter_reduce_kahan(dst_1.m_value, dst_2.m_value, m_value, + offset.m_value, mask.m_value); + if constexpr (IsEnabled) { + if (m_index) { // safe to ignore dst_1.m_index in the case of scatter_reduce + uint32_t index = detail::ad_new_scatter( + "scatter_reduce_kahan", width(dst_1), ReduceOp::Add, + m_index, dst_1.m_index, offset.m_value, mask.m_value, + false); + detail::ad_dec_ref(dst_1.m_index); + dst_1.m_index = index; + } + } + } + } + template static DiffArray gather_(const void *src, const IndexType &offset, const MaskType &mask = true) { diff --git a/include/drjit/jit.h b/include/drjit/jit.h index 3212536ac..98dad9a37 100644 --- a/include/drjit/jit.h +++ b/include/drjit/jit.h @@ -501,6 +501,15 @@ struct JitArray : ArrayBase, Derived_> { mask.index(), op)); } + template + void scatter_reduce_kahan_(Derived &dst_1, Derived &dst_2, + const Index &index, const Mask &mask) const { + static_assert( + std::is_same_v, detached_t>>); + jit_var_scatter_reduce_kahan(dst_1.index_ptr(), dst_2.index_ptr(), + m_index, index.index(), mask.index()); + } + //! @} // -----------------------------------------------------------------------