Skip to content

Commit

Permalink
RF gather, clip_to_valid with static index and dyn axis
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Sep 7, 2024
1 parent 1b5530d commit 0ee25d9
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 37 deletions.
8 changes: 1 addition & 7 deletions returnn/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,13 +890,7 @@ def combine(
return out

@staticmethod
def gather(
source: Tensor,
*,
indices: Union[Tensor, int],
axis: Dim,
clip_to_valid: bool = False,
) -> Tensor:
def gather(source: Tensor, *, indices: Union[Tensor, int], axis: Dim, clip_to_valid: bool = False) -> Tensor:
"""
Gathers slices on a specified axis from the source using indices.
If the source is of the shape ``[B,D,F1]``, and indices of shape ``[B,F2]``,
Expand Down
8 changes: 1 addition & 7 deletions returnn/tf/frontend_layers/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,13 +638,7 @@ def combine(
return rfl.make_layer({"class": "combine", "from": [a, b], "kind": kind, **kwargs}, name=kind)

@staticmethod
def gather(
source: Tensor,
*,
indices: Union[Tensor, int],
axis: Dim,
clip_to_valid: bool = False,
) -> Tensor:
def gather(source: Tensor, *, indices: Union[Tensor, int], axis: Dim, clip_to_valid: bool = False) -> Tensor:
"""gather"""
args = {}
if clip_to_valid:
Expand Down
25 changes: 25 additions & 0 deletions returnn/tf/frontend_low_level/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,28 @@ def reduce(source: _TT, *, mode: str, axis: Union[Dim, Sequence[Dim]], use_mask:
y = tf_util.optional_mul(y, correction_factor)
out_data.raw_tensor = y
return out_data

@staticmethod
def clip_by_value(
x: Tensor,
clip_value_min: Union[Tensor, rf.RawTensorTypes],
clip_value_max: Union[Tensor, rf.RawTensorTypes],
*,
allow_broadcast_all_sources: bool = False,
) -> Tensor:
"""clip by value"""
clip_value_min = rf.convert_to_tensor(clip_value_min, _backend=TFBackend, device=x.device)
clip_value_max = rf.convert_to_tensor(clip_value_max, _backend=TFBackend, device=x.device)
out = Tensor.get_common_data(
[x, clip_value_min, clip_value_max],
allow_broadcast_all_sources=allow_broadcast_all_sources,
name="clip_by_value",
)
out.dtype = x.dtype
out.sparse_dim = x.sparse_dim
out.feature_dim = x.feature_dim
x_bc_raw = x.copy_compatible_to_dims_raw(out.dims)
min_bc_raw = clip_value_min.copy_compatible_to_dims_raw(out.dims)
max_bc_raw = clip_value_max.copy_compatible_to_dims_raw(out.dims)
out.raw_tensor = tf.clip_by_value(x_bc_raw, min_bc_raw, max_bc_raw)
return out
7 changes: 3 additions & 4 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,8 @@ def __init__(self, position: Union[LayerBase, int], axis: Union[Dim, str], clip_
:param clip_to_valid: if True, the indices will be clipped to the valid range of the input
Also taking seq lengths into account.
"""
import returnn.frontend as rf

super(GatherLayer, self).__init__(**kwargs)
self.position = position

Expand All @@ -1579,10 +1581,7 @@ def __init__(self, position: Union[LayerBase, int], axis: Union[Dim, str], clip_
dyn_size_ext = dim.dyn_size_ext
if not dyn_size_ext:
dyn_size_ext = Data.from_tensor(tf.shape(input_data.placeholder)[old_gather_axis])
common = Data.get_common_data([position_data, dyn_size_ext])
position_data = position_data.copy_compatible_to(common, check_sparse=False, check_dtype=False)
dyn_size_ext = dyn_size_ext.copy_compatible_to(common, check_sparse=False, check_dtype=False)
position_data.placeholder = tf.clip_by_value(position_data.placeholder, 0, dyn_size_ext.placeholder - 1)
position_data = rf.clip_by_value(position_data, 0, dyn_size_ext - 1)

# determine all common axes of input_data and position_data
common_axes_input, common_axes_position, input_axes, position_axes = self._get_common_input_position_axes(
Expand Down
23 changes: 4 additions & 19 deletions returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,13 +902,7 @@ def full(
)

@staticmethod
def gather(
source: Tensor,
*,
indices: Union[Tensor, int],
axis: Dim,
clip_to_valid: bool = False,
) -> Tensor:
def gather(source: Tensor, *, indices: Union[Tensor, int], axis: Dim, clip_to_valid: bool = False) -> Tensor:
"""
Gather.
Expand All @@ -933,19 +927,10 @@ def gather(
raise TypeError(f"Unsupported type for indices: {type(indices)}")
axis_int = source.get_axis_from_description(axis, allow_int=False)
if clip_to_valid:
indices = indices.copy()
dim: Dim = source.dims[axis_int]
if dim.dyn_size_ext:
assert dim.dyn_size_ext.dims_set.issubset(
indices.dims_set
), f"gather with clip_to_valid: indices ({indices}) dims must be a superset of {dim} dyn-size"
size = dim.dyn_size_ext.copy_compatible_to(indices, check_sparse=False)
indices.raw_tensor = torch.clamp(
indices.raw_tensor,
torch.tensor(0, device=indices.raw_tensor.device),
(size.raw_tensor - 1).to(indices.raw_tensor.device),
)
if axis.dyn_size_ext:
indices = rf.clip_by_value(indices, 0, axis.dyn_size_ext - 1)
else:
indices = indices.copy()
indices.raw_tensor = torch.clamp(indices.raw_tensor, 0, source.raw_tensor.shape[axis_int] - 1)
index_own_dims = [dim for dim in indices.dims if dim not in source.dims or dim == axis]
out = Tensor(
Expand Down
17 changes: 17 additions & 0 deletions tests/test_rf_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,23 @@ def _forward_step(*, model: _Net, extern_data: TensorDict):
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)


def test_gather_time_static_clip_to_valid():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7, name="in")
extern_data_template = TensorDict(
{
"data": Tensor("data", [batch_dim, time_dim, in_dim], feature_dim=in_dim, dtype="float32"),
}
)

def _forward_step(*, extern_data: TensorDict, **_kwargs):
x = extern_data["data"]
out = rf.gather(x, indices=0, axis=time_dim, clip_to_valid=True)
out.mark_as_default_output(shape=(batch_dim, in_dim))

run_model(extern_data_template, lambda *, epoch, step: rf.Module(), _forward_step)


def test_slice():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7, name="in")
Expand Down

0 comments on commit 0ee25d9

Please sign in to comment.