Skip to content

Commit 1cfd834

Browse files
The swirl_lm Authorsjohn-qingwang
The swirl_lm Authors
authored andcommitted
Code update
PiperOrigin-RevId: 677869063
1 parent eb7468a commit 1cfd834

File tree

5 files changed

+293
-14
lines changed

5 files changed

+293
-14
lines changed

swirl_lm/communication/send_recv.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2024 The swirl_lm Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A library for communicating information across replicas.
16+
17+
In an example of 4 replicas, with each replica has data with different sizes as:
18+
replica 0: data = tf.constant([])
19+
replica 1: data = tf.constant([1])
20+
replica 2: data = tf.constant([2, 2])
21+
replica 3: data = tf.constant([3, 3, 3])
22+
If data is shared in an order of 0 -> 1 -> 2 -> 3, the corresponding
23+
`source_dest_pairs` is [[0, 1], [1, 2], [2, 3]]. With a buffer size `n_max = 3`,
24+
calling `send_recv(data, source_dest_pairs, n_max)` provides the following:
25+
replica 0: tf.constant([0, 0, 0])
26+
replica 1: tf.constant([])
27+
replica 2: tf.constant([1])
28+
replica 3: tf.constant([2, 2]).
29+
30+
Note that in the example above, the `source_dest_pairs` can be obtained by
31+
calling `source_dest_pairs_along_dim(np.array([[[0]], [[1]], [[2]], [[3]]]), 0,
32+
True, False)`,
33+
or `source_dest_pairs_along_dim(np.array([[[0]], [[1]], [[2]], [[3]]])
34+
*parse_dim('+x'))`.
35+
"""
36+
37+
import re
38+
39+
import numpy as np
40+
import tensorflow as tf
41+
42+
43+
def parse_dim(dim_info: str) -> tuple[int, bool, bool]:
44+
"""Parses a dimension string into a tuple (dim, forward, periodic).
45+
46+
Args:
47+
dim_info: A string that has a structure '[-+][xyz]p?$'. The first character
48+
is '-' or '+', which indicates the negative or positive direction,
49+
respectively. The second character is one of 'x', 'y', and 'z', which
50+
corresponds to dimension 0, 1, and 2, respectively. The optional last
51+
character is 'p', which suggests the dimension is periodic if present.
52+
53+
Returns:
54+
A 3-element tuple, with the first element being the dimension, the second
55+
indicating whether the dimension is along the positive direction, and the
56+
third indicating whether the dimension is periodic.
57+
58+
Raises:
59+
ValueError if `dim_info` does not match '[-+][xyz]p?$'.
60+
"""
61+
m = re.fullmatch(r'([-+])([xyz])(p?)', dim_info)
62+
if m is None:
63+
raise ValueError(
64+
f'{dim_info} does not conform with the string structure for dimension'
65+
' info ("[-+][xyz]p?$").'
66+
)
67+
68+
dim = 'xyz'.index(m.group(2))
69+
forward = m.group(1) == '+'
70+
periodic = m.group(3) == 'p'
71+
72+
return dim, forward, periodic
73+
74+
75+
def source_dest_pairs_along_dim(
76+
replicas: np.ndarray, dim: int, forward: bool, periodic: bool
77+
) -> np.ndarray:
78+
"""Generates a 2-D array of source-target pairs along `dim` in the topology.
79+
80+
Args:
81+
replicas: A 3-D tensor representing the topology of the partitions.
82+
dim: The dimension of communication. Should be one of 0, 1, and 2.
83+
forward: A boolean argument that indicates sending data from replicas with
84+
lower indices to higher indices along the positive direction of the
85+
topology. If it is `False`, communication in performed the opposite
86+
direction, i.e. from the higher indices to lower indices.
87+
periodic: An indicator of whether the topology is periodic. When using the
88+
`source_dest_pairs` generated with this function, if `periodic` is
89+
`True`, data from the last replica along `dim` will be send to the first
90+
replica; otherwise the first replica returns all zeros with the same size
91+
as the input. The first and last replica follows the direction specified
92+
in `dim`.
93+
94+
Returns:
95+
A 2-D array of size `[num_pairs, 2]`, with the columns being the
96+
`replica_id` of the senders and the receivers, respectively.
97+
"""
98+
rolled = np.roll(replicas, -1 if forward else 1, axis=dim)
99+
trim = slice(
100+
None if periodic or forward else 1,
101+
None if periodic or not forward else -1,
102+
)
103+
stacked = np.moveaxis(np.stack([replicas, rolled]), dim + 1, 1)[:, trim]
104+
return np.reshape(stacked, (2, -1)).T
105+
106+
107+
def send_recv(
108+
data: tf.Tensor, source_dest_pairs: np.ndarray, n_max: int
109+
) -> tf.Tensor:
110+
"""Exchanges N-D `tf.Tensor`s across a list of (sender, receiver) pairs.
111+
112+
Args:
113+
data: The n-dimensional tensor to be sent to a different replica. Dimension
114+
0 of this tensor can have different sizes across replicas
115+
source_dest_pairs: A 2-D numpy array of shape `[num_replicas, 2]`, with the
116+
first column being the senders' `replica_id`, and the second one being the
117+
receiver's `replica_id`.
118+
n_max: The buffer size for the communication. It has to be greater or equal
119+
to the maximum number of `data.shape[0]` across all replicas, otherwise a
120+
runtime error will occur while padding the buffer for communication.
121+
122+
Returns:
123+
An N-D tensor received from the sender replica specified in
124+
`source_dest_pairs`.
125+
"""
126+
# Because `CollectivePermute` permits transferring data that has the same
127+
# shape across all replicas only, we need to pad the input data to satisfy
128+
# this condition.
129+
static_shape = data.get_shape()
130+
u = tf.scatter_nd(
131+
tf.range(tf.shape(data)[0])[:, tf.newaxis],
132+
data,
133+
(n_max, *static_shape[1:]),
134+
)
135+
136+
n_received = tf.raw_ops.CollectivePermute(
137+
input=tf.shape(data)[0], source_target_pairs=source_dest_pairs
138+
)
139+
w = tf.raw_ops.CollectivePermute(
140+
input=u, source_target_pairs=source_dest_pairs
141+
)
142+
# Here we trim the padded data back to its original size.
143+
return tf.gather_nd(w, tf.where(tf.range(n_max) < n_received))

swirl_lm/equations/scalars.proto

+3-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ message TotalEnergy {
126126
}
127127

128128
// Defines configurations for total humidity.
129-
// Next id: 6
129+
// Next id: 7
130130
message Humidity {
131131
// An option of whether the effect of subsidence velocity is included in the
132132
// total humidity equation.
@@ -137,6 +137,8 @@ message Humidity {
137137
// An option of whether condensation is included in the equations for
138138
// humidity.
139139
optional bool include_condensation = 3;
140+
// An option for including sedimentation in the convective flux.
141+
optional bool include_sedimentation = 6;
140142
// Specifies the option of microphysics, which will be used if
141143
// `include_precipitation` is true.
142144
oneof microphysics {

swirl_lm/equations/source_function/humidity.py

+80-11
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def __init__(
7070
self._scalar_params.HasField('humidity') and
7171
self._scalar_params.humidity.include_condensation)
7272

73+
self._include_sedimentation = (
74+
self._scalar_params.HasField('humidity') and
75+
self._scalar_params.humidity.include_sedimentation)
76+
7377
if scalar_name in ('q_r', 'q_s'):
7478
assert self._include_precipitation, (
7579
'Calculating q_r without setting include_precipitation to True is not'
@@ -80,6 +84,7 @@ def __init__(
8084
if (
8185
self._include_precipitation
8286
or self._include_condensation
87+
or (self._include_sedimentation and scalar_name == 'q_t')
8388
or scalar_name in ('q_r', 'q_s')
8489
):
8590
assert self._scalar_params.humidity.HasField('microphysics'), (
@@ -332,18 +337,82 @@ def source_fn(
332337
)
333338

334339
# Compute source terms
335-
if self._scalar_name == 'q_t' and self._include_subsidence:
336-
subsidence_source = eq_utils.source_by_subsidence_velocity(
337-
self._deriv_lib,
338-
states[common.KEY_RHO],
339-
thermo_states['zz'],
340-
thermo_states['q_c'],
341-
self._g_dim,
342-
additional_states,
343-
)
340+
if self._scalar_name == 'q_t':
341+
if self._include_subsidence:
342+
subsidence_source = eq_utils.source_by_subsidence_velocity(
343+
self._deriv_lib,
344+
states[common.KEY_RHO],
345+
thermo_states['zz'],
346+
thermo_states['q_c'],
347+
self._g_dim,
348+
additional_states,
349+
)
350+
source = tf.nest.map_structure(tf.math.add, source, subsidence_source)
344351

345-
# Add external source, e.g. sponge forcing and subsidence.
346-
source = tf.nest.map_structure(tf.math.add, source, subsidence_source)
352+
if self._include_sedimentation:
353+
assert self._microphysics is not None, (
354+
'Terminal velocity for q_t requires a microphysics model but is'
355+
' undefined.'
356+
)
357+
w_l = self._microphysics.terminal_velocity(
358+
'q_r',
359+
{'rho_thermal': states['rho_thermal'], 'q_r': thermo_states['q_l']},
360+
additional_states,
361+
)
362+
if 'q_r' in states:
363+
# The sedimentation velocity of liquid-phase cloud cannot be larger
364+
# than that of rain.
365+
w_r = self._microphysics.terminal_velocity(
366+
'q_r',
367+
{
368+
'rho_thermal': states['rho_thermal'],
369+
'q_r': states['q_r'],
370+
},
371+
additional_states,
372+
)
373+
w_l = tf.nest.map_structure(
374+
lambda w_l, w_r: tf.where(tf.greater(w_l, w_r), w_r, w_l),
375+
w_l,
376+
w_r,
377+
)
378+
w_i = self._microphysics.terminal_velocity(
379+
'q_s',
380+
{'rho_thermal': states['rho_thermal'], 'q_s': thermo_states['q_i']},
381+
additional_states,
382+
)
383+
if 'q_s' in states:
384+
# The sedimentation velocity of ice-phase cloud cannot be larger than
385+
# that of snow.
386+
w_s = self._microphysics.terminal_velocity(
387+
'q_s',
388+
{
389+
'rho_thermal': states['rho_thermal'],
390+
'q_s': states['q_s'],
391+
},
392+
additional_states,
393+
)
394+
w_i = tf.nest.map_structure(
395+
lambda w_i, w_s: tf.where(tf.greater(w_i, w_s), w_s, w_i),
396+
w_i,
397+
w_s,
398+
)
399+
sedimentation_flux_fn = lambda rho, q_l, q_i, w_l, w_i: rho * (
400+
q_l * w_l + q_i * w_i
401+
)
402+
sedimentation_flux = tf.nest.map_structure(
403+
sedimentation_flux_fn,
404+
states['rho'],
405+
thermo_states['q_l'],
406+
thermo_states['q_i'],
407+
w_l,
408+
w_i,
409+
)
410+
sedimentation_source = self._deriv_lib.deriv_centered(
411+
sedimentation_flux, self._g_dim, additional_states
412+
)
413+
source = tf.nest.map_structure(
414+
tf.math.add, source, sedimentation_source
415+
)
347416

348417
if self._include_condensation:
349418
assert isinstance(self._thermodynamics.model, water.Water), (

swirl_lm/numerics/interpolation.py

+63
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
import enum
1818
import functools
19+
import itertools
1920
from typing import Sequence, Tuple, TypeAlias
2021

22+
from swirl_lm.utility import common_ops
2123
from swirl_lm.utility import get_kernel_fn
2224
from swirl_lm.utility import types
2325
import tensorflow as tf
@@ -440,3 +442,64 @@ def muscl(r: FlowFieldVal) -> FlowFieldVal:
440442
diff1,
441443
)
442444
return v_neg, v_pos
445+
446+
447+
def trilinear_interpolation(
448+
field_data: tf.Tensor,
449+
points: tf.Tensor,
450+
grid_spacing: tf.Tensor,
451+
domain_min_pt: tuple[float, float, float] = (0.0, 0.0, 0.0),
452+
) -> tf.Tensor:
453+
"""Linear interpolation on a 3-D orthogonal, uniform grid.
454+
455+
Performs trilinear interpolation by calculating the local point coordinate
456+
indices and then interpolating the surrounding field data at the point.
457+
Points provided outside of the core domain give invalid solutions.
458+
459+
Note that coordinates in this function follow the order of the dimensions of
460+
`field_data` as a 3D tensor instead of the physical-coordinates orientation in
461+
Swirl-LM. For instance, the first element in `grid_spacing` and
462+
`domain_min_pt`, as well as the first column in `points`, are associated with
463+
the 0th dimensions of `field_data`, instead of the 'x' axis in Swirl-LM.
464+
465+
Args:
466+
field_data: A 3D tensor of field scalars without halos.
467+
points: An 2D tensor (n, 3) of n coordinate points in 3D space to
468+
interpolate at. Points must fall within the range of coordinates
469+
associated with the core. If points are outside of this domain, the
470+
function will return invalid results.
471+
grid_spacing: A three element tensor defining the grid spacing along the
472+
three dimensions.
473+
domain_min_pt: A three element tuple defining the minimum coordinate
474+
location within the entire physical domain encompassing all cores, default
475+
is (0, 0, 0).
476+
477+
Returns:
478+
An n element tensor containing interpolated data values at the n supplied
479+
points.
480+
"""
481+
core_spacing = grid_spacing * (
482+
tf.cast(field_data.shape, dtype=tf.float32) - 1
483+
)
484+
485+
# Normalizes points to be within the range (0, 0, 0,) to (nx, ny, nz) for nx,
486+
# ny, and nz nodes in the core.
487+
points_norm = (
488+
(points - tf.constant(domain_min_pt)) % core_spacing
489+
) / grid_spacing
490+
ijk = tf.floor(points_norm)
491+
points_norm -= ijk
492+
ijk = tf.cast(ijk, dtype=tf.int32)
493+
i, j, k = ijk[:, 0], ijk[:, 1], ijk[:, 2]
494+
x0, x1, x2 = points_norm[:, 0], points_norm[:, 1], points_norm[:, 2]
495+
496+
values = tf.zeros(points.shape[0], dtype=field_data.dtype)
497+
for p, q, l in itertools.product(range(2), range(2), range(2)):
498+
v = common_ops.gather(field_data, tf.stack([i + p, j + q, k + l], axis=-1))
499+
values += v * (
500+
((1 - p) + (2 * p - 1) * x0)
501+
* ((1 - q) + (2 * q - 1) * x1)
502+
* ((1 - l) + (2 * l - 1) * x2)
503+
)
504+
505+
return values

swirl_lm/physics/combustion/wood.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,11 @@ def _radiative_emission(
227227
emissivity (< 1).
228228
229229
Returns:
230-
The radiation source term due to emission.
230+
The radiation source term due to emission. If `t` is less than `t_ambient`,
231+
the radiation term is 0, i.e., radiation energy can only be lost to
232+
ambient conditions.
231233
"""
232-
return _SIGMA * k / l * (t**4 - t_ambient**4)
234+
return tf.maximum(_SIGMA * k / l * (t**4 - t_ambient**4), 0.0)
233235

234236

235237
def _evaporation(

0 commit comments

Comments
 (0)