15
15
"""Contains data transformers for various inputs of the Meridian model."""
16
16
17
17
import abc
18
+
19
+ from meridian import backend
18
20
import numpy as np
19
- import tensorflow as tf
20
21
21
22
22
23
__all__ = [
@@ -31,14 +32,14 @@ class TensorTransformer(abc.ABC):
31
32
"""Abstract class for data transformers."""
32
33
33
34
@abc .abstractmethod
34
- @tf .function (jit_compile = True )
35
- def forward (self , tensor : tf .Tensor ) -> tf .Tensor :
35
+ @backend .function (jit_compile = True )
36
+ def forward (self , tensor : backend .Tensor ) -> backend .Tensor :
36
37
"""Transforms a given tensor."""
37
38
raise NotImplementedError ("`forward` must be implemented." )
38
39
39
40
@abc .abstractmethod
40
- @tf .function (jit_compile = True )
41
- def inverse (self , tensor : tf .Tensor ) -> tf .Tensor :
41
+ @backend .function (jit_compile = True )
42
+ def inverse (self , tensor : backend .Tensor ) -> backend .Tensor :
42
43
"""Transforms back a given tensor."""
43
44
raise NotImplementedError ("`inverse` must be implemented." )
44
45
@@ -52,8 +53,8 @@ class MediaTransformer(TensorTransformer):
52
53
53
54
def __init__ (
54
55
self ,
55
- media : tf .Tensor ,
56
- population : tf .Tensor ,
56
+ media : backend .Tensor ,
57
+ population : backend .Tensor ,
57
58
):
58
59
"""`MediaTransformer` constructor.
59
60
@@ -63,43 +64,43 @@ def __init__(
63
64
population: A tensor of dimension `(n_geos,)` containing the population of
64
65
each geo, used to compute the scale factors.
65
66
"""
66
- population_scaled_media = tf . math .divide_no_nan (
67
- media , population [:, tf .newaxis , tf .newaxis ]
67
+ population_scaled_media = backend .divide_no_nan (
68
+ media , population [:, backend .newaxis , backend .newaxis ]
68
69
)
69
70
# Replace zeros with NaNs
70
- population_scaled_media_nan = tf .where (
71
+ population_scaled_media_nan = backend .where (
71
72
population_scaled_media == 0 , np .nan , population_scaled_media
72
73
)
73
74
# Tensor of medians of the positive portion of `media`. Used as a component
74
75
# for scaling.
75
- self ._population_scaled_median_m = tf .numpy_function (
76
+ self ._population_scaled_median_m = backend .numpy_function (
76
77
func = lambda x : np .nanmedian (x , axis = [0 , 1 ]),
77
78
inp = [population_scaled_media_nan ],
78
- Tout = tf .float32 ,
79
+ Tout = backend .float32 ,
79
80
)
80
- if tf .reduce_any (tf . math .is_nan (self ._population_scaled_median_m )):
81
+ if backend .reduce_any (backend .is_nan (self ._population_scaled_median_m )):
81
82
raise ValueError (
82
83
"MediaTransformer has a NaN population-scaled non-zero median due to"
83
84
" a media channel with either all zeroes or all NaNs."
84
85
)
85
86
# Tensor of dimensions (`n_geos` x 1) of weights for scaling `metric`.
86
- self ._scale_factors_gm = tf .einsum (
87
+ self ._scale_factors_gm = backend .einsum (
87
88
"g,m->gm" , population , self ._population_scaled_median_m
88
89
)
89
90
90
91
@property
91
92
def population_scaled_median_m (self ):
92
93
return self ._population_scaled_median_m
93
94
94
- @tf .function (jit_compile = True )
95
- def forward (self , tensor : tf .Tensor ) -> tf .Tensor :
95
+ @backend .function (jit_compile = True )
96
+ def forward (self , tensor : backend .Tensor ) -> backend .Tensor :
96
97
"""Scales a given tensor using the stored scale factors."""
97
- return tensor / self ._scale_factors_gm [:, tf .newaxis , :]
98
+ return tensor / self ._scale_factors_gm [:, backend .newaxis , :]
98
99
99
- @tf .function (jit_compile = True )
100
- def inverse (self , tensor : tf .Tensor ) -> tf .Tensor :
100
+ @backend .function (jit_compile = True )
101
+ def inverse (self , tensor : backend .Tensor ) -> backend .Tensor :
101
102
"""Scales a given tensor using the inversed stored scale factors."""
102
- return tensor * self ._scale_factors_gm [:, tf .newaxis , :]
103
+ return tensor * self ._scale_factors_gm [:, backend .newaxis , :]
103
104
104
105
105
106
class CenteringAndScalingTransformer (TensorTransformer ):
@@ -113,9 +114,9 @@ class CenteringAndScalingTransformer(TensorTransformer):
113
114
114
115
def __init__ (
115
116
self ,
116
- tensor : tf .Tensor ,
117
- population : tf .Tensor ,
118
- population_scaling_id : tf .Tensor | None = None ,
117
+ tensor : backend .Tensor ,
118
+ population : backend .Tensor ,
119
+ population_scaling_id : backend .Tensor | None = None ,
119
120
):
120
121
"""`CenteringAndScalingTransformer` constructor.
121
122
@@ -129,25 +130,25 @@ def __init__(
129
130
scaled by population.
130
131
"""
131
132
if population_scaling_id is not None :
132
- self ._population_scaling_factors = tf .where (
133
+ self ._population_scaling_factors = backend .where (
133
134
population_scaling_id ,
134
135
population [:, None ],
135
- tf .ones_like (population )[:, None ],
136
+ backend .ones_like (population )[:, None ],
136
137
)
137
138
population_scaled_tensor = (
138
139
tensor / self ._population_scaling_factors [:, None , :]
139
140
)
140
- self ._means = tf .reduce_mean (population_scaled_tensor , axis = (0 , 1 ))
141
- self ._stdevs = tf . math .reduce_std (population_scaled_tensor , axis = (0 , 1 ))
141
+ self ._means = backend .reduce_mean (population_scaled_tensor , axis = (0 , 1 ))
142
+ self ._stdevs = backend .reduce_std (population_scaled_tensor , axis = (0 , 1 ))
142
143
else :
143
144
self ._population_scaling_factors = None
144
- self ._means = tf .reduce_mean (tensor , axis = (0 , 1 ))
145
- self ._stdevs = tf . math .reduce_std (tensor , axis = (0 , 1 ))
145
+ self ._means = backend .reduce_mean (tensor , axis = (0 , 1 ))
146
+ self ._stdevs = backend .reduce_std (tensor , axis = (0 , 1 ))
146
147
147
- @tf .function (jit_compile = True )
148
+ @backend .function (jit_compile = True )
148
149
def forward (
149
- self , tensor : tf .Tensor , apply_population_scaling : bool = True
150
- ) -> tf .Tensor :
150
+ self , tensor : backend .Tensor , apply_population_scaling : bool = True
151
+ ) -> backend .Tensor :
151
152
"""Scales a given tensor using the stored coefficients.
152
153
153
154
Args:
@@ -161,10 +162,10 @@ def forward(
161
162
and self ._population_scaling_factors is not None
162
163
):
163
164
tensor /= self ._population_scaling_factors [:, None , :]
164
- return tf . math .divide_no_nan (tensor - self ._means , self ._stdevs )
165
+ return backend .divide_no_nan (tensor - self ._means , self ._stdevs )
165
166
166
- @tf .function (jit_compile = True )
167
- def inverse (self , tensor : tf .Tensor ) -> tf .Tensor :
167
+ @backend .function (jit_compile = True )
168
+ def inverse (self , tensor : backend .Tensor ) -> backend .Tensor :
168
169
"""Scales back a given tensor using the stored coefficients."""
169
170
scaled_tensor = tensor * self ._stdevs + self ._means
170
171
return (
@@ -183,8 +184,8 @@ class KpiTransformer(TensorTransformer):
183
184
184
185
def __init__ (
185
186
self ,
186
- kpi : tf .Tensor ,
187
- population : tf .Tensor ,
187
+ kpi : backend .Tensor ,
188
+ population : backend .Tensor ,
188
189
):
189
190
"""`KpiTransformer` constructor.
190
191
@@ -195,11 +196,11 @@ def __init__(
195
196
each geo, used to to compute the population scale factors.
196
197
"""
197
198
self ._population = population
198
- population_scaled_kpi = tf . math .divide_no_nan (
199
- kpi , self ._population [:, tf .newaxis ]
199
+ population_scaled_kpi = backend .divide_no_nan (
200
+ kpi , self ._population [:, backend .newaxis ]
200
201
)
201
- self ._population_scaled_mean = tf .reduce_mean (population_scaled_kpi )
202
- self ._population_scaled_stdev = tf . math .reduce_std (population_scaled_kpi )
202
+ self ._population_scaled_mean = backend .reduce_mean (population_scaled_kpi )
203
+ self ._population_scaled_stdev = backend .reduce_std (population_scaled_kpi )
203
204
204
205
@property
205
206
def population_scaled_mean (self ):
@@ -209,18 +210,18 @@ def population_scaled_mean(self):
209
210
def population_scaled_stdev (self ):
210
211
return self ._population_scaled_stdev
211
212
212
- @tf .function (jit_compile = True )
213
- def forward (self , tensor : tf .Tensor ) -> tf .Tensor :
213
+ @backend .function (jit_compile = True )
214
+ def forward (self , tensor : backend .Tensor ) -> backend .Tensor :
214
215
"""Scales a given tensor using the stored coefficients."""
215
- return tf . math .divide_no_nan (
216
- tf . math . divide_no_nan (tensor , self ._population [:, tf .newaxis ])
216
+ return backend .divide_no_nan (
217
+ backend . divide_no_nan (tensor , self ._population [:, backend .newaxis ])
217
218
- self ._population_scaled_mean ,
218
219
self ._population_scaled_stdev ,
219
220
)
220
221
221
- @tf .function (jit_compile = True )
222
- def inverse (self , tensor : tf .Tensor ) -> tf .Tensor :
222
+ @backend .function (jit_compile = True )
223
+ def inverse (self , tensor : backend .Tensor ) -> backend .Tensor :
223
224
"""Scales back a given tensor using the stored coefficients."""
224
225
return (
225
226
tensor * self ._population_scaled_stdev + self ._population_scaled_mean
226
- ) * self ._population [:, tf .newaxis ]
227
+ ) * self ._population [:, backend .newaxis ]
0 commit comments