diff --git a/returnn/frontend/_backend.py b/returnn/frontend/_backend.py index 1df3d433c..c22ac50ba 100644 --- a/returnn/frontend/_backend.py +++ b/returnn/frontend/_backend.py @@ -890,13 +890,7 @@ def combine( return out @staticmethod - def gather( - source: Tensor, - *, - indices: Union[Tensor, int], - axis: Dim, - clip_to_valid: bool = False, - ) -> Tensor: + def gather(source: Tensor, *, indices: Union[Tensor, int], axis: Dim, clip_to_valid: bool = False) -> Tensor: """ Gathers slices on a specified axis from the source using indices. If the source is of the shape ``[B,D,F1]``, and indices of shape ``[B,F2]``, diff --git a/returnn/tf/frontend_layers/_backend.py b/returnn/tf/frontend_layers/_backend.py index 5c0e054fa..bdcc55230 100644 --- a/returnn/tf/frontend_layers/_backend.py +++ b/returnn/tf/frontend_layers/_backend.py @@ -638,13 +638,7 @@ def combine( return rfl.make_layer({"class": "combine", "from": [a, b], "kind": kind, **kwargs}, name=kind) @staticmethod - def gather( - source: Tensor, - *, - indices: Union[Tensor, int], - axis: Dim, - clip_to_valid: bool = False, - ) -> Tensor: + def gather(source: Tensor, *, indices: Union[Tensor, int], axis: Dim, clip_to_valid: bool = False) -> Tensor: """gather""" args = {} if clip_to_valid: diff --git a/returnn/tf/frontend_low_level/_backend.py b/returnn/tf/frontend_low_level/_backend.py index b60681772..30da51f76 100644 --- a/returnn/tf/frontend_low_level/_backend.py +++ b/returnn/tf/frontend_low_level/_backend.py @@ -553,3 +553,28 @@ def reduce(source: _TT, *, mode: str, axis: Union[Dim, Sequence[Dim]], use_mask: y = tf_util.optional_mul(y, correction_factor) out_data.raw_tensor = y return out_data + + @staticmethod + def clip_by_value( + x: Tensor, + clip_value_min: Union[Tensor, rf.RawTensorTypes], + clip_value_max: Union[Tensor, rf.RawTensorTypes], + *, + allow_broadcast_all_sources: bool = False, + ) -> Tensor: + """clip by value""" + clip_value_min = rf.convert_to_tensor(clip_value_min, _backend=TFBackend, device=x.device) + clip_value_max = rf.convert_to_tensor(clip_value_max, _backend=TFBackend, device=x.device) + out = Tensor.get_common_data( + [x, clip_value_min, clip_value_max], + allow_broadcast_all_sources=allow_broadcast_all_sources, + name="clip_by_value", + ) + out.dtype = x.dtype + out.sparse_dim = x.sparse_dim + out.feature_dim = x.feature_dim + x_bc_raw = x.copy_compatible_to_dims_raw(out.dims) + min_bc_raw = clip_value_min.copy_compatible_to_dims_raw(out.dims) + max_bc_raw = clip_value_max.copy_compatible_to_dims_raw(out.dims) + out.raw_tensor = tf.clip_by_value(x_bc_raw, min_bc_raw, max_bc_raw) + return out diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 6286e9128..a1324a7a6 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1564,6 +1564,8 @@ def __init__(self, position: Union[LayerBase, int], axis: Union[Dim, str], clip_ :param clip_to_valid: if True, the indices will be clipped to the valid range of the input Also taking seq lengths into account. """ + import returnn.frontend as rf + super(GatherLayer, self).__init__(**kwargs) self.position = position @@ -1579,10 +1581,7 @@ def __init__(self, position: Union[LayerBase, int], axis: Union[Dim, str], clip_ dyn_size_ext = dim.dyn_size_ext if not dyn_size_ext: dyn_size_ext = Data.from_tensor(tf.shape(input_data.placeholder)[old_gather_axis]) - common = Data.get_common_data([position_data, dyn_size_ext]) - position_data = position_data.copy_compatible_to(common, check_sparse=False, check_dtype=False) - dyn_size_ext = dyn_size_ext.copy_compatible_to(common, check_sparse=False, check_dtype=False) - position_data.placeholder = tf.clip_by_value(position_data.placeholder, 0, dyn_size_ext.placeholder - 1) + position_data = rf.clip_by_value(position_data, 0, dyn_size_ext - 1) # determine all common axes of input_data and position_data common_axes_input, common_axes_position, input_axes, position_axes = self._get_common_input_position_axes( diff --git a/returnn/torch/frontend/_backend.py b/returnn/torch/frontend/_backend.py index ec5c68df7..88a42fc4b 100644 --- a/returnn/torch/frontend/_backend.py +++ b/returnn/torch/frontend/_backend.py @@ -902,13 +902,7 @@ def full( ) @staticmethod - def gather( - source: Tensor, - *, - indices: Union[Tensor, int], - axis: Dim, - clip_to_valid: bool = False, - ) -> Tensor: + def gather(source: Tensor, *, indices: Union[Tensor, int], axis: Dim, clip_to_valid: bool = False) -> Tensor: """ Gather. @@ -933,19 +927,10 @@ def gather( raise TypeError(f"Unsupported type for indices: {type(indices)}") axis_int = source.get_axis_from_description(axis, allow_int=False) if clip_to_valid: - indices = indices.copy() - dim: Dim = source.dims[axis_int] - if dim.dyn_size_ext: - assert dim.dyn_size_ext.dims_set.issubset( - indices.dims_set - ), f"gather with clip_to_valid: indices ({indices}) dims must be a superset of {dim} dyn-size" - size = dim.dyn_size_ext.copy_compatible_to(indices, check_sparse=False) - indices.raw_tensor = torch.clamp( - indices.raw_tensor, - torch.tensor(0, device=indices.raw_tensor.device), - (size.raw_tensor - 1).to(indices.raw_tensor.device), - ) + if axis.dyn_size_ext: + indices = rf.clip_by_value(indices, 0, axis.dyn_size_ext - 1) else: + indices = indices.copy() indices.raw_tensor = torch.clamp(indices.raw_tensor, 0, source.raw_tensor.shape[axis_int] - 1) index_own_dims = [dim for dim in indices.dims if dim not in source.dims or dim == axis] out = Tensor( diff --git a/tests/test_rf_array.py b/tests/test_rf_array.py index bc4515193..183957a96 100644 --- a/tests/test_rf_array.py +++ b/tests/test_rf_array.py @@ -334,6 +334,23 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step) +def test_gather_time_static_clip_to_valid(): + time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) + in_dim = Dim(7, name="in") + extern_data_template = TensorDict( + { + "data": Tensor("data", [batch_dim, time_dim, in_dim], feature_dim=in_dim, dtype="float32"), + } + ) + + def _forward_step(*, extern_data: TensorDict, **_kwargs): + x = extern_data["data"] + out = rf.gather(x, indices=0, axis=time_dim, clip_to_valid=True) + out.mark_as_default_output(shape=(batch_dim, in_dim)) + + run_model(extern_data_template, lambda *, epoch, step: rf.Module(), _forward_step) + + def test_slice(): time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) in_dim = Dim(7, name="in")