Skip to content

Commit

Permalink
Merge branch 'tensorflow:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
mjyun01 authored Feb 12, 2024
2 parents f9da466 + c0abb86 commit 5d30c00
Show file tree
Hide file tree
Showing 15 changed files with 444 additions and 33 deletions.
62 changes: 38 additions & 24 deletions official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ def expand_vector(v: np.ndarray) -> np.ndarray:
def expand_1_axis(w: np.ndarray,
epsilon: float,
axis: int) -> np.ndarray:
"""Expands either the first dimension or the last dimension of w.
"""Expands either the first or last dimension of w.
If `axis = 0`, the following constraint will be satisfied:
If `axis = 0`, the following expression will be satisfied:
matmul(x, w) ==
matmul(expand_vector(x), expand_1_axis(w, epsilon=0.1, axis=0))
matmul(expand_vector(x), expand_1_axis(w, axis=0))
If `axis = -1`, the following constraint will be satisfied if `epsilon = 0.0`:
If `axis = -1` and `epsilon = 0.0`, the following constraint will be
satisfied:
expand_vector(matmul(x, w)) ==
2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1))
Expand All @@ -54,9 +55,12 @@ def expand_1_axis(w: np.ndarray,
Returns:
Expanded numpy array.
"""
assert axis in (0, -1), (
"Only support expanding the first or the last dimension. "
"Got: {}".format(axis))

if axis not in (0, -1):
raise ValueError(
"Only support expanding the first or the last dimension. "
"Got: {}".format(axis)
)

rank = len(w.shape)

Expand All @@ -65,7 +69,7 @@ def expand_1_axis(w: np.ndarray,

sign_flip = np.array([1, -1])
for _ in range(rank - 1):
sign_flip = np.expand_dims(sign_flip, axis=-1 if axis == 0 else 0)
sign_flip = np.expand_dims(sign_flip, axis=axis - 1)
sign_flip = np.tile(sign_flip,
[w.shape[0]] + [1] * (rank - 2) + [w.shape[-1]])

Expand All @@ -76,9 +80,9 @@ def expand_1_axis(w: np.ndarray,

def expand_2_axes(w: np.ndarray,
epsilon: float) -> np.ndarray:
"""Expands the first dimension and the last dimension of w.
"""Expands the first and last dimension of w.
The following constraint will be satisfied:
This operation satisfies the following expression:
expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w))
Args:
Expand Down Expand Up @@ -109,8 +113,8 @@ def var_to_var(var_from: tf.Variable,
epsilon: float):
"""Expands a variable to another variable.
Assume the shape of `var_from` is (a, b, ..., y, z), the shape of `var_to`
can be (a, ..., z * 2), (a * 2, ..., z * 2), (a * 2, ..., z)
Assuming the shape of `var_from` is (a, b, ..., y, z), then shape of `var_to`
must be one of (a, ..., z * 2), (a * 2, ..., z * 2), or (a * 2, ..., z).
If the shape of `var_to` is (a, ..., 2 * z):
For any x, tf.matmul(x, var_to) ~= expand_vector(tf.matmul(x, var_from)) / 2
Expand All @@ -131,21 +135,30 @@ def var_to_var(var_from: tf.Variable,

if shape_from == shape_to:
var_to.assign(var_from)
return

var_from_np = var_from.numpy()

if len(shape_from) == len(shape_to) == 1:
var_to.assign(expand_vector(var_from_np))
return

elif len(shape_from) == 1 and len(shape_to) == 1:
var_to.assign(expand_vector(var_from.numpy()))
a_from, z_from = shape_from[0], shape_from[-1]
a_to, z_to = shape_to[0], shape_to[-1]

elif shape_from[0] * 2 == shape_to[0] and shape_from[-1] == shape_to[-1]:
var_to.assign(expand_1_axis(var_from.numpy(), epsilon=epsilon, axis=0))
if a_to == 2 * a_from and z_to == z_from:
var_to.assign(expand_1_axis(var_from_np, epsilon=epsilon, axis=0))
return

elif shape_from[0] == shape_to[0] and shape_from[-1] * 2 == shape_to[-1]:
var_to.assign(expand_1_axis(var_from.numpy(), epsilon=epsilon, axis=-1))
if a_to == a_from and z_to == 2 * z_from:
var_to.assign(expand_1_axis(var_from_np, epsilon=epsilon, axis=-1))
return

elif shape_from[0] * 2 == shape_to[0] and shape_from[-1] * 2 == shape_to[-1]:
var_to.assign(expand_2_axes(var_from.numpy(), epsilon=epsilon))
if a_to == 2 * a_from and z_to == 2 * z_from:
var_to.assign(expand_2_axes(var_from_np, epsilon=epsilon))
return

else:
raise ValueError("Shape not supported, {}, {}".format(shape_from, shape_to))
raise ValueError("Shape not supported, {}, {}".format(shape_from, shape_to))


def model_to_model_2x_wide(model_from: tf.Module,
Expand All @@ -170,8 +183,9 @@ def model_to_model_2x_wide(model_from: tf.Module,
assert model_narrow([[1, 2, 3]]) == model_wide([[1, 1, 2, 2, 3, 3]])
```
We assume that `model_from` and `model_to` has the same architecture and only
widths of them differ.
We assume that `model_from` and `model_to` have the same architecture and
differ
only in widths.
Args:
model_from: input model to expand.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def test_expand_3d_tensor_axis_2(self):
o1 = np.matmul(x, w1)
self.assertAllClose(o0, np.sum(o1.reshape(2, 2), axis=-1))

def test_relations(self):
x = np.array([10, 11])
w = np.random.rand(2, 2)
# matmul(x, w) == matmul(expand_vector(x), expand_1_axis(w, axis=0))
lhs = np.matmul(x, w)
rhs = np.matmul(
tf2_utils_2x_wide.expand_vector(x),
tf2_utils_2x_wide.expand_1_axis(w, epsilon=0.1, axis=0),
)
self.assertAllClose(lhs, rhs)
# expand_vector(matmul(x, w)) ==
# 2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1))
lhs = tf2_utils_2x_wide.expand_vector(np.matmul(x, w))
rhs = 2 * np.matmul(
x, tf2_utils_2x_wide.expand_1_axis(w, epsilon=0.0, axis=-1)
)
self.assertAllClose(lhs, rhs)
# expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w))
lhs = tf2_utils_2x_wide.expand_vector(np.matmul(x, w))
rhs = np.matmul(
tf2_utils_2x_wide.expand_vector(x),
tf2_utils_2x_wide.expand_2_axes(w, epsilon=0.1),
)
self.assertAllClose(lhs, rhs)

def test_end_to_end(self):
"""Covers expand_vector, expand_2_axes, and expand_1_axis."""
model_narrow = tf_keras.Sequential()
Expand Down
8 changes: 3 additions & 5 deletions official/recommendation/uplift/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,16 @@
# limitations under the License.

"""Defines enum keys used by the keras uplift modeling library."""
import enum


class StrEnum(str, enum.Enum):
"""An Enum represented by a string."""
import enum


class TwoTowerOutputKeys(StrEnum):
class TwoTowerOutputKeys(enum.StrEnum):
"""Keys for training and inference output tensors."""

CONTROL_PREDICTIONS = "control_predictions"
TREATMENT_PREDICTIONS = "treatment_predictions"
UPLIFT_PREDICTIONS = "uplift_predictions"
IS_TREATMENT = "is_treatment"
TRUE_LOGITS = "true_logits"
TRUE_PREDICTIONS = "true_predictions"
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def call(
true_logits = tf.where(
is_treatment, outputs.treatment_logits, outputs.control_logits
)
true_predictions = tf.where(
is_treatment, treatment_predictions, control_predictions
)

# Create a new tensor since ExtensionTypes are immutable.
return types.TwoTowerTrainingOutputs(
Expand All @@ -131,6 +134,7 @@ def call(
treatment_predictions=treatment_predictions,
uplift=uplift,
true_logits=true_logits,
true_predictions=true_predictions,
is_treatment=is_treatment,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,16 @@ def test_true_logits_correspond_to_control_and_treatment_logits(self):
outputs = layer(inputs)
self.assertAllEqual(tf.constant([2, -1, 3]), outputs.true_logits)

def test_true_preds_correspond_to_control_and_treatment_preds(self):
layer = self._get_layer(inverse_link_fn=tf.nn.relu)
inputs = self._get_inputs(
control_logits=tf.constant([2, 0, 3]),
treatment_logits=tf.constant([-1, 2, 1]),
is_treatment=tf.constant([1, 1, 0]),
)
outputs = layer(inputs)
self.assertAllEqual(tf.constant([0, 2, 3]), outputs.true_predictions)

def test_is_treatment_tensor_gets_converted_to_boolean_tensor(self):
layer = self._get_layer()
inputs = self._get_inputs(is_treatment=tf.constant([1, 1, 0]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
class TrueLogitsLossTest(tf.test.TestCase, parameterized.TestCase):

def _get_y_pred(self, **kwargs):
# The shared embedding and control/treatment/uplift predictions are
# The shared embedding and control/treatment/uplift/true predictions are
# distracting from the test logic.
return types.TwoTowerTrainingOutputs(
shared_embedding=tf.zeros((3, 1)),
control_predictions=tf.zeros((3, 1)),
treatment_predictions=tf.zeros((3, 1)),
true_predictions=tf.zeros((3, 1)),
uplift=tf.zeros((3, 1)),
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions official/recommendation/uplift/metrics/label_mean_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _get_y_pred(
control_logits=tf.ones_like(is_treatment),
treatment_logits=tf.ones_like(is_treatment),
true_logits=tf.ones_like(is_treatment),
true_predictions=tf.ones_like(is_treatment),
is_treatment=is_treatment,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _get_y_pred(
control_logits=tf.ones_like(is_treatment),
treatment_logits=tf.ones_like(is_treatment),
true_logits=tf.ones_like(is_treatment),
true_predictions=tf.ones_like(is_treatment),
is_treatment=is_treatment,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _get_y_pred(
control_logits=tf.ones_like(is_treatment),
treatment_logits=tf.ones_like(is_treatment),
true_logits=tf.ones_like(is_treatment),
true_predictions=tf.ones_like(is_treatment),
is_treatment=is_treatment,
)

Expand Down
1 change: 1 addition & 0 deletions official/recommendation/uplift/metrics/uplift_mean_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _get_y_pred(
control_logits=tf.ones_like(is_treatment),
treatment_logits=tf.ones_like(is_treatment),
true_logits=tf.ones_like(is_treatment),
true_predictions=tf.ones_like(is_treatment),
is_treatment=is_treatment,
)

Expand Down
5 changes: 5 additions & 0 deletions official/recommendation/uplift/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,18 @@ class TwoTowerTrainingOutputs(TwoTowerPredictionOutputs):
the corresponding value in the `is_treatment` tensor. It will contain
treatment group logits for the `is_treatment == 1` entries and control
group logits otherwise.
true_predictions: predictions for either the control or treatment group,
depending on the corresponding value in the `is_treatment` tensor. It will
contain treatment group predictions for the `is_treatment == 1` entries
and control group predictions otherwise.
is_treatment: a boolean `tf.Tensor` indicating if the example belongs to the
treatment group (True) or control group (False).
"""

__name__ = "TwoTowerTrainingOutputs"

true_logits: tf.Tensor
true_predictions: tf.Tensor
is_treatment: tf.Tensor

# TODO(b/281776818): Override __validate__ to assert that the true logits is
Expand Down
34 changes: 33 additions & 1 deletion official/vision/configs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Common configurations."""

import dataclasses
from typing import List, Optional
from typing import List, Optional, Sequence

# Import libraries

Expand Down Expand Up @@ -100,6 +100,35 @@ class MixupAndCutmix(hyperparams.Config):
label_smoothing: float = 0.1


@dataclasses.dataclass
class SSDRandomCropParam(hyperparams.Config):
min_object_covered: float = 0.0
min_box_overlap: float = 0.5
prob_to_apply: float = 0.85


@dataclasses.dataclass
class SSDRandomCrop(hyperparams.Config):
"""Configuration for SSDRandomCrop.
Liu et al., SSD: Single shot multibox detector
https://arxiv.org/abs/1512.02325.
"""
ssd_random_crop_params: Sequence[SSDRandomCropParam] = dataclasses.field(
default_factory=lambda: (
SSDRandomCropParam(min_object_covered=0.0),
SSDRandomCropParam(min_object_covered=0.1),
SSDRandomCropParam(min_object_covered=0.3),
SSDRandomCropParam(min_object_covered=0.5),
SSDRandomCropParam(min_object_covered=0.7),
SSDRandomCropParam(min_object_covered=0.9),
SSDRandomCropParam(min_object_covered=1.0),
)
)
aspect_ratio_range: tuple[float, float] = (0.5, 2.0)
area_range: tuple[float, float] = (0.1, 1.0)


@dataclasses.dataclass
class Augmentation(hyperparams.OneOfConfig):
"""Configuration for input data augmentation.
Expand All @@ -112,6 +141,9 @@ class Augmentation(hyperparams.OneOfConfig):
type: Optional[str] = None
randaug: RandAugment = dataclasses.field(default_factory=RandAugment)
autoaug: AutoAugment = dataclasses.field(default_factory=AutoAugment)
ssd_random_crop: SSDRandomCrop = dataclasses.field(
default_factory=SSDRandomCrop
)


@dataclasses.dataclass
Expand Down
7 changes: 7 additions & 0 deletions official/vision/dataloaders/retinanet_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def __init__(self,
translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
elif aug_type.type == 'ssd_random_crop':
logging.info('Using SSD Random Crop.')
self._augmenter = augment.SSDRandomCrop(
params=aug_type.ssd_random_crop.ssd_random_crop_params,
aspect_ratio_range=aug_type.ssd_random_crop.aspect_ratio_range,
area_range=aug_type.ssd_random_crop.area_range,
)
else:
raise ValueError(f'Augmentation policy {aug_type.type} not supported.')

Expand Down
Loading

0 comments on commit 5d30c00

Please sign in to comment.