Skip to content

Commit 3a4a65b

Browse files
committed
Pass policy into param functions by value
This makes this use normal overloading resolution instead of trying to take everything through the RAJA::detail namespace. Theoretically this should allow a user could make their own param.
1 parent 30a8e7b commit 3a4a65b

25 files changed

+177
-135
lines changed

include/RAJA/pattern/params/forall.hpp

+23-20
Original file line numberDiff line numberDiff line change
@@ -43,45 +43,45 @@ struct ForallParamPack
4343
private:
4444
// Init
4545
template<typename EXEC_POL, camp::idx_t... Seq, typename... Args>
46-
static constexpr void params_init(EXEC_POL,
46+
static constexpr void params_init(EXEC_POL const& pol,
4747
camp::idx_seq<Seq...>,
4848
ForallParamPack& f_params,
4949
Args&&... args)
5050
{
51-
CAMP_EXPAND(param_init<EXEC_POL>(camp::get<Seq>(f_params.param_tup),
52-
std::forward<Args>(args)...));
51+
CAMP_EXPAND(param_init(pol, camp::get<Seq>(f_params.param_tup),
52+
std::forward<Args>(args)...));
5353
}
5454

5555
// Combine
5656
template<typename EXEC_POL, camp::idx_t... Seq>
5757
RAJA_HOST_DEVICE static constexpr void params_combine(
58-
EXEC_POL,
58+
EXEC_POL const& pol,
5959
camp::idx_seq<Seq...>,
6060
ForallParamPack& out,
6161
const ForallParamPack& in)
6262
{
63-
CAMP_EXPAND(param_combine<EXEC_POL>(camp::get<Seq>(out.param_tup),
64-
camp::get<Seq>(in.param_tup)));
63+
CAMP_EXPAND(param_combine(pol, camp::get<Seq>(out.param_tup),
64+
camp::get<Seq>(in.param_tup)));
6565
}
6666

6767
template<typename EXEC_POL, camp::idx_t... Seq>
6868
RAJA_HOST_DEVICE static constexpr void params_combine(
69-
EXEC_POL,
69+
EXEC_POL const& pol,
7070
camp::idx_seq<Seq...>,
7171
ForallParamPack& f_params)
7272
{
73-
CAMP_EXPAND(param_combine<EXEC_POL>(camp::get<Seq>(f_params.param_tup)));
73+
CAMP_EXPAND(param_combine(pol, camp::get<Seq>(f_params.param_tup)));
7474
}
7575

7676
// Resolve
7777
template<typename EXEC_POL, camp::idx_t... Seq, typename... Args>
78-
static constexpr void params_resolve(EXEC_POL,
78+
static constexpr void params_resolve(EXEC_POL const& pol,
7979
camp::idx_seq<Seq...>,
8080
ForallParamPack& f_params,
8181
Args&&... args)
8282
{
83-
CAMP_EXPAND(param_resolve<EXEC_POL>(camp::get<Seq>(f_params.param_tup),
84-
std::forward<Args>(args)...));
83+
CAMP_EXPAND(param_resolve(pol, camp::get<Seq>(f_params.param_tup),
84+
std::forward<Args>(args)...));
8585
}
8686

8787
// Used to construct the argument TYPES that will be invoked with the lambda.
@@ -155,32 +155,35 @@ struct ParamMultiplexer
155155
typename... Params,
156156
typename... Args,
157157
typename FP = ForallParamPack<Params...>>
158-
static void constexpr params_init(ForallParamPack<Params...>& f_params,
159-
Args&&... args)
158+
static void constexpr params_init(EXEC_POL const& pol,
159+
ForallParamPack<Params...>& f_params,
160+
Args&&... args)
160161
{
161-
FP::params_init(EXEC_POL(), typename FP::params_seq(), f_params,
162+
FP::params_init(pol, typename FP::params_seq(), f_params,
162163
std::forward<Args>(args)...);
163164
}
164165

165166
template<typename EXEC_POL,
166167
typename... Params,
167168
typename... Args,
168169
typename FP = ForallParamPack<Params...>>
169-
static void constexpr params_combine(ForallParamPack<Params...>& f_params,
170-
Args&&... args)
170+
static void constexpr params_combine(EXEC_POL const& pol,
171+
ForallParamPack<Params...>& f_params,
172+
Args&&... args)
171173
{
172-
FP::params_combine(EXEC_POL(), typename FP::params_seq(), f_params,
174+
FP::params_combine(pol, typename FP::params_seq(), f_params,
173175
std::forward<Args>(args)...);
174176
}
175177

176178
template<typename EXEC_POL,
177179
typename... Params,
178180
typename... Args,
179181
typename FP = ForallParamPack<Params...>>
180-
static void constexpr params_resolve(ForallParamPack<Params...>& f_params,
181-
Args&&... args)
182+
static void constexpr params_resolve(EXEC_POL const& pol,
183+
ForallParamPack<Params...>& f_params,
184+
Args&&... args)
182185
{
183-
FP::params_resolve(EXEC_POL(), typename FP::params_seq(), f_params,
186+
FP::params_resolve(pol, typename FP::params_seq(), f_params,
184187
std::forward<Args>(args)...);
185188
}
186189
};

include/RAJA/policy/cuda/forall.hpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ __launch_bounds__(BlockSize, BlocksPerSM) __global__
445445
{
446446
RAJA::expt::invoke_body(f_params, body, idx[ii]);
447447
}
448-
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
448+
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
449449
}
450450

451451
///
@@ -474,7 +474,7 @@ __global__ void forallp_cuda_kernel(LOOP_BODY loop_body,
474474
{
475475
RAJA::expt::invoke_body(f_params, body, idx[ii]);
476476
}
477-
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
477+
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
478478
}
479479

480480
template<
@@ -565,7 +565,7 @@ __launch_bounds__(BlockSize, BlocksPerSM) __global__
565565
{
566566
RAJA::expt::invoke_body(f_params, body, idx[ii]);
567567
}
568-
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
568+
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
569569
}
570570

571571
///
@@ -597,7 +597,7 @@ __global__ void forallp_cuda_kernel(LOOP_BODY loop_body,
597597
{
598598
RAJA::expt::invoke_body(f_params, body, idx[ii]);
599599
}
600-
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
600+
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
601601
}
602602

603603
} // namespace impl
@@ -712,7 +712,7 @@ forall_impl(resources::Cuda cuda_res,
712712
IterationGetter,
713713
Concretizer,
714714
BlocksPerSM,
715-
Async> const&,
715+
Async> const& pol,
716716
Iterable&& iter,
717717
LoopBody&& loop_body,
718718
ForallParam f_params)
@@ -764,7 +764,7 @@ forall_impl(resources::Cuda cuda_res,
764764
launch_info.res = cuda_res;
765765

766766
{
767-
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(f_params, launch_info);
767+
RAJA::expt::ParamMultiplexer::params_init(pol, f_params, launch_info);
768768

769769
//
770770
// Privatize the loop_body, using make_launch_body to setup reductions
@@ -781,7 +781,7 @@ forall_impl(resources::Cuda cuda_res,
781781
RAJA::cuda::launch(func, dims.blocks, dims.threads, args, shmem, cuda_res,
782782
Async);
783783

784-
RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(f_params, launch_info);
784+
RAJA::expt::ParamMultiplexer::params_resolve(pol, f_params, launch_info);
785785
}
786786

787787
RAJA_FT_END;

include/RAJA/policy/cuda/launch.hpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ __global__ void launch_new_reduce_global_fcn(BODY body_in,
6161
RAJA::expt::invoke_body(reduce_params, body, ctx);
6262

6363
// Using a flatten global policy as we may use all dimensions
64-
RAJA::expt::ParamMultiplexer::params_combine<RAJA::cuda_flatten_global_xyz_direct>(
64+
RAJA::expt::ParamMultiplexer::params_combine(RAJA::cuda_flatten_global_xyz_direct{},
6565
reduce_params);
6666
}
6767

@@ -186,7 +186,7 @@ struct LaunchExecute<
186186
{
187187
using EXEC_POL = RAJA::policy::cuda::cuda_launch_explicit_t<
188188
async, named_usage::unspecified, named_usage::unspecified>;
189-
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(launch_reducers,
189+
RAJA::expt::ParamMultiplexer::params_init(EXEC_POL{}, launch_reducers,
190190
launch_info);
191191

192192

@@ -204,7 +204,7 @@ struct LaunchExecute<
204204
RAJA::cuda::launch(func, gridSize, blockSize, args, shared_mem_size,
205205
cuda_res, async, kernel_name);
206206

207-
RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(launch_reducers,
207+
RAJA::expt::ParamMultiplexer::params_resolve(EXEC_POL{}, launch_reducers,
208208
launch_info);
209209
}
210210

@@ -253,7 +253,7 @@ __launch_bounds__(num_threads, BLOCKS_PER_SM) __global__
253253
RAJA::expt::invoke_body(reduce_params, body, ctx);
254254

255255
// Using a flatten global policy as we may use all dimensions
256-
RAJA::expt::ParamMultiplexer::params_combine<RAJA::cuda_flatten_global_xyz_direct>(
256+
RAJA::expt::ParamMultiplexer::params_combine(RAJA::cuda_flatten_global_xyz_direct{},
257257
reduce_params);
258258
}
259259

@@ -378,11 +378,11 @@ struct LaunchExecute<
378378
launch_info.dynamic_smem = &shared_mem_size;
379379
launch_info.res = cuda_res;
380380
{
381-
381+
// Use a generic block size policy here to match that used in params_combine
382382
using EXEC_POL =
383-
RAJA::policy::cuda::cuda_launch_explicit_t<async, nthreads,
384-
BLOCKS_PER_SM>;
385-
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(launch_reducers,
383+
RAJA::policy::cuda::cuda_launch_explicit_t<
384+
async, named_usage::unspecified, named_usage::unspecified>;
385+
RAJA::expt::ParamMultiplexer::params_init(EXEC_POL{}, launch_reducers,
386386
launch_info);
387387

388388
//
@@ -399,7 +399,7 @@ struct LaunchExecute<
399399
RAJA::cuda::launch(func, gridSize, blockSize, args, shared_mem_size,
400400
cuda_res, async, kernel_name);
401401

402-
RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(launch_reducers,
402+
RAJA::expt::ParamMultiplexer::params_resolve(EXEC_POL{}, launch_reducers,
403403
launch_info);
404404
}
405405

include/RAJA/policy/cuda/params/reduce.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace detail
2020
// Init
2121
template<typename EXEC_POL, typename OP, typename T, typename VOp>
2222
camp::concepts::enable_if<type_traits::is_cuda_policy<EXEC_POL>> param_init(
23+
EXEC_POL const&,
2324
Reducer<OP, T, VOp>& red,
2425
RAJA::cuda::detail::cudaInfo& ci)
2526
{
@@ -34,7 +35,8 @@ camp::concepts::enable_if<type_traits::is_cuda_policy<EXEC_POL>> param_init(
3435
template<typename EXEC_POL, typename OP, typename T, typename VOp>
3536
RAJA_HOST_DEVICE camp::concepts::enable_if<
3637
type_traits::is_cuda_policy<EXEC_POL>>
37-
param_combine(Reducer<OP, T, VOp>& red)
38+
param_combine(EXEC_POL const&,
39+
Reducer<OP, T, VOp>& red)
3840
{
3941
RAJA::cuda::impl::expt::grid_reduce<typename EXEC_POL::IterationGetter, OP>(
4042
red.devicetarget, red.getVal(), red.device_mem, red.device_count);
@@ -43,6 +45,7 @@ param_combine(Reducer<OP, T, VOp>& red)
4345
// Resolve
4446
template<typename EXEC_POL, typename OP, typename T, typename VOp>
4547
camp::concepts::enable_if<type_traits::is_cuda_policy<EXEC_POL>> param_resolve(
48+
EXEC_POL const&,
4649
Reducer<OP, T, VOp>& red,
4750
RAJA::cuda::detail::cudaInfo& ci)
4851
{

include/RAJA/policy/hip/forall.hpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ __launch_bounds__(BlockSize, 1) __global__
443443
{
444444
RAJA::expt::invoke_body(f_params, body, idx[ii]);
445445
}
446-
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
446+
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
447447
}
448448

449449
///
@@ -471,7 +471,7 @@ __global__ void forallp_hip_kernel(LOOP_BODY loop_body,
471471
{
472472
RAJA::expt::invoke_body(f_params, body, idx[ii]);
473473
}
474-
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
474+
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
475475
}
476476

477477
template<
@@ -559,7 +559,7 @@ __launch_bounds__(BlockSize, 1) __global__
559559
{
560560
RAJA::expt::invoke_body(f_params, body, idx[ii]);
561561
}
562-
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
562+
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
563563
}
564564

565565
///
@@ -590,7 +590,7 @@ __global__ void forallp_hip_kernel(LOOP_BODY loop_body,
590590
{
591591
RAJA::expt::invoke_body(f_params, body, idx[ii]);
592592
}
593-
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
593+
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
594594
}
595595

596596
} // namespace impl
@@ -699,7 +699,7 @@ RAJA_INLINE concepts::enable_if_t<
699699
forall_impl(
700700
resources::Hip hip_res,
701701
::RAJA::policy::hip::
702-
hip_exec<IterationMapping, IterationGetter, Concretizer, Async> const&,
702+
hip_exec<IterationMapping, IterationGetter, Concretizer, Async> const& pol,
703703
Iterable&& iter,
704704
LoopBody&& loop_body,
705705
ForallParam f_params)
@@ -751,7 +751,7 @@ forall_impl(
751751
launch_info.res = hip_res;
752752

753753
{
754-
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(f_params, launch_info);
754+
RAJA::expt::ParamMultiplexer::params_init(pol, f_params, launch_info);
755755

756756
//
757757
// Privatize the loop_body, using make_launch_body to setup reductions
@@ -768,7 +768,7 @@ forall_impl(
768768
RAJA::hip::launch(func, dims.blocks, dims.threads, args, shmem, hip_res,
769769
Async);
770770

771-
RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(f_params, launch_info);
771+
RAJA::expt::ParamMultiplexer::params_resolve(pol, f_params, launch_info);
772772
}
773773

774774
RAJA_FT_END;

include/RAJA/policy/hip/launch.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ __global__ void launch_new_reduce_global_fcn(BODY body_in,
6161
RAJA::expt::invoke_body(reduce_params, body, ctx);
6262

6363
// Using a flatten global policy as we may use all dimensions
64-
RAJA::expt::ParamMultiplexer::params_combine<RAJA::hip_flatten_global_xyz_direct>(
64+
RAJA::expt::ParamMultiplexer::params_combine(RAJA::hip_flatten_global_xyz_direct{},
6565
reduce_params);
6666
}
6767

@@ -184,7 +184,7 @@ struct LaunchExecute<
184184
{
185185
using EXEC_POL =
186186
RAJA::policy::hip::hip_launch_t<async, named_usage::unspecified>;
187-
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(launch_reducers,
187+
RAJA::expt::ParamMultiplexer::params_init(EXEC_POL{}, launch_reducers,
188188
launch_info);
189189

190190
//
@@ -201,7 +201,7 @@ struct LaunchExecute<
201201
RAJA::hip::launch(func, gridSize, blockSize, args, shared_mem_size,
202202
hip_res, async, kernel_name);
203203

204-
RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(launch_reducers,
204+
RAJA::expt::ParamMultiplexer::params_resolve(EXEC_POL{}, launch_reducers,
205205
launch_info);
206206
}
207207

@@ -247,7 +247,7 @@ __launch_bounds__(num_threads, 1) __global__
247247
RAJA::expt::invoke_body(reduce_params, body, ctx);
248248

249249
// Using a flatten global policy as we may use all dimensions
250-
RAJA::expt::ParamMultiplexer::params_combine<RAJA::hip_flatten_global_xyz_direct>(
250+
RAJA::expt::ParamMultiplexer::params_combine(RAJA::hip_flatten_global_xyz_direct{},
251251
reduce_params);
252252
}
253253

@@ -370,7 +370,7 @@ struct LaunchExecute<RAJA::policy::hip::hip_launch_t<async, nthreads>>
370370
{
371371
using EXEC_POL =
372372
RAJA::policy::hip::hip_launch_t<async, named_usage::unspecified>;
373-
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(launch_reducers,
373+
RAJA::expt::ParamMultiplexer::params_init(EXEC_POL{}, launch_reducers,
374374
launch_info);
375375

376376
//
@@ -387,7 +387,7 @@ struct LaunchExecute<RAJA::policy::hip::hip_launch_t<async, nthreads>>
387387
RAJA::hip::launch(func, gridSize, blockSize, args, shared_mem_size,
388388
hip_res, async, kernel_name);
389389

390-
RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(launch_reducers,
390+
RAJA::expt::ParamMultiplexer::params_resolve(EXEC_POL{}, launch_reducers,
391391
launch_info);
392392
}
393393

include/RAJA/policy/hip/params/kernel_name.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace detail
2121
// Init
2222
template<typename EXEC_POL>
2323
camp::concepts::enable_if<type_traits::is_hip_policy<EXEC_POL>> param_init(
24+
EXEC_POL const&,
2425
KernelName& kn,
2526
const RAJA::hip::detail::hipInfo&)
2627
{
@@ -34,12 +35,14 @@ camp::concepts::enable_if<type_traits::is_hip_policy<EXEC_POL>> param_init(
3435
// Combine
3536
template<typename EXEC_POL>
3637
RAJA_HOST_DEVICE camp::concepts::enable_if<type_traits::is_hip_policy<EXEC_POL>>
37-
param_combine(KernelName&)
38+
param_combine(EXEC_POL const&,
39+
KernelName&)
3840
{}
3941

4042
// Resolve
4143
template<typename EXEC_POL>
4244
camp::concepts::enable_if<type_traits::is_hip_policy<EXEC_POL>> param_resolve(
45+
EXEC_POL const&,
4346
KernelName&,
4447
const RAJA::hip::detail::hipInfo&)
4548
{

0 commit comments

Comments
 (0)