Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Making transformer encoder fully hybridized for export #789

Closed
eric-haibin-lin opened this issue Jun 21, 2019 · 5 comments
Closed

Making transformer encoder fully hybridized for export #789

eric-haibin-lin opened this issue Jun 21, 2019 · 5 comments
Labels
enhancement New feature or request

Comments

@eric-haibin-lin
Copy link
Member

Currently BERT can be exported with static length support. BERTEncoder inherits Transformer encoder, which contains a few .shape API calls, making the transformer encoder not fully hybridizable and create issues during export. As the result, the exported model only supports static length. These calls are located here:
https://github.com/dmlc/gluon-nlp/blob/master/src/gluonnlp/model/transformer.py#L450-L463

We need to remove these calls to export a model that supports variable length.

handling arange

In particular, we have

length = inputs.shape[1]
arange = mx.nd.arange(length, ctx=valid_length.context, dtype=valid_length.dtype)

To remove these .shape calls, we have 2 options:

contrib.arange_like op with ndarray input

Instead of contrib.arange(arr.shape[1], ...), we can introduce an arange_like op:

  • input: arr with shape (x,) and abitrary data
  • output: an output with shape arr.shape, and value of [0, 1, 2, ... size(arr) - 1].

With this op, we just need to slice the inputs on axis 1 and pass it to arange_like op:

arr = inputs.slice(begin=(0,0,0), end=(0,None,0)
arange = F.contrib.arange(arr)

control flow op

Alternatively, we can use control flow op (either foreach, or while loop) to loop N times, where N = inputs.shape[1]. Loop i fills in the value i in the output "arange" array.

However, this may have high overhead when N is large (512).

handling other .shape calls

  • mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=1), axis=1, size=length) can be replaced with broadcast_mul op with ones_like(arr)
  • inputs * math.sqrt(inputs.shape[-1]) can be replace with shape_nd op

@TaoLv

@pengxin99
Copy link
Contributor

Just wonder, the op arange, broadcast_axes, expand_dims is all can be called by both NDArray and Symbol, so the all effort is to remove input.shape[] operator because symbol do not have shape operator?
thx @eric-haibin-lin @TaoLv

BTW. we can use export() to get the hybrid model derictly from train_transformer.py .

@nickguletskii
Copy link

The contrib.index_array operator could probably be used to replace these shape calls:

http://mxnet.incubator.apache.org/versions/master/api/python/symbol/contrib.html#mxnet.symbol.contrib.index_array

For instance, the code in BaseTransformerEncoder can be replaced with something like this:

template = F.squeeze(F.slice_axis(inputs, axis=-1, begin=0, end=1), axis=-1)
arange = F.squeeze(F.contrib.index_array(template, axes=(1,)), axis=-1)
mask = F.broadcast_lesser(arange, valid_length.reshape((-1, 1)))
mask = F.broadcast_axes(F.expand_dims(mask, axis=1), input, lhs_axes=(1,), rhs_axes=(1,))
...
template = inputs.slice(begin=(0,0,0), end=(0,None,0)).reshape((-1,))
steps = F.index_array(template)
steps = F.squeeze(steps, axis=-1)

@eric-haibin-lin
Copy link
Member Author

@pengxin99 yes, the reason is that symbol does not have shape API.
@nickguletskii Nice proposal, I'll take a look at the index_array op

@eric-haibin-lin
Copy link
Member Author

Another alternative to fix arange is sth like:

F.arange(max_length).slice_like(inputs, axis=1)

where max_length is sth like 512.

@eric-haibin-lin
Copy link
Member Author

Fixed using arange(infer_range=True)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants