Skip to content

Commit ae0cf63

Browse files
The swirl_lm Authorsjohn-qingwang
The swirl_lm Authors
authored andcommitted
Code update
PiperOrigin-RevId: 699277789
1 parent 1cfd834 commit ae0cf63

File tree

86 files changed

+4520
-981
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+4520
-981
lines changed

swirl_lm/base/driver.py

+113-55
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from swirl_lm.boundary_condition import nonreflecting_boundary
2929
from swirl_lm.core import simulation
3030
from swirl_lm.linalg import poisson_solver
31+
from swirl_lm.physics.lpt import lpt
32+
from swirl_lm.physics.lpt import lpt_manager
3133
from swirl_lm.physics.radiation import rrtmgp_common
3234
from swirl_lm.utility import common_ops
3335
from swirl_lm.utility import debug_output
@@ -94,6 +96,7 @@
9496
CKPT_DIR_FORMAT = '{filename_prefix}-ckpts/'
9597
COMPLETION_FILE = 'DONE'
9698
_MAX_UVW_CFL = 'max_uvw_cfl'
99+
SIMULATION_TIME = 'simulation_time'
97100

98101
Array: TypeAlias = Any
99102
PerReplica: TypeAlias = tf.types.experimental.distributed.PerReplica
@@ -191,6 +194,9 @@ def _get_state_keys(params: parameters_lib.SwirlLMParameters):
191194
# Add additional keys required by the radiative transfer library.
192195
additional_keys += rrtmgp_common.required_keys(params.radiative_transfer)
193196

197+
# Add additional keys required by the lagrangian particle tracking library.
198+
additional_keys += lpt.required_keys(params.lpt)
199+
194200
# Check to make sure we don't have keys duplicating / overwriting each other.
195201
if len(set(essential_keys)) + len(set(additional_keys)) + len(
196202
set(helper_var_keys)
@@ -250,6 +256,11 @@ def init_fn(
250256
(params.num_steps, 4), dtype=types.TF_DTYPE
251257
)
252258

259+
states[SIMULATION_TIME] = tf.convert_to_tensor(0, dtype=tf.float64)
260+
261+
if params.lpt is not None:
262+
states.update(lpt.init_fn(params))
263+
253264
# Apply the user defined `init_fn` in the end to allow it to override
254265
# default initializations.
255266
if customized_init_fn is not None:
@@ -348,6 +359,20 @@ def _update_additional_states(
348359
)
349360
)
350361

362+
# Update lagrangian particle additional states.
363+
with tf.name_scope('lpt_additional_states_update'):
364+
lpt_field = lpt_manager.lpt_factory(params)
365+
if lpt_field is not None:
366+
updated_additional_states.update(
367+
lpt_field.step(
368+
replica_id=common_kwargs['replica_id'],
369+
replicas=common_kwargs['replicas'],
370+
states=essential_states,
371+
additional_states=additional_states,
372+
step_id=step_id,
373+
),
374+
)
375+
351376
if params.additional_states_update_fn is not None:
352377
with tf.name_scope('additional_states_update'):
353378
updated_additional_states = params.additional_states_update_fn(
@@ -362,21 +387,24 @@ def _update_additional_states(
362387

363388
def _state_has_nan_inf(state: dict[str, PerReplica], replicas: Array) -> bool:
364389
"""Checks whether any field in the `state` contains `nan` or `inf`."""
365-
has_nan_inf = False
390+
local_has_nan_inf = False
366391
for _, v in state.items():
367-
# We want to make sure every core is seeing the same problem so the
368-
# termination of all cores is synchronized, thus a global reduce operation
369-
# is needed.
370-
if common_ops.global_reduce(
371-
tf.math.logical_or(tf.math.is_nan(v), tf.math.is_inf(v)),
372-
tf.reduce_any,
373-
replicas.reshape([1, -1]),
374-
):
375-
has_nan_inf = True
376-
# For some reason, the graph tracing in this case doesn't allow early break
377-
# of the loop. For now we will just check through all fields without early
378-
# break. This should still be very efficient.
379-
return has_nan_inf
392+
if (v.dtype.is_floating and tf.reduce_any(
393+
tf.math.logical_or(tf.math.is_nan(v), tf.math.is_inf(v)))):
394+
local_has_nan_inf = True
395+
# Graph compilation doesn't allow early break.
396+
397+
# We want to make sure every core is seeing the same problem so the
398+
# termination of all cores is synchronized, thus a global reduce operation
399+
# is needed.
400+
if common_ops.global_reduce(
401+
tf.convert_to_tensor(local_has_nan_inf),
402+
tf.reduce_any,
403+
replicas.reshape([1, -1]),
404+
):
405+
return True
406+
else:
407+
return False
380408

381409

382410
def _compute_max_uvw_and_cfl(
@@ -437,7 +465,8 @@ def _one_cycle(
437465
num_steps: Array,
438466
params: Union[parameters_lib.SwirlLMParameters, Any],
439467
model: Any,
440-
) -> tuple[dict[str, PerReplica], dict[str, PerReplica]]:
468+
) -> tuple[dict[str, PerReplica], PerReplica, dict[str, PerReplica],
469+
PerReplica]:
441470
"""Runs one cycle of the Navier-Stokes solver.
442471
443472
Args:
@@ -452,8 +481,12 @@ def _one_cycle(
452481
defined.
453482
454483
Returns:
455-
A 2-tuple of the final state at the end of the cycle and the completed
456-
number of steps, both will be in the dict[str, PerReplica] format.
484+
A (final_state, completed_steps, previous_state, has_non_finite) where
485+
final_state is the state at the end of the cycle, completed_steps is the
486+
number of steps completed in this cycle, previous_state is the state one
487+
step before the final_state (which is useful in blow-ups), and
488+
has_non_finite is a boolean indicating whether the final_state contains
489+
non-finite values.
457490
"""
458491
logging.info(
459492
'Tracing and compiling of _one_cycle starts. This can take up to 30 min.'
@@ -500,6 +533,8 @@ def step_fn(state):
500533
state[_MAX_UVW_CFL] = tf.zeros_like(state[_MAX_UVW_CFL])
501534

502535
cycle_step_id = 0
536+
prev_state = state
537+
has_non_finite = False
503538
for _ in tf.range(num_steps):
504539
step_id = init_step_id + cycle_step_id
505540
# Split the state into essential states and additional states. Note that
@@ -596,22 +631,29 @@ def step_fn(state):
596631
),
597632
)
598633

599-
if SAVE_LAST_VALID_STEP.value:
600-
if _state_has_nan_inf(updated_state, logical_replicas):
601-
# Detected nan/inf, skip the update of state by early-exiting from the
602-
# for loop.
603-
break
634+
# Simulation time will accumulate precision errors with this approach.
635+
# For now, the converter deals with this by rounding off to a fewer
636+
# number of significant digits.
637+
updated_state[SIMULATION_TIME] = state[SIMULATION_TIME] + params.dt64
638+
639+
prev_state = state
604640
# Some state keys such as `replica_id` may not lie in either of the three
605641
# categories. Just pass them through.
606642
state = _stateless_update_if_present(state, updated_state)
607643
cycle_step_id += 1
644+
if SAVE_LAST_VALID_STEP.value:
645+
if _state_has_nan_inf(state, logical_replicas):
646+
# Detected nan/inf, skip the update of state by early-exiting from the
647+
# for loop.
648+
has_non_finite = True
649+
break
608650

609651
if not params.use_3d_tf_tensor:
610652
# Unsplit the keys that were previously split.
611653
for key in keys_to_split:
612654
state[key] = tf.stack(state[key])
613655

614-
return state, cycle_step_id
656+
return state, cycle_step_id, prev_state, has_non_finite
615657

616658
return strategy.run(step_fn, args=(init_state,))
617659

@@ -740,10 +782,10 @@ def solver(
740782
num_steps, 2 * num_steps, ..., num_cycles * num_steps].
741783
742784
If the simulation reaches a state where any variable has a non-finite value
743-
(NaN or Inf), then the simulation will stop early and save the state one
744-
step before the state that contains non-finite values. As a result, there
745-
will be fewer steps saved than `num_cycles + ` and the final saved step
746-
number will not necessarily be a multiple of num_steps.
785+
(NaN or Inf), then the simulation will stop early and save both the
786+
non-finite state and the one before it. As a result, there will most likely
787+
be fewer steps saved than `num_cycles + 1` and the final saved step number
788+
will not necessarily be a multiple of num_steps.
747789
748790
In both of the these cases (`num_cycles` reached or non-finite value seen),
749791
the solver will write an empty `DONE` file to the output directory to
@@ -866,9 +908,19 @@ def write_state_and_sync(
866908
state: dict[str, PerReplica],
867909
step_id: Array,
868910
data_dump_filter: Optional[Sequence[str]] = None,
911+
allow_non_finite_values: bool = False,
912+
use_zeros_for_debug_values: bool = False,
869913
):
870-
debug_vars = debug_output.get_vars(strategy, state.keys())
914+
if use_zeros_for_debug_values:
915+
debug_vars = debug_output.zeros_like_vars(strategy, state.keys())
916+
else:
917+
debug_vars = debug_output.get_vars(strategy, state.keys())
918+
871919
write_state = dict(state) | debug_vars
920+
if allow_non_finite_values:
921+
fields_allowing_non_finite_values = list(write_state.keys())
922+
else:
923+
fields_allowing_non_finite_values = list(debug_vars.keys())
872924
write_status = driver_tpu.distributed_write_state(
873925
strategy,
874926
_local_state_dict(strategy, write_state),
@@ -877,7 +929,7 @@ def write_state_and_sync(
877929
filename_prefix=filename_prefix,
878930
step_id=step_id,
879931
data_dump_filter=data_dump_filter,
880-
fields_allowing_non_finite_values=list(debug_vars.keys()),
932+
fields_allowing_non_finite_values=fields_allowing_non_finite_values,
881933
)
882934

883935
# This will block until all replicas are done writing.
@@ -957,7 +1009,7 @@ def write_state_and_sync(
9571009
write_status = write_state_and_sync(state=state, step_id=step_id_value())
9581010
logging.info(
9591011
'`restoring-checkpoint-if-necessary` stage '
960-
'done with writing initial steps. Write status are: %s',
1012+
'done with writing initial steps. Write status: %s',
9611013
write_status,
9621014
)
9631015
# Only after the logging, which forces the `write_status` to be
@@ -1008,20 +1060,22 @@ def write_state_and_sync(
10081060
cycle = (step_id_value() - params.start_step) // params.num_steps
10091061
logging.info('Step %d (cycle %d) is starting.', step_id_value(), cycle)
10101062
t0 = time.time()
1011-
state, num_steps_completed = _one_cycle(
1063+
state, num_steps_completed, prev_state, has_non_finite = _one_cycle(
10121064
strategy=strategy,
10131065
init_state=state,
10141066
init_step_id=step_id_value(),
10151067
num_steps=params.num_steps,
10161068
params=params,
10171069
model=model,
10181070
)
1019-
1020-
# Completed number steps are guaranteed to be identical for all replicas, so
1021-
# we are just taking replica 0 value.
1022-
completed_steps = _local_state_value(
1071+
# num_steps_completed and has_non_finite are guaranteed to be identical for
1072+
# all replicas, so we are just taking replica 0 value.
1073+
num_steps_completed = _local_state_value(
10231074
strategy, num_steps_completed)[0].numpy()
1024-
step_id.assign_add(completed_steps)
1075+
has_non_finite = _local_state_value(
1076+
strategy, has_non_finite)[0].numpy()
1077+
1078+
step_id.assign_add(num_steps_completed)
10251079

10261080
if SAVE_MAX_UVW_AND_CFL.value:
10271081
# CFL number is guaranteed to be identical for all replicas, so take
@@ -1031,7 +1085,7 @@ def write_state_and_sync(
10311085
* _local_state_value(strategy, state[_MAX_UVW_CFL])[0].numpy()[:, 3]
10321086
)
10331087
max_cfl_number_from_cycle = tf.reduce_max(cfl_values)
1034-
cfl_number_from_last_step = cfl_values[completed_steps - 1]
1088+
cfl_number_from_last_step = cfl_values[num_steps_completed - 1]
10351089
logging.info(
10361090
'max CFL number from last cycle: %.3f. CFL number from last step:'
10371091
' %.3f',
@@ -1045,17 +1099,22 @@ def write_state_and_sync(
10451099
debug_output.log_variable_use()
10461100

10471101
# Check if we did not complete a full cycle.
1048-
if completed_steps < params.num_steps:
1102+
if has_non_finite:
10491103
logging.info(
1050-
'Non-convergence detected. Early exit from cycle %d at step %d.'
1051-
'Starting dumping the last valid state at step %d',
1052-
cycle,
1053-
step_id_value() + 1,
1054-
step_id_value(),
1055-
)
1056-
write_status = write_state_and_sync(state, step_id=step_id_value())
1104+
'Non-convergence detected. Early exit from cycle %d at step %d.',
1105+
cycle, step_id_value())
1106+
if num_steps_completed > 1:
1107+
write_status = write_state_and_sync(prev_state,
1108+
step_id=step_id_value() - 1,
1109+
use_zeros_for_debug_values=True)
1110+
logging.info(
1111+
'Dumping last valid state at step %d done. Write status: %s',
1112+
step_id_value() - 1, write_status)
1113+
write_status = write_state_and_sync(state, step_id=step_id_value(),
1114+
allow_non_finite_values=True)
10571115
logging.info(
1058-
'Dumping full state done. Write status are: %s', write_status
1116+
'Dumping final non-finite state done. Write status: %s',
1117+
write_status
10591118
)
10601119
# Save checkpoint to update the completed step.
10611120
# Note: Only after the logging, which forces the `write_status` to be
@@ -1068,28 +1127,27 @@ def write_state_and_sync(
10681127
_write_completion_file(output_dir)
10691128
raise _NonRecoverableError(
10701129
f'Non-convergence detected. Early exit from cycle {cycle} at step '
1071-
f'{step_id_value() + 1}. The last valid state at step '
1072-
f'{step_id_value()} has been saved in the specified output path.'
1130+
f'{step_id_value()}. The last valid state at step '
1131+
f'{step_id_value() - 1} has been saved in the specified output path.'
10731132
)
10741133

1134+
# Consider explicitly deleting prev_state here to free its memory because
1135+
# after its written to disk it is no longer needed.
1136+
10751137
replica_id_values = []
10761138
replica_id_values.extend(_local_state_value(strategy, state['replica_id']))
10771139
logging.info(
10781140
'One cycle computation is done. Replicas are: %s',
10791141
str([v.numpy() for v in replica_id_values]),
10801142
)
10811143
t1 = time.time()
1082-
# "Recover" float64 precision for dt by rounding the 32 bit value to 6
1083-
# significant digits and then converting back to float. The assumption is
1084-
# that dt is user specified with 6 or less significant digits.
1085-
dt64 = float(np.format_float_positional(params.dt, 6, fractional=False))
10861144
logging.info(
10871145
'Completed total %d steps (%d cycles, %s simulation time) so far. '
10881146
'Took %s for the last cycle (%d steps).',
10891147
step_id_value(),
10901148
cycle + 1,
10911149
text_util.seconds_to_string(
1092-
int(step_id_value()) * dt64, precision=dt64
1150+
int(step_id_value()) * params.dt64, precision=params.dt64
10931151
),
10941152
text_util.seconds_to_string(t1 - t0),
10951153
params.num_steps,
@@ -1101,7 +1159,7 @@ def write_state_and_sync(
11011159
if (step_id_value() - params.start_step) % checkpoint_interval == 0:
11021160
write_status = write_state_and_sync(state=state, step_id=step_id_value())
11031161
logging.info(
1104-
'`Post cycle writing full state done. Write status are: %s',
1162+
'`Post cycle writing full state done. Write status: %s',
11051163
write_status,
11061164
)
11071165
# Only after the logging, which forces the `write_status` to be
@@ -1118,7 +1176,7 @@ def write_state_and_sync(
11181176
data_dump_filter=data_dump_filter,
11191177
)
11201178
logging.info(
1121-
'`Post cycle writing filtered state done. Write status are: %s',
1179+
'`Post cycle writing filtered state done. Write status: %s',
11221180
write_status,
11231181
)
11241182
t2 = time.time()

swirl_lm/base/initializer.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,15 @@ def get_slice_in_dim(core_n, length, num_cores, core_id, provided_mesh):
167167
'smaller than total number of core partitioning in z direction.')
168168

169169
if pad_mode == 'PHYSICAL':
170-
xs, ys, zs = [
171-
params.grid_local_with_coord(coordinate, dim, True)
172-
for dim in range(3)
173-
]
170+
xs = common_ops.get_local_slice_of_1d_array(
171+
params.global_xyz_with_halos[0], coordinate[0], core_nx, nx
172+
)
173+
ys = common_ops.get_local_slice_of_1d_array(
174+
params.global_xyz_with_halos[1], coordinate[1], core_ny, ny
175+
)
176+
zs = common_ops.get_local_slice_of_1d_array(
177+
params.global_xyz_with_halos[2], coordinate[2], core_nz, nz
178+
)
174179
else:
175180
xs = get_slice_in_dim(core_nx, lx, cx, gx, params.x)
176181
ys = get_slice_in_dim(core_ny, ly, cy, gy, params.y)

0 commit comments

Comments
 (0)