Skip to content

Commit

Permalink
added error-compensated atomic scatter-reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Jan 18, 2023
1 parent 794ecba commit c7225f6
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ext/drjit-core
15 changes: 15 additions & 0 deletions include/drjit/array_router.h
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,21 @@ void scatter_reduce(ReduceOp op, Target &&target, const Value &value,
}
}

template <typename Target, typename Value, typename Index>
void scatter_reduce_kahan(Target &&target_1, Target &&target_2,
const Value &value, const Index &index,
const mask_t<Value> &mask = true) {
static_assert(
is_jit_v<Target> &&
is_jit_v<Value> &&
is_jit_v<Index> &&
array_depth_v<Value> == array_depth_v<Index> &&
array_depth_v<Value> == 1,
"Only flat JIT arrays are supported at the moment");

value.scatter_reduce_kahan_(target_1, target_2, index, mask);
}

template <typename T, typename TargetType>
decltype(auto) migrate(const T &value, TargetType target) {
static_assert(std::is_enum_v<TargetType>);
Expand Down
21 changes: 21 additions & 0 deletions include/drjit/autodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,27 @@ struct DiffArray : ArrayBase<value_t<Type_>, is_mask_v<Type_>, DiffArray<Type_>>
}
}

void scatter_reduce_kahan_(DiffArray &dst_1, DiffArray &dst_2, const IndexType &offset,
const MaskType &mask = true) const {
if constexpr (std::is_scalar_v<Type>) {
(void) dst_1; (void) dst_2; (void) offset; (void) mask;
drjit_raise("Array scatter_reduce operation not supported for scalar array type.");
} else {
scatter_reduce_kahan(dst_1.m_value, dst_2.m_value, m_value,
offset.m_value, mask.m_value);
if constexpr (IsEnabled) {
if (m_index) { // safe to ignore dst_1.m_index in the case of scatter_reduce
uint32_t index = detail::ad_new_scatter<Type>(
"scatter_reduce_kahan", width(dst_1), ReduceOp::Add,
m_index, dst_1.m_index, offset.m_value, mask.m_value,
false);
detail::ad_dec_ref<Type>(dst_1.m_index);
dst_1.m_index = index;
}
}
}
}

template <bool>
static DiffArray gather_(const void *src, const IndexType &offset,
const MaskType &mask = true) {
Expand Down
9 changes: 9 additions & 0 deletions include/drjit/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,15 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
mask.index(), op));
}

template <typename Index, typename Mask>
void scatter_reduce_kahan_(Derived &dst_1, Derived &dst_2,
const Index &index, const Mask &mask) const {
static_assert(
std::is_same_v<detached_t<Mask>, detached_t<mask_t<Derived>>>);
jit_var_scatter_reduce_kahan(dst_1.index_ptr(), dst_2.index_ptr(),
m_index, index.index(), mask.index());
}

//! @}
// -----------------------------------------------------------------------

Expand Down

0 comments on commit c7225f6

Please sign in to comment.