Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
remove class template param from Iterate
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Apr 18, 2022
1 parent 99ec7c0 commit 34e465b
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions cub/block/block_discontinuity.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ private:
};

/// Templated unrolling of item comparison (inductive case)
template <int ITERATION, int MAX_ITERATIONS>
struct Iterate
{
// Head flags
Expand All @@ -183,7 +182,7 @@ private:
FlagOp flag_op) ///< [in] Binary boolean flag predicate
{
#pragma unroll
for (int i=ITERATION; i < MAX_ITERATIONS; ++i) {
for (int i = 1; i < ITEMS_PER_THREAD; ++i) {
preds[i] = input[i - 1];
flags[i] = ApplyOp<FlagOp>::FlagT(
flag_op,
Expand All @@ -205,7 +204,7 @@ private:
FlagOp flag_op) ///< [in] Binary boolean flag predicate
{
#pragma unroll
for (int i=ITERATION; i < MAX_ITERATIONS; ++i) {
for (int i = 0; i < ITEMS_PER_THREAD - 1; ++i) {
flags[i] = ApplyOp<FlagOp>::FlagT(
flag_op,
input[i],
Expand Down Expand Up @@ -295,7 +294,7 @@ public:
}

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
}

template <
Expand All @@ -322,7 +321,7 @@ public:
head_flags[0] = ApplyOp<FlagOp>::FlagT(flag_op, preds[0], input[0], linear_tid * ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
}

#endif // DOXYGEN_SHOULD_SKIP_THIS
Expand Down Expand Up @@ -543,7 +542,7 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -630,7 +629,7 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -745,10 +744,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -863,10 +862,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -981,10 +980,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down Expand Up @@ -1103,10 +1102,10 @@ public:
(linear_tid * ITEMS_PER_THREAD) + ITEMS_PER_THREAD);

// Set head_flags for remaining items
Iterate<1, ITEMS_PER_THREAD>::FlagHeads(linear_tid, head_flags, input, preds, flag_op);
Iterate::FlagHeads(linear_tid, head_flags, input, preds, flag_op);

// Set tail_flags for remaining items
Iterate<0, ITEMS_PER_THREAD - 1>::FlagTails(linear_tid, tail_flags, input, flag_op);
Iterate::FlagTails(linear_tid, tail_flags, input, flag_op);
}


Expand Down

0 comments on commit 34e465b

Please sign in to comment.