28
28
from swirl_lm .boundary_condition import nonreflecting_boundary
29
29
from swirl_lm .core import simulation
30
30
from swirl_lm .linalg import poisson_solver
31
+ from swirl_lm .physics .lpt import lpt
32
+ from swirl_lm .physics .lpt import lpt_manager
31
33
from swirl_lm .physics .radiation import rrtmgp_common
32
34
from swirl_lm .utility import common_ops
33
35
from swirl_lm .utility import debug_output
94
96
CKPT_DIR_FORMAT = '{filename_prefix}-ckpts/'
95
97
COMPLETION_FILE = 'DONE'
96
98
_MAX_UVW_CFL = 'max_uvw_cfl'
99
+ SIMULATION_TIME = 'simulation_time'
97
100
98
101
Array : TypeAlias = Any
99
102
PerReplica : TypeAlias = tf .types .experimental .distributed .PerReplica
@@ -191,6 +194,9 @@ def _get_state_keys(params: parameters_lib.SwirlLMParameters):
191
194
# Add additional keys required by the radiative transfer library.
192
195
additional_keys += rrtmgp_common .required_keys (params .radiative_transfer )
193
196
197
+ # Add additional keys required by the lagrangian particle tracking library.
198
+ additional_keys += lpt .required_keys (params .lpt )
199
+
194
200
# Check to make sure we don't have keys duplicating / overwriting each other.
195
201
if len (set (essential_keys )) + len (set (additional_keys )) + len (
196
202
set (helper_var_keys )
@@ -250,6 +256,11 @@ def init_fn(
250
256
(params .num_steps , 4 ), dtype = types .TF_DTYPE
251
257
)
252
258
259
+ states [SIMULATION_TIME ] = tf .convert_to_tensor (0 , dtype = tf .float64 )
260
+
261
+ if params .lpt is not None :
262
+ states .update (lpt .init_fn (params ))
263
+
253
264
# Apply the user defined `init_fn` in the end to allow it to override
254
265
# default initializations.
255
266
if customized_init_fn is not None :
@@ -348,6 +359,20 @@ def _update_additional_states(
348
359
)
349
360
)
350
361
362
+ # Update lagrangian particle additional states.
363
+ with tf .name_scope ('lpt_additional_states_update' ):
364
+ lpt_field = lpt_manager .lpt_factory (params )
365
+ if lpt_field is not None :
366
+ updated_additional_states .update (
367
+ lpt_field .step (
368
+ replica_id = common_kwargs ['replica_id' ],
369
+ replicas = common_kwargs ['replicas' ],
370
+ states = essential_states ,
371
+ additional_states = additional_states ,
372
+ step_id = step_id ,
373
+ ),
374
+ )
375
+
351
376
if params .additional_states_update_fn is not None :
352
377
with tf .name_scope ('additional_states_update' ):
353
378
updated_additional_states = params .additional_states_update_fn (
@@ -362,21 +387,24 @@ def _update_additional_states(
362
387
363
388
def _state_has_nan_inf (state : dict [str , PerReplica ], replicas : Array ) -> bool :
364
389
"""Checks whether any field in the `state` contains `nan` or `inf`."""
365
- has_nan_inf = False
390
+ local_has_nan_inf = False
366
391
for _ , v in state .items ():
367
- # We want to make sure every core is seeing the same problem so the
368
- # termination of all cores is synchronized, thus a global reduce operation
369
- # is needed.
370
- if common_ops .global_reduce (
371
- tf .math .logical_or (tf .math .is_nan (v ), tf .math .is_inf (v )),
372
- tf .reduce_any ,
373
- replicas .reshape ([1 , - 1 ]),
374
- ):
375
- has_nan_inf = True
376
- # For some reason, the graph tracing in this case doesn't allow early break
377
- # of the loop. For now we will just check through all fields without early
378
- # break. This should still be very efficient.
379
- return has_nan_inf
392
+ if (v .dtype .is_floating and tf .reduce_any (
393
+ tf .math .logical_or (tf .math .is_nan (v ), tf .math .is_inf (v )))):
394
+ local_has_nan_inf = True
395
+ # Graph compilation doesn't allow early break.
396
+
397
+ # We want to make sure every core is seeing the same problem so the
398
+ # termination of all cores is synchronized, thus a global reduce operation
399
+ # is needed.
400
+ if common_ops .global_reduce (
401
+ tf .convert_to_tensor (local_has_nan_inf ),
402
+ tf .reduce_any ,
403
+ replicas .reshape ([1 , - 1 ]),
404
+ ):
405
+ return True
406
+ else :
407
+ return False
380
408
381
409
382
410
def _compute_max_uvw_and_cfl (
@@ -437,7 +465,8 @@ def _one_cycle(
437
465
num_steps : Array ,
438
466
params : Union [parameters_lib .SwirlLMParameters , Any ],
439
467
model : Any ,
440
- ) -> tuple [dict [str , PerReplica ], dict [str , PerReplica ]]:
468
+ ) -> tuple [dict [str , PerReplica ], PerReplica , dict [str , PerReplica ],
469
+ PerReplica ]:
441
470
"""Runs one cycle of the Navier-Stokes solver.
442
471
443
472
Args:
@@ -452,8 +481,12 @@ def _one_cycle(
452
481
defined.
453
482
454
483
Returns:
455
- A 2-tuple of the final state at the end of the cycle and the completed
456
- number of steps, both will be in the dict[str, PerReplica] format.
484
+ A (final_state, completed_steps, previous_state, has_non_finite) where
485
+ final_state is the state at the end of the cycle, completed_steps is the
486
+ number of steps completed in this cycle, previous_state is the state one
487
+ step before the final_state (which is useful in blow-ups), and
488
+ has_non_finite is a boolean indicating whether the final_state contains
489
+ non-finite values.
457
490
"""
458
491
logging .info (
459
492
'Tracing and compiling of _one_cycle starts. This can take up to 30 min.'
@@ -500,6 +533,8 @@ def step_fn(state):
500
533
state [_MAX_UVW_CFL ] = tf .zeros_like (state [_MAX_UVW_CFL ])
501
534
502
535
cycle_step_id = 0
536
+ prev_state = state
537
+ has_non_finite = False
503
538
for _ in tf .range (num_steps ):
504
539
step_id = init_step_id + cycle_step_id
505
540
# Split the state into essential states and additional states. Note that
@@ -596,22 +631,29 @@ def step_fn(state):
596
631
),
597
632
)
598
633
599
- if SAVE_LAST_VALID_STEP .value :
600
- if _state_has_nan_inf (updated_state , logical_replicas ):
601
- # Detected nan/inf, skip the update of state by early-exiting from the
602
- # for loop.
603
- break
634
+ # Simulation time will accumulate precision errors with this approach.
635
+ # For now, the converter deals with this by rounding off to a fewer
636
+ # number of significant digits.
637
+ updated_state [SIMULATION_TIME ] = state [SIMULATION_TIME ] + params .dt64
638
+
639
+ prev_state = state
604
640
# Some state keys such as `replica_id` may not lie in either of the three
605
641
# categories. Just pass them through.
606
642
state = _stateless_update_if_present (state , updated_state )
607
643
cycle_step_id += 1
644
+ if SAVE_LAST_VALID_STEP .value :
645
+ if _state_has_nan_inf (state , logical_replicas ):
646
+ # Detected nan/inf, skip the update of state by early-exiting from the
647
+ # for loop.
648
+ has_non_finite = True
649
+ break
608
650
609
651
if not params .use_3d_tf_tensor :
610
652
# Unsplit the keys that were previously split.
611
653
for key in keys_to_split :
612
654
state [key ] = tf .stack (state [key ])
613
655
614
- return state , cycle_step_id
656
+ return state , cycle_step_id , prev_state , has_non_finite
615
657
616
658
return strategy .run (step_fn , args = (init_state ,))
617
659
@@ -740,10 +782,10 @@ def solver(
740
782
num_steps, 2 * num_steps, ..., num_cycles * num_steps].
741
783
742
784
If the simulation reaches a state where any variable has a non-finite value
743
- (NaN or Inf), then the simulation will stop early and save the state one
744
- step before the state that contains non-finite values . As a result, there
745
- will be fewer steps saved than `num_cycles + ` and the final saved step
746
- number will not necessarily be a multiple of num_steps.
785
+ (NaN or Inf), then the simulation will stop early and save both the
786
+ non-finite state and the one before it . As a result, there will most likely
787
+ be fewer steps saved than `num_cycles + 1 ` and the final saved step number
788
+ will not necessarily be a multiple of num_steps.
747
789
748
790
In both of the these cases (`num_cycles` reached or non-finite value seen),
749
791
the solver will write an empty `DONE` file to the output directory to
@@ -866,9 +908,19 @@ def write_state_and_sync(
866
908
state : dict [str , PerReplica ],
867
909
step_id : Array ,
868
910
data_dump_filter : Optional [Sequence [str ]] = None ,
911
+ allow_non_finite_values : bool = False ,
912
+ use_zeros_for_debug_values : bool = False ,
869
913
):
870
- debug_vars = debug_output .get_vars (strategy , state .keys ())
914
+ if use_zeros_for_debug_values :
915
+ debug_vars = debug_output .zeros_like_vars (strategy , state .keys ())
916
+ else :
917
+ debug_vars = debug_output .get_vars (strategy , state .keys ())
918
+
871
919
write_state = dict (state ) | debug_vars
920
+ if allow_non_finite_values :
921
+ fields_allowing_non_finite_values = list (write_state .keys ())
922
+ else :
923
+ fields_allowing_non_finite_values = list (debug_vars .keys ())
872
924
write_status = driver_tpu .distributed_write_state (
873
925
strategy ,
874
926
_local_state_dict (strategy , write_state ),
@@ -877,7 +929,7 @@ def write_state_and_sync(
877
929
filename_prefix = filename_prefix ,
878
930
step_id = step_id ,
879
931
data_dump_filter = data_dump_filter ,
880
- fields_allowing_non_finite_values = list ( debug_vars . keys ()) ,
932
+ fields_allowing_non_finite_values = fields_allowing_non_finite_values ,
881
933
)
882
934
883
935
# This will block until all replicas are done writing.
@@ -957,7 +1009,7 @@ def write_state_and_sync(
957
1009
write_status = write_state_and_sync (state = state , step_id = step_id_value ())
958
1010
logging .info (
959
1011
'`restoring-checkpoint-if-necessary` stage '
960
- 'done with writing initial steps. Write status are : %s' ,
1012
+ 'done with writing initial steps. Write status: %s' ,
961
1013
write_status ,
962
1014
)
963
1015
# Only after the logging, which forces the `write_status` to be
@@ -1008,20 +1060,22 @@ def write_state_and_sync(
1008
1060
cycle = (step_id_value () - params .start_step ) // params .num_steps
1009
1061
logging .info ('Step %d (cycle %d) is starting.' , step_id_value (), cycle )
1010
1062
t0 = time .time ()
1011
- state , num_steps_completed = _one_cycle (
1063
+ state , num_steps_completed , prev_state , has_non_finite = _one_cycle (
1012
1064
strategy = strategy ,
1013
1065
init_state = state ,
1014
1066
init_step_id = step_id_value (),
1015
1067
num_steps = params .num_steps ,
1016
1068
params = params ,
1017
1069
model = model ,
1018
1070
)
1019
-
1020
- # Completed number steps are guaranteed to be identical for all replicas, so
1021
- # we are just taking replica 0 value.
1022
- completed_steps = _local_state_value (
1071
+ # num_steps_completed and has_non_finite are guaranteed to be identical for
1072
+ # all replicas, so we are just taking replica 0 value.
1073
+ num_steps_completed = _local_state_value (
1023
1074
strategy , num_steps_completed )[0 ].numpy ()
1024
- step_id .assign_add (completed_steps )
1075
+ has_non_finite = _local_state_value (
1076
+ strategy , has_non_finite )[0 ].numpy ()
1077
+
1078
+ step_id .assign_add (num_steps_completed )
1025
1079
1026
1080
if SAVE_MAX_UVW_AND_CFL .value :
1027
1081
# CFL number is guaranteed to be identical for all replicas, so take
@@ -1031,7 +1085,7 @@ def write_state_and_sync(
1031
1085
* _local_state_value (strategy , state [_MAX_UVW_CFL ])[0 ].numpy ()[:, 3 ]
1032
1086
)
1033
1087
max_cfl_number_from_cycle = tf .reduce_max (cfl_values )
1034
- cfl_number_from_last_step = cfl_values [completed_steps - 1 ]
1088
+ cfl_number_from_last_step = cfl_values [num_steps_completed - 1 ]
1035
1089
logging .info (
1036
1090
'max CFL number from last cycle: %.3f. CFL number from last step:'
1037
1091
' %.3f' ,
@@ -1045,17 +1099,22 @@ def write_state_and_sync(
1045
1099
debug_output .log_variable_use ()
1046
1100
1047
1101
# Check if we did not complete a full cycle.
1048
- if completed_steps < params . num_steps :
1102
+ if has_non_finite :
1049
1103
logging .info (
1050
- 'Non-convergence detected. Early exit from cycle %d at step %d.'
1051
- 'Starting dumping the last valid state at step %d' ,
1052
- cycle ,
1053
- step_id_value () + 1 ,
1054
- step_id_value (),
1055
- )
1056
- write_status = write_state_and_sync (state , step_id = step_id_value ())
1104
+ 'Non-convergence detected. Early exit from cycle %d at step %d.' ,
1105
+ cycle , step_id_value ())
1106
+ if num_steps_completed > 1 :
1107
+ write_status = write_state_and_sync (prev_state ,
1108
+ step_id = step_id_value () - 1 ,
1109
+ use_zeros_for_debug_values = True )
1110
+ logging .info (
1111
+ 'Dumping last valid state at step %d done. Write status: %s' ,
1112
+ step_id_value () - 1 , write_status )
1113
+ write_status = write_state_and_sync (state , step_id = step_id_value (),
1114
+ allow_non_finite_values = True )
1057
1115
logging .info (
1058
- 'Dumping full state done. Write status are: %s' , write_status
1116
+ 'Dumping final non-finite state done. Write status: %s' ,
1117
+ write_status
1059
1118
)
1060
1119
# Save checkpoint to update the completed step.
1061
1120
# Note: Only after the logging, which forces the `write_status` to be
@@ -1068,28 +1127,27 @@ def write_state_and_sync(
1068
1127
_write_completion_file (output_dir )
1069
1128
raise _NonRecoverableError (
1070
1129
f'Non-convergence detected. Early exit from cycle { cycle } at step '
1071
- f'{ step_id_value () + 1 } . The last valid state at step '
1072
- f'{ step_id_value ()} has been saved in the specified output path.'
1130
+ f'{ step_id_value ()} . The last valid state at step '
1131
+ f'{ step_id_value () - 1 } has been saved in the specified output path.'
1073
1132
)
1074
1133
1134
+ # Consider explicitly deleting prev_state here to free its memory because
1135
+ # after its written to disk it is no longer needed.
1136
+
1075
1137
replica_id_values = []
1076
1138
replica_id_values .extend (_local_state_value (strategy , state ['replica_id' ]))
1077
1139
logging .info (
1078
1140
'One cycle computation is done. Replicas are: %s' ,
1079
1141
str ([v .numpy () for v in replica_id_values ]),
1080
1142
)
1081
1143
t1 = time .time ()
1082
- # "Recover" float64 precision for dt by rounding the 32 bit value to 6
1083
- # significant digits and then converting back to float. The assumption is
1084
- # that dt is user specified with 6 or less significant digits.
1085
- dt64 = float (np .format_float_positional (params .dt , 6 , fractional = False ))
1086
1144
logging .info (
1087
1145
'Completed total %d steps (%d cycles, %s simulation time) so far. '
1088
1146
'Took %s for the last cycle (%d steps).' ,
1089
1147
step_id_value (),
1090
1148
cycle + 1 ,
1091
1149
text_util .seconds_to_string (
1092
- int (step_id_value ()) * dt64 , precision = dt64
1150
+ int (step_id_value ()) * params . dt64 , precision = params . dt64
1093
1151
),
1094
1152
text_util .seconds_to_string (t1 - t0 ),
1095
1153
params .num_steps ,
@@ -1101,7 +1159,7 @@ def write_state_and_sync(
1101
1159
if (step_id_value () - params .start_step ) % checkpoint_interval == 0 :
1102
1160
write_status = write_state_and_sync (state = state , step_id = step_id_value ())
1103
1161
logging .info (
1104
- '`Post cycle writing full state done. Write status are : %s' ,
1162
+ '`Post cycle writing full state done. Write status: %s' ,
1105
1163
write_status ,
1106
1164
)
1107
1165
# Only after the logging, which forces the `write_status` to be
@@ -1118,7 +1176,7 @@ def write_state_and_sync(
1118
1176
data_dump_filter = data_dump_filter ,
1119
1177
)
1120
1178
logging .info (
1121
- '`Post cycle writing filtered state done. Write status are : %s' ,
1179
+ '`Post cycle writing filtered state done. Write status: %s' ,
1122
1180
write_status ,
1123
1181
)
1124
1182
t2 = time .time ()
0 commit comments