Skip to content

Commit 1e77c13

Browse files
authored
Merge pull request #1187 from parthenon-hpc-lab/lroberts36/add-more-careful-container-checking
Small: Changes to `DataCollection::Add` and `MeshBlockData::Intialize`
2 parents a9b7858 + 3a15cfe commit 1e77c13

10 files changed

+83
-15
lines changed

.github/workflows/docs.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ jobs:
2323
run: export DEBIAN_FRONTEND=noninteractive
2424
- name: install dependencies
2525
run: |
26-
pip install sphinx
27-
pip install sphinx-rtd-theme
28-
pip install sphinx-multiversion
26+
pip install --break-system-packages sphinx
27+
pip install --break-system-packages sphinx-rtd-theme
28+
pip install --break-system-packages sphinx-multiversion
2929
- name: build docs
3030
run: |
3131
echo "Repo = ${GITHUB_REPOSITORY}"

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- [[PR 1161]](https://github.com/parthenon-hpc-lab/parthenon/pull/1161) Make flux field Metadata accessible, add Metadata::CellMemAligned flag, small perfomance upgrades
1212

1313
### Changed (changing behavior/API/variables/...)
14+
- [[PR 1187]](https://github.com/parthenon-hpc-lab/parthenon/pull/1187) Make DataCollection::Add safer and generalize MeshBlockData::Initialize
1415
- [[PR 1186]](https://github.com/parthenon-hpc-lab/parthenon/pull/1186) Bump Kokkos submodule to 4.4.1
1516
- [[PR 1171]](https://github.com/parthenon-hpc-lab/parthenon/pull/1171) Add PARTHENON_USE_SYSTEM_PACKAGES build option
1617
- [[PR 1172]](https://github.com/parthenon-hpc-lab/parthenon/pull/1172) Make parthenon manager robust against external MPI init and finalize calls

src/interface/data_collection.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class DataCollection {
5858
auto key = GetKey(name, src);
5959
auto it = containers_.find(key);
6060
if (it != containers_.end()) {
61-
if (fields.size() && !(it->second)->Contains(fields)) {
61+
if (fields.size() && !(it->second)->CreatedFrom(fields)) {
6262
PARTHENON_THROW(key + " already exists in collection but fields do not match.");
6363
}
6464
return it->second;

src/interface/mesh_data.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,14 @@ class MeshData {
476476
[this, vars](const auto &b) { return b->ContainsExactly(vars); });
477477
}
478478

479+
// Checks that the same set of variables was requested to create this container
480+
// (which may be different than the set of variables in the container because of fluxes)
481+
template <typename Vars_t>
482+
bool CreatedFrom(const Vars_t &vars) const noexcept {
483+
return std::all_of(block_data_.begin(), block_data_.end(),
484+
[this, vars](const auto &b) { return b->CreatedFrom(vars); });
485+
}
486+
479487
std::shared_ptr<SwarmContainer> GetSwarmData(int n) {
480488
PARTHENON_REQUIRE(n >= 0 && n < block_data_.size(),
481489
"MeshData::GetSwarmData requires n within [0, block_data_.size()]");

src/interface/meshblock_data.hpp

+43-3
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ class MeshBlockData {
141141
resolved_packages = resolved_packages_in;
142142
is_shallow_ = shallow_copy;
143143

144+
// Store the list of variables used to create this container
145+
// so we can compare to it when searching the cache
146+
varUidIn_.clear();
147+
if constexpr (std::is_same_v<ID_t, std::string>) {
148+
for (const auto &var : vars)
149+
varUidIn_.insert(Variable<Real>::GetUniqueID(var));
150+
} else {
151+
for (const auto &var : vars)
152+
varUidIn_.insert(var);
153+
}
154+
144155
// clear all variables, maps, and pack caches
145156
varVector_.clear();
146157
varMap_.clear();
@@ -185,9 +196,24 @@ class MeshBlockData {
185196
if (!found) add_var(src->GetVarPtr(flx_name));
186197
}
187198
}
188-
} else {
189-
PARTHENON_FAIL(
190-
"Variable subset selection not yet implemented for MeshBlock input.");
199+
} else if constexpr (std::is_same_v<SRC_t, MeshBlock>) {
200+
for (const auto &v : vars) {
201+
const auto &vid = resolved_packages->GetFieldVarID(v);
202+
const auto &md = resolved_packages->GetFieldMetadata(v);
203+
AddField(vid.base_name, md, vid.sparse_id);
204+
// Add the associated flux as well if not explicitly
205+
// asked for
206+
if (md.IsSet(Metadata::WithFluxes)) {
207+
auto flx_vid = resolved_packages->GetFieldVarID(md.GetFluxName());
208+
bool found = false;
209+
for (const auto &v2 : vars)
210+
if (resolved_packages->GetFieldVarID(v2) == flx_vid) found = true;
211+
if (!found) {
212+
const auto &flx_md = resolved_packages->GetFieldMetadata(flx_vid);
213+
AddField(flx_vid.base_name, flx_md, flx_vid.sparse_id);
214+
}
215+
}
216+
}
191217
}
192218
}
193219

@@ -525,6 +551,18 @@ class MeshBlockData {
525551
return Contains(vars) && (vars.size() == varVector_.size());
526552
}
527553

554+
bool CreatedFrom(const std::vector<Uid_t> &vars) {
555+
return (vars.size() == varUidIn_.size()) &&
556+
std::all_of(vars.begin(), vars.end(),
557+
[this](const auto &v) { return this->varUidIn_.count(v); });
558+
}
559+
bool CreatedFrom(const std::vector<std::string> &vars) {
560+
return (vars.size() == varUidIn_.size()) &&
561+
std::all_of(vars.begin(), vars.end(), [this](const auto &v) {
562+
return this->varUidIn_.count(Variable<Real>::GetUniqueID(v));
563+
});
564+
}
565+
528566
void SetAllVariablesToInitialized() {
529567
std::for_each(varVector_.begin(), varVector_.end(),
530568
[](auto &sp_var) { sp_var->data.initialized = true; });
@@ -561,6 +599,8 @@ class MeshBlockData {
561599

562600
VariableVector<T> varVector_; ///< the saved variable array
563601
std::map<Uid_t, std::shared_ptr<Variable<T>>> varUidMap_;
602+
std::set<Uid_t> varUidIn_; // Uid list from which this MeshBlockData was created,
603+
// empty implies all variables were included
564604

565605
MapToVars<T> varMap_;
566606
MetadataFlagToVariableMap<T> flagsToVars_;

src/interface/state_descriptor.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ bool StateDescriptor::AddFieldImpl_(const VarID &vid, const Metadata &m_in,
316316
AddFieldImpl_(fId, *(m.GetSPtrFluxMetadata()), control_vid);
317317
m.SetFluxName(fId.label());
318318
}
319+
labelToVidMap_.insert({vid.label(), vid});
319320
metadataMap_.insert({vid, m});
320321
refinementFuncMaps_.Register(m, vid.label());
321322
allocControllerReverseMap_.insert({vid, control_vid});

src/interface/state_descriptor.hpp

+15
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,20 @@ class StateDescriptor {
203203
// retrieve all swarm names
204204
std::vector<std::string> Swarms() noexcept;
205205

206+
const auto GetFieldVarID(const VarID &id) const {
207+
PARTHENON_REQUIRE_THROWS(
208+
metadataMap_.count(id),
209+
"Asking for a variable that is not in this StateDescriptor.");
210+
return id;
211+
}
212+
213+
const auto &GetFieldVarID(const std::string &label) const {
214+
return labelToVidMap_.at(label);
215+
}
216+
const auto &GetFieldMetadata(const std::string &label) const {
217+
return metadataMap_.at(labelToVidMap_.at(label));
218+
}
219+
const auto &GetFieldMetadata(const VarID &id) const { return metadataMap_.at(id); }
206220
const auto &AllFields() const noexcept { return metadataMap_; }
207221
const auto &AllSparsePools() const noexcept { return sparsePoolMap_; }
208222
const auto &AllSwarms() const noexcept { return swarmMetadataMap_; }
@@ -397,6 +411,7 @@ class StateDescriptor {
397411
const std::string label_;
398412

399413
// for each variable label (full label for sparse variables) hold metadata
414+
std::unordered_map<std::string, VarID> labelToVidMap_;
400415
std::unordered_map<VarID, Metadata, VarIDHasher> metadataMap_;
401416
std::unordered_map<VarID, VarID, VarIDHasher> allocControllerReverseMap_;
402417
std::unordered_map<std::string, std::vector<std::string>> allocControllerMap_;

src/mesh/mesh-amr_loadbalance.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -979,8 +979,9 @@ void Mesh::RedistributeAndRefineMeshBlocks(ParameterInput *pin, ApplicationInput
979979
auto &md_noncc = mesh_data.AddShallow(noncc, md, noncc_names);
980980
}
981981

982-
CommunicateBoundaries(noncc); // Called to make sure shared values are correct,
983-
// ghosts of non-cell centered vars may get some junk
982+
CommunicateBoundaries(
983+
noncc, noncc_names); // Called to make sure shared values are correct,
984+
// ghosts of non-cell centered vars may get some junk
984985
// Now there is the correct data for prolongating on un-shared topological elements
985986
// on the new fine blocks
986987
if (nprolong > 0) {

src/mesh/mesh.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,8 @@ void Mesh::BuildTagMapAndBoundaryBuffers() {
644644
}
645645
}
646646

647-
void Mesh::CommunicateBoundaries(std::string md_name) {
647+
void Mesh::CommunicateBoundaries(std::string md_name,
648+
const std::vector<std::string> &fields) {
648649
const int num_partitions = DefaultNumPartitions();
649650
const int nmb = GetNumMeshBlocksThisRank(Globals::my_rank);
650651
constexpr std::int64_t max_it = 1e10;
@@ -656,7 +657,7 @@ void Mesh::CommunicateBoundaries(std::string md_name) {
656657
do {
657658
all_sent = true;
658659
for (int i = 0; i < partitions.size(); ++i) {
659-
auto &md = mesh_data.Add(md_name, partitions[i]);
660+
auto &md = mesh_data.Add(md_name, partitions[i], fields);
660661
if (!sent[i]) {
661662
if (SendBoundaryBuffers(md) != TaskStatus::complete) {
662663
all_sent = false;
@@ -680,7 +681,7 @@ void Mesh::CommunicateBoundaries(std::string md_name) {
680681
do {
681682
all_received = true;
682683
for (int i = 0; i < partitions.size(); ++i) {
683-
auto &md = mesh_data.Add(md_name, partitions[i]);
684+
auto &md = mesh_data.Add(md_name, partitions[i], fields);
684685
if (!received[i]) {
685686
if (ReceiveBoundaryBuffers(md) != TaskStatus::complete) {
686687
all_received = false;
@@ -696,14 +697,14 @@ void Mesh::CommunicateBoundaries(std::string md_name) {
696697
"Too many iterations waiting to receive boundary communication buffers.");
697698

698699
for (auto &partition : partitions) {
699-
auto &md = mesh_data.Add(md_name, partition);
700+
auto &md = mesh_data.Add(md_name, partition, fields);
700701
// unpack FillGhost variables
701702
SetBoundaries(md);
702703
}
703704

704705
// Now do prolongation, compute primitives, apply BCs
705706
for (auto &partition : partitions) {
706-
auto &md = mesh_data.Add(md_name, partition);
707+
auto &md = mesh_data.Add(md_name, partition, fields);
707708
if (multilevel) {
708709
ApplyBoundaryConditionsOnCoarseOrFineMD(md, true);
709710
ProlongateBoundaries(md);

src/mesh/mesh.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,8 @@ class Mesh {
329329

330330
void SetupMPIComms();
331331
void BuildTagMapAndBoundaryBuffers();
332-
void CommunicateBoundaries(std::string md_name = "base");
332+
void CommunicateBoundaries(std::string md_name = "base",
333+
const std::vector<std::string> &fields = {});
333334
void PreCommFillDerived();
334335
void FillDerived();
335336

0 commit comments

Comments
 (0)