@@ -30,6 +30,14 @@ namespace oneapi {
3030
3131namespace detail {
3232
33+ template <class FunctorTy >
34+ event withAuxHandler (std::shared_ptr<detail::queue_impl> Queue, bool IsHost,
35+ FunctorTy Func) {
36+ handler AuxHandler (Queue, IsHost);
37+ Func (AuxHandler);
38+ return AuxHandler.finalize ();
39+ }
40+
3341using cl::sycl::detail::bool_constant;
3442using cl::sycl::detail::enable_if_t ;
3543using cl::sycl::detail::queue_impl;
@@ -2434,6 +2442,7 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
24342442
24352443 bool Pow2WG = (WGSize & (WGSize - 1 )) == 0 ;
24362444 bool IsOneWG = NWorkGroups == 1 ;
2445+ bool HasUniformWG = Pow2WG && (NWorkGroups * WGSize == NWorkItems);
24372446
24382447 // Like reduCGFuncImpl, we also have to split out scalar and array reductions
24392448 IsScalarReduction ScalarPredicate;
@@ -2442,28 +2451,27 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
24422451 IsArrayReduction ArrayPredicate;
24432452 auto ArrayIs = filterSequence<Reductions...>(ArrayPredicate, ReduIndices);
24442453
2454+ size_t LocalAccSize = WGSize + (HasUniformWG ? 0 : 1 );
2455+ auto LocalAccsTuple =
2456+ createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
2457+ auto InAccsTuple =
2458+ getReadAccsToPreviousPartialReds (CGH, ReduTuple, ReduIndices);
2459+
2460+ auto IdentitiesTuple = getReduIdentities (ReduTuple, ReduIndices);
2461+ auto BOPsTuple = getReduBOPs (ReduTuple, ReduIndices);
2462+ auto InitToIdentityProps =
2463+ getInitToIdentityProperties (ReduTuple, ReduIndices);
2464+
24452465 // Predicate/OutAccsTuple below have different type depending on us having
24462466 // just a single WG or multiple WGs. Use this lambda to avoid code
24472467 // duplication.
24482468 auto Rest = [&](auto Predicate, auto OutAccsTuple) {
24492469 auto AccReduIndices = filterSequence<Reductions...>(Predicate, ReduIndices);
24502470 associateReduAccsWithHandler (CGH, ReduTuple, AccReduIndices);
2451-
2452- size_t LocalAccSize = WGSize + (Pow2WG ? 0 : 1 );
2453- auto LocalAccsTuple =
2454- createReduLocalAccs<Reductions...>(LocalAccSize, CGH, ReduIndices);
2455- auto InAccsTuple =
2456- getReadAccsToPreviousPartialReds (CGH, ReduTuple, ReduIndices);
2457-
2458- auto IdentitiesTuple = getReduIdentities (ReduTuple, ReduIndices);
2459- auto BOPsTuple = getReduBOPs (ReduTuple, ReduIndices);
2460- auto InitToIdentityProps =
2461- getInitToIdentityProperties (ReduTuple, ReduIndices);
2462-
24632471 using Name = __sycl_reduction_kernel<reduction::aux_krn::Multi, KernelName,
24642472 decltype (OutAccsTuple)>;
24652473 // TODO: Opportunity to parallelize across number of elements
2466- range<1 > GlobalRange = {Pow2WG ? NWorkItems : NWorkGroups * WGSize};
2474+ range<1 > GlobalRange = {HasUniformWG ? NWorkItems : NWorkGroups * WGSize};
24672475 nd_range<1 > Range{GlobalRange, range<1 >(WGSize)};
24682476 CGH.parallel_for <Name>(Range, [=](nd_item<1 > NDIt) {
24692477 size_t WGSize = NDIt.get_local_range ().size ();
@@ -2472,12 +2480,12 @@ size_t reduAuxCGFunc(handler &CGH, size_t NWorkItems, size_t MaxWGSize,
24722480
24732481 // Handle scalar and array reductions
24742482 reduAuxCGFuncImplScalar<Reductions...>(
2475- Pow2WG , IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, LocalAccsTuple ,
2476- InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
2483+ HasUniformWG , IsOneWG, NDIt, LID, GID, NWorkItems, WGSize,
2484+ LocalAccsTuple, InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
24772485 InitToIdentityProps, ScalarIs);
24782486 reduAuxCGFuncImplArray<Reductions...>(
2479- Pow2WG , IsOneWG, NDIt, LID, GID, NWorkItems, WGSize, LocalAccsTuple ,
2480- InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
2487+ HasUniformWG , IsOneWG, NDIt, LID, GID, NWorkItems, WGSize,
2488+ LocalAccsTuple, InAccsTuple, OutAccsTuple, IdentitiesTuple, BOPsTuple,
24812489 InitToIdentityProps, ArrayIs);
24822490 });
24832491 };
@@ -2504,7 +2512,7 @@ void reduSaveFinalResultToUserMemHelper(
25042512 if constexpr (!Reduction::is_usm) {
25052513 if (Redu.hasUserDiscardWriteAccessor ()) {
25062514 event CopyEvent =
2507- handler:: withAuxHandler (Queue, IsHost, [&](handler &CopyHandler) {
2515+ withAuxHandler (Queue, IsHost, [&](handler &CopyHandler) {
25082516 auto InAcc = Redu.getReadAccToPreviousPartialReds (CopyHandler);
25092517 auto OutAcc = Redu.getUserDiscardWriteAccessor ();
25102518 Redu.associateWithHandler (CopyHandler);
0 commit comments