Skip to content

Commit ff8fb71

Browse files
andyl7anThe Meridian Authors
authored andcommitted
[JAX] Refactor transformers to use backend abstraction module
PiperOrigin-RevId: 795670081
1 parent 467fabd commit ff8fb71

File tree

4 files changed

+179
-107
lines changed

4 files changed

+179
-107
lines changed

meridian/backend/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,19 @@ def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
153153
return x.astype(dtype)
154154

155155

156+
def _jax_divide_no_nan(x, y):
157+
"""JAX implementation for divide_no_nan."""
158+
import jax.numpy as jnp
159+
160+
return jnp.where(y != 0, jnp.divide(x, y), 0.0)
161+
162+
163+
def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
164+
raise NotImplementedError(
165+
"backend.numpy_function is not implemented for the JAX backend."
166+
)
167+
168+
156169
# --- Backend Initialization ---
157170
_BACKEND = config.get_backend()
158171

@@ -184,6 +197,7 @@ class _JaxErrors:
184197
stack = ops.stack
185198
zeros = ops.zeros
186199
ones = ops.ones
200+
ones_like = ops.ones_like
187201
repeat = ops.repeat
188202
where = ops.where
189203
transpose = ops.transpose
@@ -194,6 +208,12 @@ class _JaxErrors:
194208
exp = ops.exp
195209
log = ops.log
196210
reduce_sum = ops.sum
211+
reduce_mean = ops.mean
212+
reduce_std = ops.std
213+
reduce_any = ops.any
214+
is_nan = ops.isnan
215+
divide_no_nan = _jax_divide_no_nan
216+
numpy_function = _jax_numpy_function
197217

198218
float32 = ops.float32
199219
bool_ = ops.bool_
@@ -230,6 +250,7 @@ def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
230250
stack = ops.stack
231251
zeros = ops.zeros
232252
ones = ops.ones
253+
ones_like = ops.ones_like
233254
repeat = ops.repeat
234255
where = ops.where
235256
transpose = ops.transpose
@@ -240,6 +261,12 @@ def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
240261
exp = ops.math.exp
241262
log = ops.math.log
242263
reduce_sum = ops.reduce_sum
264+
reduce_mean = ops.reduce_mean
265+
reduce_std = ops.math.reduce_std
266+
reduce_any = ops.reduce_any
267+
is_nan = ops.math.is_nan
268+
divide_no_nan = ops.math.divide_no_nan
269+
numpy_function = ops.numpy_function
243270

244271
float32 = ops.float32
245272
bool_ = ops.bool

meridian/backend/test_utils.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424

2525

2626
def assert_allclose(
27-
a: ArrayLike, b: ArrayLike, rtol: float = 1e-6, atol: float = 1e-6
27+
a: ArrayLike,
28+
b: ArrayLike,
29+
rtol: float = 1e-6,
30+
atol: float = 1e-6,
31+
err_msg: str = "",
2832
):
2933
"""Backend-agnostic assertion to check if two array-like objects are close.
3034
@@ -37,23 +41,55 @@ def assert_allclose(
3741
b: The second array-like object to compare.
3842
rtol: The relative tolerance parameter.
3943
atol: The absolute tolerance parameter.
44+
err_msg: The error message to be printed in case of failure.
4045
4146
Raises:
4247
AssertionError: If the two arrays are not equal within the given tolerance.
4348
"""
44-
np.testing.assert_allclose(np.array(a), np.array(b), rtol=rtol, atol=atol)
49+
np.testing.assert_allclose(
50+
np.array(a), np.array(b), rtol=rtol, atol=atol, err_msg=err_msg
51+
)
4552

4653

47-
def assert_allequal(a: ArrayLike, b: ArrayLike):
54+
def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
4855
"""Backend-agnostic assertion to check if two array-like objects are equal.
4956
5057
This function converts both inputs to NumPy arrays before comparing them.
5158
5259
Args:
5360
a: The first array-like object to compare.
5461
b: The second array-like object to compare.
62+
err_msg: The error message to be printed in case of failure.
5563
5664
Raises:
5765
AssertionError: If the two arrays are not equal.
5866
"""
59-
np.testing.assert_array_equal(np.array(a), np.array(b))
67+
np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
68+
69+
70+
def assert_all_finite(a: ArrayLike, err_msg: str = ""):
71+
"""Backend-agnostic assertion to check if all elements in an array are finite.
72+
73+
Args:
74+
a: The array-like object to check.
75+
err_msg: The error message to be printed in case of failure.
76+
77+
Raises:
78+
AssertionError: If the array contains non-finite values.
79+
"""
80+
if not np.all(np.isfinite(np.array(a))):
81+
raise AssertionError(err_msg or "Array contains non-finite values.")
82+
83+
84+
def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
85+
"""Backend-agnostic assertion to check if all elements are non-negative.
86+
87+
Args:
88+
a: The array-like object to check.
89+
err_msg: The error message to be printed in case of failure.
90+
91+
Raises:
92+
AssertionError: If the array contains negative values.
93+
"""
94+
if not np.all(np.array(a) >= 0):
95+
raise AssertionError(err_msg or "Array contains negative values.")

meridian/model/transformers.py

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
"""Contains data transformers for various inputs of the Meridian model."""
1616

1717
import abc
18+
19+
from meridian import backend
1820
import numpy as np
19-
import tensorflow as tf
2021

2122

2223
__all__ = [
@@ -31,14 +32,14 @@ class TensorTransformer(abc.ABC):
3132
"""Abstract class for data transformers."""
3233

3334
@abc.abstractmethod
34-
@tf.function(jit_compile=True)
35-
def forward(self, tensor: tf.Tensor) -> tf.Tensor:
35+
@backend.function(jit_compile=True)
36+
def forward(self, tensor: backend.Tensor) -> backend.Tensor:
3637
"""Transforms a given tensor."""
3738
raise NotImplementedError("`forward` must be implemented.")
3839

3940
@abc.abstractmethod
40-
@tf.function(jit_compile=True)
41-
def inverse(self, tensor: tf.Tensor) -> tf.Tensor:
41+
@backend.function(jit_compile=True)
42+
def inverse(self, tensor: backend.Tensor) -> backend.Tensor:
4243
"""Transforms back a given tensor."""
4344
raise NotImplementedError("`inverse` must be implemented.")
4445

@@ -52,8 +53,8 @@ class MediaTransformer(TensorTransformer):
5253

5354
def __init__(
5455
self,
55-
media: tf.Tensor,
56-
population: tf.Tensor,
56+
media: backend.Tensor,
57+
population: backend.Tensor,
5758
):
5859
"""`MediaTransformer` constructor.
5960
@@ -63,43 +64,43 @@ def __init__(
6364
population: A tensor of dimension `(n_geos,)` containing the population of
6465
each geo, used to compute the scale factors.
6566
"""
66-
population_scaled_media = tf.math.divide_no_nan(
67-
media, population[:, tf.newaxis, tf.newaxis]
67+
population_scaled_media = backend.divide_no_nan(
68+
media, population[:, backend.newaxis, backend.newaxis]
6869
)
6970
# Replace zeros with NaNs
70-
population_scaled_media_nan = tf.where(
71+
population_scaled_media_nan = backend.where(
7172
population_scaled_media == 0, np.nan, population_scaled_media
7273
)
7374
# Tensor of medians of the positive portion of `media`. Used as a component
7475
# for scaling.
75-
self._population_scaled_median_m = tf.numpy_function(
76+
self._population_scaled_median_m = backend.numpy_function(
7677
func=lambda x: np.nanmedian(x, axis=[0, 1]),
7778
inp=[population_scaled_media_nan],
78-
Tout=tf.float32,
79+
Tout=backend.float32,
7980
)
80-
if tf.reduce_any(tf.math.is_nan(self._population_scaled_median_m)):
81+
if backend.reduce_any(backend.is_nan(self._population_scaled_median_m)):
8182
raise ValueError(
8283
"MediaTransformer has a NaN population-scaled non-zero median due to"
8384
" a media channel with either all zeroes or all NaNs."
8485
)
8586
# Tensor of dimensions (`n_geos` x 1) of weights for scaling `metric`.
86-
self._scale_factors_gm = tf.einsum(
87+
self._scale_factors_gm = backend.einsum(
8788
"g,m->gm", population, self._population_scaled_median_m
8889
)
8990

9091
@property
9192
def population_scaled_median_m(self):
9293
return self._population_scaled_median_m
9394

94-
@tf.function(jit_compile=True)
95-
def forward(self, tensor: tf.Tensor) -> tf.Tensor:
95+
@backend.function(jit_compile=True)
96+
def forward(self, tensor: backend.Tensor) -> backend.Tensor:
9697
"""Scales a given tensor using the stored scale factors."""
97-
return tensor / self._scale_factors_gm[:, tf.newaxis, :]
98+
return tensor / self._scale_factors_gm[:, backend.newaxis, :]
9899

99-
@tf.function(jit_compile=True)
100-
def inverse(self, tensor: tf.Tensor) -> tf.Tensor:
100+
@backend.function(jit_compile=True)
101+
def inverse(self, tensor: backend.Tensor) -> backend.Tensor:
101102
"""Scales a given tensor using the inversed stored scale factors."""
102-
return tensor * self._scale_factors_gm[:, tf.newaxis, :]
103+
return tensor * self._scale_factors_gm[:, backend.newaxis, :]
103104

104105

105106
class CenteringAndScalingTransformer(TensorTransformer):
@@ -113,9 +114,9 @@ class CenteringAndScalingTransformer(TensorTransformer):
113114

114115
def __init__(
115116
self,
116-
tensor: tf.Tensor,
117-
population: tf.Tensor,
118-
population_scaling_id: tf.Tensor | None = None,
117+
tensor: backend.Tensor,
118+
population: backend.Tensor,
119+
population_scaling_id: backend.Tensor | None = None,
119120
):
120121
"""`CenteringAndScalingTransformer` constructor.
121122
@@ -129,25 +130,25 @@ def __init__(
129130
scaled by population.
130131
"""
131132
if population_scaling_id is not None:
132-
self._population_scaling_factors = tf.where(
133+
self._population_scaling_factors = backend.where(
133134
population_scaling_id,
134135
population[:, None],
135-
tf.ones_like(population)[:, None],
136+
backend.ones_like(population)[:, None],
136137
)
137138
population_scaled_tensor = (
138139
tensor / self._population_scaling_factors[:, None, :]
139140
)
140-
self._means = tf.reduce_mean(population_scaled_tensor, axis=(0, 1))
141-
self._stdevs = tf.math.reduce_std(population_scaled_tensor, axis=(0, 1))
141+
self._means = backend.reduce_mean(population_scaled_tensor, axis=(0, 1))
142+
self._stdevs = backend.reduce_std(population_scaled_tensor, axis=(0, 1))
142143
else:
143144
self._population_scaling_factors = None
144-
self._means = tf.reduce_mean(tensor, axis=(0, 1))
145-
self._stdevs = tf.math.reduce_std(tensor, axis=(0, 1))
145+
self._means = backend.reduce_mean(tensor, axis=(0, 1))
146+
self._stdevs = backend.reduce_std(tensor, axis=(0, 1))
146147

147-
@tf.function(jit_compile=True)
148+
@backend.function(jit_compile=True)
148149
def forward(
149-
self, tensor: tf.Tensor, apply_population_scaling: bool = True
150-
) -> tf.Tensor:
150+
self, tensor: backend.Tensor, apply_population_scaling: bool = True
151+
) -> backend.Tensor:
151152
"""Scales a given tensor using the stored coefficients.
152153
153154
Args:
@@ -161,10 +162,10 @@ def forward(
161162
and self._population_scaling_factors is not None
162163
):
163164
tensor /= self._population_scaling_factors[:, None, :]
164-
return tf.math.divide_no_nan(tensor - self._means, self._stdevs)
165+
return backend.divide_no_nan(tensor - self._means, self._stdevs)
165166

166-
@tf.function(jit_compile=True)
167-
def inverse(self, tensor: tf.Tensor) -> tf.Tensor:
167+
@backend.function(jit_compile=True)
168+
def inverse(self, tensor: backend.Tensor) -> backend.Tensor:
168169
"""Scales back a given tensor using the stored coefficients."""
169170
scaled_tensor = tensor * self._stdevs + self._means
170171
return (
@@ -183,8 +184,8 @@ class KpiTransformer(TensorTransformer):
183184

184185
def __init__(
185186
self,
186-
kpi: tf.Tensor,
187-
population: tf.Tensor,
187+
kpi: backend.Tensor,
188+
population: backend.Tensor,
188189
):
189190
"""`KpiTransformer` constructor.
190191
@@ -195,11 +196,11 @@ def __init__(
195196
each geo, used to to compute the population scale factors.
196197
"""
197198
self._population = population
198-
population_scaled_kpi = tf.math.divide_no_nan(
199-
kpi, self._population[:, tf.newaxis]
199+
population_scaled_kpi = backend.divide_no_nan(
200+
kpi, self._population[:, backend.newaxis]
200201
)
201-
self._population_scaled_mean = tf.reduce_mean(population_scaled_kpi)
202-
self._population_scaled_stdev = tf.math.reduce_std(population_scaled_kpi)
202+
self._population_scaled_mean = backend.reduce_mean(population_scaled_kpi)
203+
self._population_scaled_stdev = backend.reduce_std(population_scaled_kpi)
203204

204205
@property
205206
def population_scaled_mean(self):
@@ -209,18 +210,18 @@ def population_scaled_mean(self):
209210
def population_scaled_stdev(self):
210211
return self._population_scaled_stdev
211212

212-
@tf.function(jit_compile=True)
213-
def forward(self, tensor: tf.Tensor) -> tf.Tensor:
213+
@backend.function(jit_compile=True)
214+
def forward(self, tensor: backend.Tensor) -> backend.Tensor:
214215
"""Scales a given tensor using the stored coefficients."""
215-
return tf.math.divide_no_nan(
216-
tf.math.divide_no_nan(tensor, self._population[:, tf.newaxis])
216+
return backend.divide_no_nan(
217+
backend.divide_no_nan(tensor, self._population[:, backend.newaxis])
217218
- self._population_scaled_mean,
218219
self._population_scaled_stdev,
219220
)
220221

221-
@tf.function(jit_compile=True)
222-
def inverse(self, tensor: tf.Tensor) -> tf.Tensor:
222+
@backend.function(jit_compile=True)
223+
def inverse(self, tensor: backend.Tensor) -> backend.Tensor:
223224
"""Scales back a given tensor using the stored coefficients."""
224225
return (
225226
tensor * self._population_scaled_stdev + self._population_scaled_mean
226-
) * self._population[:, tf.newaxis]
227+
) * self._population[:, backend.newaxis]

0 commit comments

Comments
 (0)