Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 75 additions & 13 deletions jax_cfd/base/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ def stable_time_step(viscosity: float, grid: grids.Grid) -> float:
return dx ** 2 / (viscosity * 2 ** ndim)


def _subtract_linear_part_dirichlet(
def _add_or_subtract_linear_part_dirichlet(
c_data: Array,
grid: grids.Grid,
axis: int,
subtract: bool,
offset: Tuple[float, float],
bc_values: Tuple[float, float],
) -> Array:
Expand All @@ -74,6 +75,7 @@ def _subtract_linear_part_dirichlet(
c_data: right-hand-side of diffusion equation.
grid: grid object
axis: axis along which to impose boundary transformation
subtract: if the linear part needs to be added or subtracted.
offset: offset of the right-hand-side
bc_values: boundary values along axis

Expand All @@ -82,15 +84,16 @@ def _subtract_linear_part_dirichlet(
"""

def _update_rhs_along_axis(arr_1d, linear_part):
arr_1d = arr_1d - linear_part
if subtract:
linear_part = - linear_part
arr_1d = arr_1d + linear_part
return arr_1d

lower_value, upper_value = bc_values
y = grid.mesh(offset)[axis][0]
one_d_grid = grids.Grid((grid.shape[axis],), domain=(grid.domain[axis],))
y_boundary = boundaries.dirichlet_boundary_conditions(ndim=1)
y = y_boundary.trim_boundary(grids.GridArray(y, (offset[axis],),
one_d_grid)).data
if grid.ndim == 3:
y = y[0]
# TODO(ayyaalieva): coincide the two
domain_length = (grid.domain[axis][1] - grid.domain[axis][0])
domain_start = grid.domain[axis][0]
linear_part = lower_value + (upper_value - lower_value) * (
Expand All @@ -100,10 +103,67 @@ def _update_rhs_along_axis(arr_1d, linear_part):
return c_data


def _rhs_transform(
def _subtract_linear_part_dirichlet(
c_data: Array,
grid: grids.Grid,
axis: int,
offset: Tuple[float, float],
bc_values: Tuple[float, float],
) -> Array:
"""Wrapper for _add_or_subtract_linear_part_dirichlet.

Args:
c_data: right-hand-side of diffusion equation.
grid: grid object
axis: axis along which to impose boundary transformation
offset: offset of the right-hand-side
bc_values: boundary values along axis

Returns:
transformed right-hand-side
"""
return _add_or_subtract_linear_part_dirichlet(
c_data,
grid,
axis,
True,
offset,
bc_values)


def _add_linear_part_dirichlet(
c_data: Array,
grid: grids.Grid,
axis: int,
offset: Tuple[float, float],
bc_values: Tuple[float, float],
) -> Array:
"""Wrapper for _add_or_subtract_linear_part_dirichlet.

Args:
c_data: right-hand-side of diffusion equation.
grid: grid object
axis: axis along which to impose boundary transformation
offset: offset of the right-hand-side
bc_values: boundary values along axis

Returns:
transformed right-hand-side
"""
return _add_or_subtract_linear_part_dirichlet(
c_data,
grid,
axis,
False,
offset,
bc_values)


def rhs_transform(
u: grids.GridArray,
bc: boundaries.BoundaryConditions,
) -> Array:
subtract_linear_part: bool = True,
) -> grids.GridArray:
"""Transforms the RHS of diffusion equation.

In case of constant dirichlet boundary conditions for heat equation
Expand All @@ -112,6 +172,7 @@ def _rhs_transform(
Args:
u: a GridArray that solves ∇²x = ∇²u for x.
bc: specifies boundary of u.
subtract_linear_part: if True, linear part is subtracted from rhs.

Returns:
u' s.t. u = u' + w where u' has 0 dirichlet bc and w is linear.
Expand All @@ -125,13 +186,14 @@ def _rhs_transform(
if bc.types[axis][i] == boundaries.BCType.DIRICHLET:
bc_values = [0., 0.]
bc_values[i] = bc.bc_values[axis][i]
u_data = _subtract_linear_part_dirichlet(u_data, u.grid, axis, u.offset,
bc_values)
u_data = _add_or_subtract_linear_part_dirichlet(u_data, u.grid, axis,
subtract_linear_part,
u.offset, bc_values)
elif bc.types[axis][i] == boundaries.BCType.NEUMANN:
if any(bc.bc_values[axis]):
raise NotImplementedError(
'transformation is not implemented for inhomogeneous Neumann bc.')
return u_data
return grids.GridArray(u_data, u.offset, u.grid)


def solve_cg(v: GridVariableVector,
Expand Down Expand Up @@ -182,7 +244,7 @@ def func(x):
# If dirichlet bc are supplied: only works for dirichlet bc that are linear
# functions on the boundary. Then u = u' + w where u' has 0 dirichlet bc and
# w is linear. Then u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u = u +
# (1 - ν Δt ∇²)⁻¹(ν Δt ∇²)u'. The function _rhs_transform subtracts
# (1 - ν Δt ∇²)⁻¹(ν Δt ∇²)u'. The function rhs_transform subtracts
# the linear part s.t. fast_diagonalization solves
# u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u'.
v_diffused = list()
Expand All @@ -203,7 +265,7 @@ def func(x):
circulant=circulant,
implementation=implementation)
u_interior = u.bc.trim_boundary(u.array)
u_interior_transformed = _rhs_transform(u_interior, u.bc)
u_interior_transformed = rhs_transform(u_interior, u.bc)
u_dt_diffused = grids.GridArray(
op(u_interior_transformed), u_interior.offset, u_interior.grid)
u_diffused = u_interior + u_dt_diffused
Expand Down
7 changes: 5 additions & 2 deletions jax_cfd/base/finite_differences.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ def forward_difference(u, axis=None):
def laplacian(u: GridVariable) -> GridArray:
"""Approximates the Laplacian of `u`."""
scales = np.square(1 / np.array(u.grid.step, dtype=u.dtype))
result = -2 * u.array * np.sum(scales)
bc = u.bc
u = u.trim_boundary()
result = u * 0
for axis in range(u.grid.ndim):
result += stencil_sum(u.shift(-1, axis), u.shift(+1, axis)) * scales[axis]
result += stencil_sum(bc.shift(u, -1, axis), -2 * u, bc.shift(
u, + 1, axis)) * scales[axis]
return result


Expand Down
146 changes: 99 additions & 47 deletions jax_cfd/base/subgrid_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

import jax
from jax_cfd.base import boundaries
from jax_cfd.base import diffusion
from jax_cfd.base import equations
from jax_cfd.base import finite_differences
from jax_cfd.base import forcings
from jax_cfd.base import grids
from jax_cfd.base import interpolation
import numpy as np


GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
Expand All @@ -37,6 +37,44 @@
# TODO(pnorgaard) Refactor subgrid_models to interpolate, then differentiate


def centered_strain_rate_tensor(
v: grids.GridVariableVector,
) -> grids.GridArrayTensor:
"""Computes centered strain rate tensor.

Args:
v: velocity field.

Returns:
centered strain rate tensor.
"""

grid = grids.consistent_grid(*v)
v_centered = tuple(interpolation.linear(u, grid.cell_center) for u in v)

def make_strain_rate_bc(v):
# viscocity vanishes at the boundary.
types = []
for ax in range(grid.ndim):
if v[0].bc.types[ax][0] == boundaries.BCType.PERIODIC:
types.append((boundaries.BCType.PERIODIC, boundaries.BCType.PERIODIC))
elif v[0].bc.types[ax][0] == boundaries.BCType.DIRICHLET and v[
0].bc.types[ax][1] == boundaries.BCType.DIRICHLET:
types.append((boundaries.BCType.DIRICHLET, boundaries.BCType.DIRICHLET))
else:
raise ValueError(
f'boundary condition {v[0].bc.types[ax]} is not implemented')
return boundaries.HomogeneousBoundaryConditions(types)

strain_rate_bc = make_strain_rate_bc(v)
s_ij = grids.GridArrayTensor([[ # pylint: disable=g-complex-comprehension
strain_rate_bc.impose_bc(
0.5 * (finite_differences.central_difference(v_centered[i], j) +
finite_differences.central_difference(v_centered[j], i)))
for j in range(grid.ndim)] for i in range(grid.ndim)])
return s_ij


def smagorinsky_viscosity(
s_ij: grids.GridArrayTensor,
v: GridVariableVector,
Expand Down Expand Up @@ -73,35 +111,28 @@ def smagorinsky_viscosity(
# velocity and then computing s_ij via finite differences, producing
# a `GridVariableTensor`. Then no wrapper or GridArray/GridVariable
# conversion hacks are needed.
if not boundaries.has_all_periodic_boundary_conditions(*v):
raise ValueError('smagorinsky_viscosity only valid for periodic BC.')
bc = grids.unique_boundary_conditions(*v)

def wrapped_interp_fn(c, offset, v, dt):
return interpolate_fn(grids.GridVariable(c, bc), offset, v, dt).array

del dt, interpolate_fn
grid = grids.consistent_grid(*s_ij.ravel(), *v)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
s_ij_offsets = [array.offset for array in s_ij.ravel()]
unique_offsets = list(set(s_ij_offsets))
cell_center = grid.cell_center
interpolate_to_center = lambda x: wrapped_interp_fn(x, cell_center, v, dt)
centered_s_ij = np.vectorize(interpolate_to_center)(s_ij)
unique_offsets = list(
set([tuple(abs(o) for o in offset) for offset in unique_offsets]))
if len(unique_offsets) > 1 or unique_offsets[0] != grid.cell_center:
raise ValueError('This function requires cell-centered strain rate tensor.')
# geometric average
cutoff = np.prod(np.array(grid.step))**(1 / grid.ndim)
viscosity = (cs * cutoff)**2 * np.sqrt(
2 * np.trace(centered_s_ij.dot(centered_s_ij)))
viscosities_dict = {
offset: wrapped_interp_fn(viscosity, offset, v, dt).data
for offset in unique_offsets}
viscosities = [viscosities_dict[offset] for offset in s_ij_offsets]
return jax.tree_unflatten(jax.tree_util.tree_structure(s_ij), viscosities)
s_ij_array = grids.GridArrayTensor(
[[s_ij[i, j].array for j in range(grid.ndim)] for i in range(grid.ndim)])
s_abs = np.sqrt(
2 * np.trace(s_ij_array.dot(s_ij_array)))
viscosity = (cs * cutoff)**2
return viscosity * s_abs


def evm_model(
v: GridVariableVector,
viscosity_fn: ViscosityFn,
) -> GridArrayVector:
) -> GridVariableVector:
"""Computes acceleration due to eddy viscosity turbulence model.

Eddy viscosity models compute a turbulence closure term as a divergence of
Expand All @@ -120,18 +151,21 @@ def evm_model(
if not boundaries.has_all_periodic_boundary_conditions(*v):
raise ValueError('evm_model only valid for periodic BC.')
grid = grids.consistent_grid(*v)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
s_ij = grids.GridArrayTensor([
[0.5 * (finite_differences.forward_difference(v[i], j) + # pylint: disable=g-complex-comprehension
finite_differences.forward_difference(v[j], i))
for j in range(grid.ndim)]
for i in range(grid.ndim)])
s_ij = centered_strain_rate_tensor(v)
viscosity = viscosity_fn(s_ij, v)
tau = jax.tree_map(lambda x, y: -2. * x * y, viscosity, s_ij)
return tuple(-finite_differences.divergence( # pylint: disable=g-complex-comprehension
tuple(grids.GridVariable(t, bc) # use velocity bc to compute diverence
for t in tau[i, :]))
for i in range(grid.ndim))
tau = grids.GridArrayTensor([
[s.bc.impose_bc(2 * viscosity * s.array) for s in s_ij[i]]
for i in range(grid.ndim)])
strain_rate_div = tuple(-finite_differences.centered_divergence(tau[i])
for i in range(grid.ndim))
homogeneous_bc = lambda u: boundaries.HomogeneousBoundaryConditions(u.bc.types # pylint: disable=g-long-lambda
)
strain_rate_div = tuple(
homogeneous_bc(u).impose_bc(strain_div)
for u, strain_div in zip(v, strain_rate_div))
return tuple(
interpolation.linear(strain_div, u.offset)
for strain_div, u in zip(strain_rate_div, v))


# TODO(dkochkov) remove when b/160947162 is resolved.
Expand All @@ -140,7 +174,7 @@ def implicit_evm_solve_with_diffusion(
viscosity: float,
dt: float,
configured_evm_model: Callable, # pylint: disable=g-bare-generic
cg_kwargs: Optional[Mapping[str, Any]] = None
cg_kwargs: Optional[Mapping[str, Any]] = None,
) -> GridVariableVector:
"""Implicit solve for eddy viscosity model combined with diffusion.

Expand All @@ -162,27 +196,42 @@ def implicit_evm_solve_with_diffusion(
cg_kwargs = dict(cg_kwargs)
cg_kwargs.setdefault('tol', 1e-6)
cg_kwargs.setdefault('atol', 1e-6)

if not boundaries.has_all_periodic_boundary_conditions(*v):
raise ValueError(
'implicit_evm_solve_with_diffusion only valid for periodic BC.')
bc = grids.unique_boundary_conditions(*v)
vector_laplacian = np.vectorize(finite_differences.laplacian)

bc_vals = tuple(u.bc for u in v)
offset_vals = tuple(u.offset for u in v)
vector_laplacian = finite_differences.laplacian
homogeneous_bc = tuple(
boundaries.HomogeneousBoundaryConditions(bc_val.types)
for bc_val in bc_vals)
v_with_homogeneous_bc_array = tuple(
diffusion.rhs_transform(u.array, bc_val, True)
for u, bc_val in zip(v, bc_vals))
v_with_homogeneous_bc = tuple(
bc.impose_bc(u)
for u, bc in zip(v_with_homogeneous_bc_array, homogeneous_bc))
# the arg v from the outer function.
def linear_op(velocity):
v_var = tuple(grids.GridVariable(u, bc) for u in velocity)
v_var = tuple(
bc.pad_and_impose_bc(u, offset)
for u, bc, offset in zip(velocity, homogeneous_bc, offset_vals))
acceleration = configured_evm_model(v_var)
return tuple(
velocity - dt * (acceleration + viscosity * vector_laplacian(v_var)))
return tuple(v.trim_boundary() - dt *
(a.trim_boundary() + viscosity * vector_laplacian(v))
for v, a in zip(v_var, acceleration))

# We normally prefer fast diagonalization, but that requires an outer
# product structure for the linear operation, which doesn't hold here.
# TODO(shoyer): consider adding a preconditioner
v_prime, _ = jax.scipy.sparse.linalg.cg(linear_op, tuple(u.array for u in v),
**cg_kwargs)
v_prime, _ = jax.scipy.sparse.linalg.cg(
linear_op, tuple(u.trim_boundary() for u in v_with_homogeneous_bc),
**cg_kwargs)
v_prime = tuple(
interpolation.linear(bc.pad_and_impose_bc(u_prime), o)
for u_prime, bc, o in zip(v_prime, homogeneous_bc, offset_vals))
v_prime_with_constant_bc_array = tuple(
diffusion.rhs_transform(u.array, bc_val, False)
for u, bc_val in zip(v_prime, bc_vals))
return tuple(
grids.GridVariable(u_prime, u.bc) for u_prime, u in zip(v_prime, v))
bc.impose_bc(u) for u, bc in zip(v_prime_with_constant_bc_array, bc_vals))


def explicit_smagorinsky_navier_stokes(dt, cs, forcing, **kwargs):
Expand All @@ -206,10 +255,13 @@ def explicit_smagorinsky_navier_stokes(dt, cs, forcing, **kwargs):
smagorinsky_viscosity, dt=dt, cs=cs)
smagorinsky_acceleration = functools.partial(
evm_model, viscosity_fn=viscosity_fn)
def smagorinsky_acceleration_array(v):
v = smagorinsky_acceleration(v)
return tuple(u.array for u in v)
if forcing is None:
forcing = smagorinsky_acceleration
forcing = smagorinsky_acceleration_array
else:
forcing = forcings.sum_forcings(forcing, smagorinsky_acceleration)
forcing = forcings.sum_forcings(forcing, smagorinsky_acceleration_array)
return equations.semi_implicit_navier_stokes(dt=dt, forcing=forcing, **kwargs)


Expand Down
Loading