Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #450 from kshitij12345/iterate/pragma-unroll
Browse files Browse the repository at this point in the history
update Iterate (BlockDiscontinuity) to use pragma unroll
  • Loading branch information
alliepiper authored Apr 25, 2022
2 parents 00869c8 + 34e465b commit d2c014c
Showing 1 changed file with 29 additions and 60 deletions.
89 changes: 29 additions & 60 deletions cub/block/block_discontinuity.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ private:
};

/// Templated unrolling of item comparison (inductive case)
template <int ITERATION, int MAX_ITERATIONS>
struct Iterate
{
// Head flags
Expand All @@ -182,15 +181,15 @@ private:
T (&preds)[ITEMS_PER_THREAD], ///< [out] Calling thread's predecessor items
FlagOp flag_op) ///< [in] Binary boolean flag predicate
{
preds[ITERATION] = input[ITERATION - 1];

flags[ITERATION] = ApplyOp<FlagOp>::FlagT(
flag_op,
preds[ITERATION],
input[ITERATION],
(linear_tid * ITEMS_PER_THREAD) + ITERATION);

Iterate<ITERATION + 1, MAX_ITERATIONS>::FlagHeads(linear_tid, flags, input, preds, flag_op);
#pragma unroll
for (int i = 1; i < ITEMS_PER_THREAD; ++i) {
preds[i] = input[i - 1];
flags[i] = ApplyOp<FlagOp>::FlagT(
flag_op,
preds[i],
input[i],
(linear_tid * ITEMS_PER_THREAD) + i);
}
}

// Tail flags
Expand All @@ -204,48 +203,18 @@ private:
T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
FlagOp flag_op) ///< [in] Binary boolean flag predicate
{
flags[ITERATION] = ApplyOp<FlagOp>::FlagT(
flag_op,
input[ITERATION],
input[ITERATION + 1],
(linear_tid * ITEMS_PER_THREAD) + ITERATION + 1);

Iterate<ITERATION + 1, MAX_ITERATIONS>::FlagTails(linear_tid, flags, input, flag_op);
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD - 1; ++i) {
flags[i] = ApplyOp<FlagOp>::FlagT(
flag_op,
input[i],
input[i + 1],
(linear_tid * ITEMS_PER_THREAD) + i + 1);
}
}

};

/// Templated unrolling of item comparison (termination case)
template <int MAX_ITERATIONS>
struct Iterate<MAX_ITERATIONS, MAX_ITERATIONS>
{
// Head flags
template <
int ITEMS_PER_THREAD,
typename FlagT,
typename FlagOp>
static __device__ __forceinline__ void FlagHeads(
int /*linear_tid*/,
FlagT (&/*flags*/)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags
T (&/*input*/)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
T (&/*preds*/)[ITEMS_PER_THREAD], ///< [out] Calling thread's predecessor items
FlagOp /*flag_op*/) ///< [in] Binary boolean flag predicate
{}

// Tail flags
template <
int ITEMS_PER_THREAD,
typename FlagT,
typename FlagOp>
static __device__ __forceinline__ void FlagTails(
int /*linear_tid*/,
FlagT (&/*flags*/)[ITEMS_PER_THREAD], ///< [out] Calling thread's discontinuity head_flags
T (&/*input*/)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
FlagOp /*flag_op*/) ///< [in] Binary boolean flag predicate
{}
};


/******************************************************************************
* Thread fields
******************************************************************************/
Expand Down Expand Up @@ -325,7 +294,7 @@ public:
}

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
}

template <
Expand All @@ -352,7 +321,7 @@ public:
head_flags[0] = ApplyOp<FlagOp>::FlagT(flag_op, preds[0], input[0], linear_tid * ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
}

#endif // DOXYGEN_SHOULD_SKIP_THIS
Expand Down Expand Up @@ -573,7 +542,7 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -660,7 +629,7 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -775,10 +744,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -893,10 +862,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -1011,10 +980,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -1133,10 +1102,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down

0 comments on commit d2c014c

Please sign in to comment.