diff --git a/kharma/reductions/reductions.cpp b/kharma/reductions/reductions.cpp index ebf754b8..f69fa8b8 100644 --- a/kharma/reductions/reductions.cpp +++ b/kharma/reductions/reductions.cpp @@ -111,32 +111,39 @@ std::vector Reductions::CountFlags(MeshData *md, std::string field_na IndexRange kb = md->GetBoundsK(domain); IndexRange block = IndexRange{0, flag.GetDim(5) - 1}; + // Man, moving arrays is clunky. Oh well. const int n_of_flags = flag_values.size(); - int flag_val_list[MAX_NFLAGS]; - int f=0; + ParArray1D flag_val_list("flag_values", MAX_NFLAGS); + auto flag_val_list_h = flag_val_list.GetHostMirror(); + int f=1; for (auto &flag : flag_values) { - flag_val_list[f] = flag.first; + flag_val_list_h[f] = flag.first; f++; } + flag_val_list.DeepCopy(flag_val_list_h); + Kokkos::fence(); // Count all nonzero (technically, >0) values, - // and all values of each + // and all values which match each flag. // This works for pflags or fflags, so long as they're separate // We don't count negative pflags as they denote zones that shouldn't be fixed Reductions::array_type flag_reducer; pmb0->par_reduce("count_flags", block.s, block.e, kb.s, kb.e, jb.s, jb.e, ib.s, ib.e, KOKKOS_LAMBDA (const int &b, const int &k, const int &j, const int &i, Reductions::array_type &local_result) { - if ((int) flag(b, 0, k, j, i) > 0) ++local_result.my_array[0]; - for (int f=0; f(flag(b, 0, k, j, i)) & flag_val_list[f]) || - (!is_bitflag && static_cast(flag(b, 0, k, j, i)) == flag_val_list[f])) - ++local_result.my_array[f+1]; + const int flag_int = static_cast(flag(b, 0, k, j, i)); + // First element is total count + if (flag_int > 0) ++local_result.my_array[0]; + // The rest of the list is individual flags + for (int f=1; f < n_of_flags; f++) + if ((is_bitflag && flag_int & flag_val_list(f)) || + (!is_bitflag && flag_int == flag_val_list(f))) + ++local_result.my_array[f]; } - , Reductions::ArraySum(flag_reducer)); - + , Reductions::ArraySum(flag_reducer)); + std::vector n_each_flag; - for (int f=0; f *md, UserHistoryOperation op, int zone) return result; } -#define INSIDE (x[1] > startx[0] && x[2] > startx[1] && x[3] > startx[2]) && \ - (trivial[0] ? x[1] < startx[0] + G.Dxc<1>(i) : x[1] < stopx[0]) && \ - (trivial[1] ? x[2] < startx[1] + G.Dxc<2>(j) : x[2] < stopx[1]) && \ - (trivial[2] ? x[3] < startx[2] + G.Dxc<3>(k) : x[3] < stopx[2]) +#define INSIDE (x[1] > startx1 && x[2] > startx2 && x[3] > startx3) && \ + (trivial1 ? x[1] < startx1 + G.Dxc<1>(i) : x[1] < stopx1) && \ + (trivial2 ? x[2] < startx2 + G.Dxc<2>(j) : x[2] < stopx2) && \ + (trivial3 ? x[3] < startx3 + G.Dxc<3>(k) : x[3] < stopx3) // TODO additionally template on return type to avoid counting flags with Reals template @@ -226,7 +226,17 @@ T Reductions::DomainReduction(MeshData *md, UserHistoryOperation op, const VLOOP if(startx[v] == stopx[v]) { trivial_tmp[v] = true; } - const bool trivial[3] = {trivial_tmp[0], trivial_tmp[1], trivial_tmp[2]}; + + // Pull values to pass to device, because passing views is cumbersome + const bool trivial1 = trivial_tmp[0]; + const bool trivial2 = trivial_tmp[1]; + const bool trivial3 = trivial_tmp[2]; + const GReal startx1 = startx[0]; + const GReal startx2 = startx[1]; + const GReal startx3 = startx[2]; + const GReal stopx1 = stopx[0]; + const GReal stopx2 = stopx[1]; + const GReal stopx3 = stopx[2]; T result = 0.; MPI_Op mop; @@ -240,7 +250,7 @@ T Reductions::DomainReduction(MeshData *md, UserHistoryOperation op, const G.coord_embed(k, j, i, Loci::center, x); if(INSIDE) { local_result += reduction_var(REDUCE_FUNCTION_CALL) * - (!trivial[2]) * G.Dxc<3>(k) * (!trivial[1]) * G.Dxc<2>(j) * (!trivial[0]) * G.Dxc<1>(i); + (!trivial3) * G.Dxc<3>(k) * (!trivial2) * G.Dxc<2>(j) * (!trivial1) * G.Dxc<1>(i); } } , sum_reducer); @@ -256,7 +266,7 @@ T Reductions::DomainReduction(MeshData *md, UserHistoryOperation op, const G.coord_embed(k, j, i, Loci::center, x); if(INSIDE) { const Real val = reduction_var(REDUCE_FUNCTION_CALL) * - (!trivial[2]) * G.Dxc<3>(k) * (!trivial[1]) * G.Dxc<2>(j) * (!trivial[0]) * G.Dxc<1>(i); + (!trivial3) * G.Dxc<3>(k) * (!trivial2) * G.Dxc<2>(j) * (!trivial1) * G.Dxc<1>(i); if (val > local_result) local_result = val; } } @@ -273,7 +283,7 @@ T Reductions::DomainReduction(MeshData *md, UserHistoryOperation op, const G.coord_embed(k, j, i, Loci::center, x); if(INSIDE) { const Real val = reduction_var(REDUCE_FUNCTION_CALL) * - (!trivial[2]) * G.Dxc<3>(k) * (!trivial[1]) * G.Dxc<2>(j) * (!trivial[0]) * G.Dxc<1>(i); + (!trivial3) * G.Dxc<3>(k) * (!trivial2) * G.Dxc<2>(j) * (!trivial1) * G.Dxc<1>(i); if (val < local_result) local_result = val; } }