Skip to content

Commit

Permalink
CumConcatLayer: First draft
Browse files Browse the repository at this point in the history
  • Loading branch information
Zettelkasten committed Nov 22, 2020
1 parent efe762b commit bb04e4c
Showing 1 changed file with 142 additions and 0 deletions.
142 changes: 142 additions & 0 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import typing
import tensorflow as tf
import returnn.tf.compat as tf_compat

try:
from tensorflow.python.ops.nn import rnn_cell
except ImportError:
Expand Down Expand Up @@ -8202,3 +8203,144 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
# length will be ``network.get_rec_step_index() + 1``.
data.shape = (1, None, n_out)
return data


class CumConcatLayer(_ConcatInputLayer):
"""
Concatenates all previous frames of a time-axis. This should be used inside a RecLayer.
If this layer receives input of the form [..., D] for every time step t, it will add a new time-dimension and output
a tensor of form [t, ..., D] in the t-th time frame where all frames <= t are concatenated. Namely, for the first
frame this will return a tensor of form [1, ..., D], for the second [2, ..., D], and so on.
This axis t this layer concatenates along will be marked as ``stag:rec-history``.
If the time-axis already exists (e.g. if the recurrent loop this layer is in was automatically optimized away), this
layer will not add an additional time-axis but rather change the DimensionTag of the existing time axis to
``stag:rec-history`` and simply return the input-tensor unchanged.
This imitates the behaviour of being in the final time frame if in a recurrent loop.
This layer should be used in combination with the :class:`RecHistoryMaskLayer` to prevent access to future frames when
the loop is optimized away.
"""
layer_class = "cum_concat"
recurrent = True # order matters

def __init__(self, axis=":i", **kwargs):
"""
:param str axis: Which time axis to accumulate over, or "rec-frame" / ":i" if this should accumulate over the most
inner recurrent time-dimension. See :func:`Data.get_axis_from_description`
"""
super(CumConcatLayer, self).__init__(**kwargs)

data = self.output
# The rec-history axis is always created by get_out_data_from_opts
axis = data.get_axis_by_tag_name("rec-history")

if self.network.is_inside_rec_layer():
# If inside a RecLayer, create a new rec-history axis to accumulate over
last_frames = self._rec_previous_layer.rec_vars_outputs["state"] # [t, ..., D]
current_frame = data.placeholder # [1, ..., D]
concat_frames = tf.concat([last_frames, current_frame], axis=axis) # [t+1, ..., D]
self.rec_vars_outputs["state"] = concat_frames
data.placeholder = concat_frames
dyn_size = tf.tile(tf.expand_dims(self.network.get_rec_step_index() + 1, axis=0), [data.get_batch_dim()])
else:
# If not inside a RecLayer, this layer is a no-op
# leave data.placeholder unchanged, only adjust dim tag
dyn_size = tf.identity(data.get_dynamic_size(axis))

# We already set the size_placeholder to a dummy rec-history before, now do it properly
from returnn.tf.util.basic import DimensionTag
tag = DimensionTag(
description="rec-history:%s" % self.get_absolute_name(),
kind=DimensionTag.Types.Time)
data.size_placeholder[data.get_batch_axis_excluding_batch(axis)] = dyn_size
tag.set_tag_on_size_tensor(dyn_size)

@classmethod
def _get_rec_history_axis(cls, data, axis):
"""
:param Data data:
:param str axis:
:rtype: int|None
"""
if axis.lower() in [":i", "rec-frame"]:
# TODO: Will data.time_dim_axis always be the time axis corresponding to the recurrent loop?
return data.time_dim_axis
else:
try:
return data.get_axis_from_description(axis)
except AssertionError:
# Assume axis does not exist
return None

@classmethod
def get_out_data_from_opts(cls, name, network, sources, axis=":i", **kwargs):
"""
:param str name:
:param returnn.tf.network.TFNetwork network:
:param list[LayerBase] sources:
:param str axis:
:rtype: Data
"""
data = get_concat_sources_data_template(sources, name="%s_output" % name)
axis = cls._get_rec_history_axis(data, axis)

if axis is None:
# TODO: Is there a better way to figure out we are actually in a RecLayer?
# Assume inside a RecLayer, add a new rec-history time axis
axis = 0
# placeholder will get dim 1 (i.e. exactly one slice for frame t), but shape has to be None (= variable)
data = data.copy_add_spatial_dim(dim=1, spatial_dim_axis=axis)
axis_wo_batch = data.get_batch_axis_excluding_batch(axis)
data.shape = data.shape[:axis_wo_batch] + (None,) + data.shape[axis_wo_batch+1:]

dyn_size = tf.zeros(shape=()) # temporarily, will be overridden
else:
dyn_size = tf.identity(data.get_dynamic_size(axis))

# Temporarily set the size_placeholder. This is somewhat hacky but following layers need to know how our output is
# called, otherwise one cannot use 'stag:rec-history' as axis during their template construction in
# get_out_data_from_opts.
from returnn.tf.util.basic import DimensionTag
tag = DimensionTag(
description="rec-history",
kind=DimensionTag.Types.Time)
data.size_placeholder[data.get_batch_axis_excluding_batch(axis)] = dyn_size
tag.set_tag_on_size_tensor(dyn_size)
return data

@classmethod
def get_rec_initial_extra_outputs(cls, network, batch_dim, rec_layer, axis=":i", sources=(), **kwargs):
"""
:param returnn.tf.network.TFNetwork network:
:param tf.Tensor batch_dim:
:param TFNetworkRecLayer.RecLayer|LayerBase rec_layer:
:param str axis:
:param list[LayerBase] sources:
:rtype: dict[str,tf.Tensor]
"""
if network.is_inside_rec_layer():
data = get_concat_sources_data_template(sources)
data = data.copy_add_spatial_dim(dim=0, spatial_dim_axis=0)

return {"state": tf.zeros(data.get_batch_shape(batch_dim=batch_dim), dtype=data.dtype)}
else:
return {}

@classmethod
def get_rec_initial_extra_outputs_shape_invariants(cls, network, sources, **kwargs):
"""
:param returnn.tf.network.TFNetwork network:
:param list[LayerBase] sources:
:rtype: dict[str, tf.TensorShape]
"""
if network.is_inside_rec_layer():
data = get_concat_sources_data_template(sources)
data = data.copy_add_spatial_dim(dim=1, spatial_dim_axis=0)
axis = data.time_dim_axis

shape = data.get_batch_shape(batch_dim=None)
shape = shape[:axis] + (None,) + shape[axis+1:]
return {"state": tf.TensorShape(shape)}
else:
return {}

0 comments on commit bb04e4c

Please sign in to comment.