diff --git a/pyphare/pyphare/pharein/__init__.py b/pyphare/pyphare/pharein/__init__.py index 5704efb49..71ad6c5e0 100644 --- a/pyphare/pyphare/pharein/__init__.py +++ b/pyphare/pyphare/pharein/__init__.py @@ -365,7 +365,7 @@ def as_paths(rb): if "dir" in restart_options: restart_file_path = restart_options["dir"] - if "restart_time" in restart_options and restart_options["restart_time"] > 0: + if "restart_time" in restart_options: from pyphare.cpp import cpp_etc_lib restart_time = restart_options["restart_time"] diff --git a/pyphare/pyphare/pharein/simulation.py b/pyphare/pyphare/pharein/simulation.py index b0db542b3..6e71cd0b2 100644 --- a/pyphare/pyphare/pharein/simulation.py +++ b/pyphare/pyphare/pharein/simulation.py @@ -531,9 +531,10 @@ def check_restart_options(**kwargs): "restart_time", # number or "auto" "keep_last", # delete obsolete ] - restart_options = kwargs.get("restart_options", None) - if restart_options is not None: + restart_options = kwargs.get("restart_options", {}) + + if "restart_options" in kwargs: for key in restart_options.keys(): if key not in valid_keys: raise ValueError( @@ -553,8 +554,11 @@ def check_restart_options(**kwargs): f"Invalid restart mode {mode}, valid modes are {valid_modes}" ) - if restart_time := restarts.restart_time(restart_options): + restart_time = restarts.restart_time(restart_options) + if restart_time is not None: restart_options["restart_time"] = restart_time + elif "restart_time" in restart_options: + restart_options.pop("restart_time") # auto with no existing file to use return restart_options @@ -650,7 +654,7 @@ def check_clustering(**kwargs): def checker(func): - def wrapper(simulation_object, **kwargs): + def wrapper(simulation_object, **kwargs_in): accepted_keywords = [ "domain_size", "cells", @@ -684,6 +688,7 @@ def wrapper(simulation_object, **kwargs): "write_reports", ] + kwargs = deepcopy(kwargs_in) # local copy - dictionaries are weird accepted_keywords += check_optional_keywords(**kwargs) wrong_kwds = phare_utilities.not_in_keywords_list(accepted_keywords, **kwargs) @@ -702,6 +707,7 @@ def wrapper(simulation_object, **kwargs): kwargs["clustering"] = check_clustering(**kwargs) + kwargs["restart_options"] = check_restart_options(**kwargs) time_step_nbr, time_step, final_time = check_time(**kwargs) kwargs["time_step_nbr"] = time_step_nbr kwargs["time_step"] = time_step @@ -716,7 +722,6 @@ def wrapper(simulation_object, **kwargs): ndim = compute_dimension(cells) kwargs["diag_options"] = check_diag_options(**kwargs) - kwargs["restart_options"] = check_restart_options(**kwargs) kwargs["boundary_types"] = check_boundaries(ndim, **kwargs) @@ -1022,9 +1027,9 @@ def start_time(self): return 0 def is_from_restart(self): - return ( - self.restart_options is not None and "restart_time" in self.restart_options - ) + if self.restart_options is not None and "restart_time" in self.restart_options: + return self.restart_options["restart_time"] is not None + return False def __getattr__( self, name diff --git a/src/simulator/simulator.hpp b/src/simulator/simulator.hpp index 5883d4912..120fa5fef 100644 --- a/src/simulator/simulator.hpp +++ b/src/simulator/simulator.hpp @@ -359,6 +359,7 @@ Simulator::Simulator(PHARE::initializer::PHAREDict const& dict, { resman_ptr = std::make_shared(); currentTime_ = restart_time(dict); + finalTime_ += currentTime_; // final time is from timestep * timestep_nbr! if (dict["simulation"].contains("restarts")) rMan = restarts::RestartsManagerResolver::make_unique(*hierarchy_, *resman_ptr, diff --git a/tests/simulator/test_restarts.py b/tests/simulator/test_restarts.py index e3097553b..a60018387 100644 --- a/tests/simulator/test_restarts.py +++ b/tests/simulator/test_restarts.py @@ -431,7 +431,7 @@ def test_advanced_restarts_options(self): dup( dict( cells=10, - time_step_nbr=10, + time_step_nbr=7, max_nbr_levels=1, refinement="tagging", ) @@ -440,13 +440,13 @@ def test_advanced_restarts_options(self): simput["interp_order"] = interp time_step = simput["time_step"] - time_step_nbr = simput["time_step_nbr"] - timestamps = time_step * np.arange(time_step_nbr + 1) + timestamps = time_step * np.arange(simput["time_step_nbr"] + 1) local_out = self.unique_diag_dir_for_test_case(f"{out}/test", ndim, interp) simput["restart_options"]["dir"] = local_out simput["restart_options"]["keep_last"] = 3 simput["restart_options"]["timestamps"] = timestamps + simput["restart_options"]["restart_time"] = "auto" ph.global_vars.sim = None ph.global_vars.sim = ph.Simulation(**simput) @@ -454,8 +454,16 @@ def test_advanced_restarts_options(self): Simulator(ph.global_vars.sim).run().reset() self.register_diag_dir_for_cleanup(local_out) + # restarted + timestamps = time_step * np.arange(7, 11) + simput["time_step_nbr"] = 3 simput["restart_options"]["restart_time"] = "auto" - self.assertEqual(0.01, ph.restarts.restart_time(simput["restart_options"])) + simput["restart_options"]["timestamps"] = timestamps + self.assertEqual(0.007, ph.restarts.restart_time(simput["restart_options"])) + ph.global_vars.sim = None + ph.global_vars.sim = ph.Simulation(**simput) + setup_model() + Simulator(ph.global_vars.sim).run().reset() dirs = [] for path_object in Path(local_out).iterdir(): @@ -469,6 +477,8 @@ def test_advanced_restarts_options(self): for i, idx in enumerate(range(8, 11)): self.assertAlmostEqual(dirs[i], time_step * idx) + self.assertEqual(0.01, ph.restarts.restart_time(simput["restart_options"])) + if __name__ == "__main__": unittest.main()