Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vedantdhruv96 committed Sep 22, 2023
1 parent bcb304b commit b86820a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
31 changes: 19 additions & 12 deletions kharma/reductions/reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,32 +111,39 @@ std::vector<int> Reductions::CountFlags(MeshData<Real> *md, std::string field_na
IndexRange kb = md->GetBoundsK(domain);
IndexRange block = IndexRange{0, flag.GetDim(5) - 1};

// Man, moving arrays is clunky. Oh well.
const int n_of_flags = flag_values.size();
int flag_val_list[MAX_NFLAGS];
int f=0;
ParArray1D<int> flag_val_list("flag_values", MAX_NFLAGS);
auto flag_val_list_h = flag_val_list.GetHostMirror();
int f=1;
for (auto &flag : flag_values) {
flag_val_list[f] = flag.first;
flag_val_list_h[f] = flag.first;
f++;
}
flag_val_list.DeepCopy(flag_val_list_h);
Kokkos::fence();

// Count all nonzero (technically, >0) values,
// and all values of each
// and all values which match each flag.
// This works for pflags or fflags, so long as they're separate
// We don't count negative pflags as they denote zones that shouldn't be fixed
Reductions::array_type<int, MAX_NFLAGS> flag_reducer;
pmb0->par_reduce("count_flags", block.s, block.e, kb.s, kb.e, jb.s, jb.e, ib.s, ib.e,
KOKKOS_LAMBDA (const int &b, const int &k, const int &j, const int &i,
Reductions::array_type<int, MAX_NFLAGS> &local_result) {
if ((int) flag(b, 0, k, j, i) > 0) ++local_result.my_array[0];
for (int f=0; f<n_of_flags; f++)
if ((is_bitflag && static_cast<int>(flag(b, 0, k, j, i)) & flag_val_list[f]) ||
(!is_bitflag && static_cast<int>(flag(b, 0, k, j, i)) == flag_val_list[f]))
++local_result.my_array[f+1];
const int flag_int = static_cast<int>(flag(b, 0, k, j, i));
// First element is total count
if (flag_int > 0) ++local_result.my_array[0];
// The rest of the list is individual flags
for (int f=1; f < n_of_flags; f++)
if ((is_bitflag && flag_int & flag_val_list(f)) ||
(!is_bitflag && flag_int == flag_val_list(f)))
++local_result.my_array[f];
}
, Reductions::ArraySum<int, DevExecSpace, MAX_NFLAGS>(flag_reducer));
, Reductions::ArraySum<int, HostExecSpace, MAX_NFLAGS>(flag_reducer));

std::vector<int> n_each_flag;
for (int f=0; f<n_of_flags+1; f++)
for (int f=0; f < n_of_flags+1; f++)
n_each_flag.push_back(flag_reducer.my_array[f]);

EndFlag();
Expand Down
26 changes: 18 additions & 8 deletions kharma/reductions/reductions_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ T Reductions::EHReduction(MeshData<Real> *md, UserHistoryOperation op, int zone)
return result;
}

#define INSIDE (x[1] > startx[0] && x[2] > startx[1] && x[3] > startx[2]) && \
(trivial[0] ? x[1] < startx[0] + G.Dxc<1>(i) : x[1] < stopx[0]) && \
(trivial[1] ? x[2] < startx[1] + G.Dxc<2>(j) : x[2] < stopx[1]) && \
(trivial[2] ? x[3] < startx[2] + G.Dxc<3>(k) : x[3] < stopx[2])
#define INSIDE (x[1] > startx1 && x[2] > startx2 && x[3] > startx3) && \
(trivial1 ? x[1] < startx1 + G.Dxc<1>(i) : x[1] < stopx1) && \
(trivial2 ? x[2] < startx2 + G.Dxc<2>(j) : x[2] < stopx2) && \
(trivial3 ? x[3] < startx3 + G.Dxc<3>(k) : x[3] < stopx3)

// TODO additionally template on return type to avoid counting flags with Reals
template<Reductions::Var var, typename T>
Expand Down Expand Up @@ -226,7 +226,17 @@ T Reductions::DomainReduction(MeshData<Real> *md, UserHistoryOperation op, const
VLOOP if(startx[v] == stopx[v]) {
trivial_tmp[v] = true;
}
const bool trivial[3] = {trivial_tmp[0], trivial_tmp[1], trivial_tmp[2]};

// Pull values to pass to device, because passing views is cumbersome
const bool trivial1 = trivial_tmp[0];
const bool trivial2 = trivial_tmp[1];
const bool trivial3 = trivial_tmp[2];
const GReal startx1 = startx[0];
const GReal startx2 = startx[1];
const GReal startx3 = startx[2];
const GReal stopx1 = stopx[0];
const GReal stopx2 = stopx[1];
const GReal stopx3 = stopx[2];

T result = 0.;
MPI_Op mop;
Expand All @@ -240,7 +250,7 @@ T Reductions::DomainReduction(MeshData<Real> *md, UserHistoryOperation op, const
G.coord_embed(k, j, i, Loci::center, x);
if(INSIDE) {
local_result += reduction_var<var>(REDUCE_FUNCTION_CALL) *
(!trivial[2]) * G.Dxc<3>(k) * (!trivial[1]) * G.Dxc<2>(j) * (!trivial[0]) * G.Dxc<1>(i);
(!trivial3) * G.Dxc<3>(k) * (!trivial2) * G.Dxc<2>(j) * (!trivial1) * G.Dxc<1>(i);
}
}
, sum_reducer);
Expand All @@ -256,7 +266,7 @@ T Reductions::DomainReduction(MeshData<Real> *md, UserHistoryOperation op, const
G.coord_embed(k, j, i, Loci::center, x);
if(INSIDE) {
const Real val = reduction_var<var>(REDUCE_FUNCTION_CALL) *
(!trivial[2]) * G.Dxc<3>(k) * (!trivial[1]) * G.Dxc<2>(j) * (!trivial[0]) * G.Dxc<1>(i);
(!trivial3) * G.Dxc<3>(k) * (!trivial2) * G.Dxc<2>(j) * (!trivial1) * G.Dxc<1>(i);
if (val > local_result) local_result = val;
}
}
Expand All @@ -273,7 +283,7 @@ T Reductions::DomainReduction(MeshData<Real> *md, UserHistoryOperation op, const
G.coord_embed(k, j, i, Loci::center, x);
if(INSIDE) {
const Real val = reduction_var<var>(REDUCE_FUNCTION_CALL) *
(!trivial[2]) * G.Dxc<3>(k) * (!trivial[1]) * G.Dxc<2>(j) * (!trivial[0]) * G.Dxc<1>(i);
(!trivial3) * G.Dxc<3>(k) * (!trivial2) * G.Dxc<2>(j) * (!trivial1) * G.Dxc<1>(i);
if (val < local_result) local_result = val;
}
}
Expand Down

0 comments on commit b86820a

Please sign in to comment.