Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hook for UserWorkBeforeLoop #971

Merged
merged 6 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Current develop

### Added (new features/APIs/variables/...)
- [[PR 971]](https://github.com/parthenon-hpc-lab/parthenon/pull/971) Add UserWorkBeforeLoop
- [[PR 907]](https://github.com/parthenon-hpc-lab/parthenon/pull/907) PEP1: Allow subclassing StateDescriptor
- [[PR 932]](https://github.com/parthenon-hpc-lab/parthenon/pull/932) Add GetOrAddFlag to metadata
- [[PR 931]](https://github.com/parthenon-hpc-lab/parthenon/pull/931) Allow SparsePacks with subsets of blocks
Expand Down
7 changes: 7 additions & 0 deletions doc/sphinx/src/interface/state.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ several useful features and functions.
deletgates to the ``std::function`` member ``PostStepDiagnosticsMesh``
if set (defaults to ``nullptr`` an therefore a no-op) to print
diagnostics after the time-integration advance
- ``void UserWorkBeforeLoopMesh(Mesh *, ParameterInput *pin, SimTime
pdmullen marked this conversation as resolved.
Show resolved Hide resolved
&tm)`` performs a per-package, mesh-wide calculation after the mesh
has been generated, and problem generators called, but before any
time evolution. This work is done both on first initialization and
on restart. If you would like to avoid doing the work upon restart,
you can check for the const ``is_restart`` member field of the ``Mesh``
object.

The reasoning for providing ``FillDerived*`` and ``EstimateTimestep*``
function pointers appropriate for usage with both ``MeshData`` and
Expand Down
13 changes: 12 additions & 1 deletion example/advection/advection_package.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//========================================================================================
// (C) (or copyright) 2020-2021. Triad National Security, LLC. All rights reserved.
// (C) (or copyright) 2020-2023. Triad National Security, LLC. All rights reserved.
//
// This program was produced under U.S. Government contract 89233218CNA000001 for Los
// Alamos National Laboratory (LANL), which is operated by Triad National Security, LLC
Expand All @@ -13,12 +13,14 @@

#include <algorithm>
#include <cmath>
#include <iostream>
#include <limits>
#include <memory>
#include <string>
#include <vector>

#include <coordinates/coordinates.hpp>
#include <globals.hpp>
#include <parthenon/package.hpp>

#include "advection_package.hpp"
Expand Down Expand Up @@ -215,10 +217,19 @@ std::shared_ptr<StateDescriptor> Initialize(ParameterInput *pin) {
}
pkg->CheckRefinementBlock = CheckRefinement;
pkg->EstimateTimestepBlock = EstimateTimestepBlock;
pkg->UserWorkBeforeLoopMesh = AdvectionGreetings;

return pkg;
}

void AdvectionGreetings(Mesh *pmesh, ParameterInput *pin, parthenon::SimTime &tm) {
if (parthenon::Globals::my_rank == 0) {
std::cout << "Hello from the advection package in the advection example!\n"
<< "This run is a restart: " << pmesh->is_restart << "\n"
<< std::endl;
}
}

AmrTag CheckRefinement(MeshBlockData<Real> *rc) {
// refine on advected, for example. could also be a derived quantity
auto pmb = rc->GetBlockPointer();
Expand Down
3 changes: 2 additions & 1 deletion example/advection/advection_package.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//========================================================================================
// (C) (or copyright) 2020. Triad National Security, LLC. All rights reserved.
// (C) (or copyright) 2020-2023. Triad National Security, LLC. All rights reserved.
//
// This program was produced under U.S. Government contract 89233218CNA000001 for Los
// Alamos National Laboratory (LANL), which is operated by Triad National Security, LLC
Expand All @@ -21,6 +21,7 @@ namespace advection_package {
using namespace parthenon::package::prelude;

std::shared_ptr<StateDescriptor> Initialize(ParameterInput *pin);
void AdvectionGreetings(Mesh *pmes, ParameterInput *pin, parthenon::SimTime &tm);
AmrTag CheckRefinement(MeshBlockData<Real> *rc);
void PreFill(MeshBlockData<Real> *rc);
void SquareIt(MeshBlockData<Real> *rc);
Expand Down
1 change: 1 addition & 0 deletions src/application_input.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct ApplicationInput {
PostStepDiagnosticsInLoop = nullptr;

std::function<void(Mesh *, ParameterInput *, SimTime &)> UserWorkAfterLoop = nullptr;
std::function<void(Mesh *, ParameterInput *, SimTime &)> UserWorkBeforeLoop = nullptr;
BValFunc boundary_conditions[BOUNDARY_NFACES] = {nullptr};
SBValFunc swarm_boundary_conditions[BOUNDARY_NFACES] = {nullptr};

Expand Down
10 changes: 10 additions & 0 deletions src/driver/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ DriverStatus EvolutionDriver::Execute() {
// Defaults must be set across all ranks
DumpInputParameters();

// Before loop do work
Yurlungur marked this conversation as resolved.
Show resolved Hide resolved
// App input version
if (app_input->UserWorkBeforeLoop != nullptr) {
app_input->UserWorkBeforeLoop(pmesh, pinput, tm);
}
// packages version
for (auto &[name, pkg] : pmesh->packages.AllPackages()) {
pkg->UserWorkBeforeLoop(pmesh, pinput, tm);
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of the older function hooks store the function pointers in the mesh. I think this is a holdover from Athena++ and in parthenon, it's more natural to just query the app input or the state descriptor directly.


Kokkos::Profiling::pushRegion("Driver_Main");
while (tm.KeepGoing()) {
if (Globals::my_rank == 0) OutputCycleDiagnostics();
Expand Down
6 changes: 6 additions & 0 deletions src/interface/state_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ class StateDescriptor {
if (InitNewlyAllocatedVarsBlock != nullptr) return InitNewlyAllocatedVarsBlock(rc);
}

void UserWorkBeforeLoop(Mesh *pmesh, ParameterInput *pin, SimTime &tm) const {
if (UserWorkBeforeLoopMesh != nullptr) return UserWorkBeforeLoopMesh(pmesh, pin, tm);
}
Comment on lines +409 to +411
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This copies the design of the other function hooks in state descriptor, but we could cut out the middle and just check for nullptr in the function pointer itself.


std::vector<std::shared_ptr<AMRCriteria>> amr_criteria;

std::function<void(MeshBlockData<Real> *rc)> PreCommFillDerivedBlock = nullptr;
Expand All @@ -416,6 +420,8 @@ class StateDescriptor {
std::function<void(MeshData<Real> *rc)> PostFillDerivedMesh = nullptr;
std::function<void(MeshBlockData<Real> *rc)> FillDerivedBlock = nullptr;
std::function<void(MeshData<Real> *rc)> FillDerivedMesh = nullptr;
std::function<void(Mesh *, ParameterInput *, SimTime &)> UserWorkBeforeLoopMesh =
nullptr;

std::function<void(SimTime const &simtime, MeshData<Real> *rc)> PreStepDiagnosticsMesh =
nullptr;
Expand Down
4 changes: 2 additions & 2 deletions src/mesh/mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ namespace parthenon {
Mesh::Mesh(ParameterInput *pin, ApplicationInput *app_in, Packages_t &packages,
int mesh_test)
: // public members:
modified(true),
modified(true), is_restart(false),
// aggregate initialization of RegionSize struct:
mesh_size({pin->GetReal("parthenon/mesh", "x1min"),
pin->GetReal("parthenon/mesh", "x2min"),
Expand Down Expand Up @@ -484,7 +484,7 @@ Mesh::Mesh(ParameterInput *pin, ApplicationInput *app_in, RestartReader &rr,
: // public members:
// aggregate initialization of RegionSize struct:
// (will be overwritten by memcpy from restart file, in this case)
modified(true),
modified(true), is_restart(true),
// aggregate initialization of RegionSize struct:
mesh_size({pin->GetReal("parthenon/mesh", "x1min"),
pin->GetReal("parthenon/mesh", "x2min"),
Expand Down
1 change: 1 addition & 0 deletions src/mesh/mesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class Mesh {

// data
bool modified;
const bool is_restart;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted people to be able to easily do different things upon restart vs first initialization, but without having to register two separate, maybe almost identical functions. There might be a better way than storing restart state in the mesh pointer, but it also seemed natural as something inherited from Athena++.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What we've typically been doing is to check for tm.cycle != 0, which should also be possible in the new callbacks given that the SimTime object is passed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point---is_restart feels cleaner/less ambiguous to me, but I'm fine to remove this and tell people to use tm.cycle if preferred.

RegionSize mesh_size;
BoundaryFlag mesh_bcs[BOUNDARY_NFACES];
const int ndim; // number of dimensions
Expand Down
Loading