This repository has been archived by the owner on Jan 15, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 538
Making transformer encoder fully hybridized for export #789
Labels
enhancement
New feature or request
Comments
6 tasks
Just wonder, the op BTW. we can use |
The For instance, the code in
|
@pengxin99 yes, the reason is that symbol does not have shape API. |
Another alternative to fix arange is sth like:
where max_length is sth like 512. |
8 tasks
Fixed using |
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Currently BERT can be exported with static length support. BERTEncoder inherits Transformer encoder, which contains a few
.shape
API calls, making the transformer encoder not fully hybridizable and create issues during export. As the result, the exported model only supports static length. These calls are located here:https://github.com/dmlc/gluon-nlp/blob/master/src/gluonnlp/model/transformer.py#L450-L463
We need to remove these calls to export a model that supports variable length.
handling arange
In particular, we have
To remove these .shape calls, we have 2 options:
contrib.arange_like op with ndarray input
Instead of
contrib.arange(arr.shape[1], ...)
, we can introduce an arange_like op:arr
with shape (x,) and abitrary dataarr.shape
, and value of[0, 1, 2, ... size(arr) - 1]
.With this op, we just need to slice the inputs on axis 1 and pass it to arange_like op:
control flow op
Alternatively, we can use control flow op (either foreach, or while loop) to loop N times, where N = inputs.shape[1]. Loop i fills in the value i in the output "arange" array.
However, this may have high overhead when N is large (512).
handling other
.shape
callsmask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=1), axis=1, size=length)
can be replaced withbroadcast_mul
op withones_like(arr)
inputs * math.sqrt(inputs.shape[-1])
can be replace withshape_nd
op@TaoLv
The text was updated successfully, but these errors were encountered: