Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

TransformerEncoder expects valid_length float dtype, but BERTSentenceTransform returns int dtype #1014

Open
zeeshansayyed opened this issue Nov 19, 2019 · 4 comments
Labels
bug Something isn't working numpyrefactor

Comments

@zeeshansayyed
Copy link

I am using the bert model from gluonnlp (bert_12_768_12, book_corpus_wiki_en_uncased).
Since, I wanted to be able to switch betweem bert and roberta, I am using the BERTSentenceTransform that is present on GitHub instead of the one present in the release version of 0.8.1.

I created my own dataloader and other preprocessing steps very much in line with the code present in this example.

But I am facing a very weird error when I try to call the forward function of my network. The stacktrace is as follows:

Traceback (most recent call last):
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/ec2-user/projects/OneEncoder/tests/test_net.py", line 70, in <module>
    output_token_ids, output_valid_length
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/gluon/block.py", line 548, in __call__
    out = self.forward(*args)
  File "/home/ec2-user/projects/OneEncoder/net.py", line 122, in forward
    input_token_ids, input_token_types, input_valid_length
  File "/home/ec2-user/projects/OneEncoder/net.py", line 98, in encode
    input_token_ids, input_token_types, input_valid_length
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/gluonnlp/model/bert.py", line 429, in __call__
    valid_length, masked_positions)
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/gluon/block.py", line 548, in __call__
    out = self.forward(*args)
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/gluon/block.py", line 925, in forward
    return self.hybrid_forward(ndarray, x, *args, **params)
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/gluonnlp/model/bert.py", line 442, in hybrid_forward
    seq_out, attention_out = self._encode_sequence(inputs, token_types, valid_length)
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/gluonnlp/model/bert.py", line 478, in _encode_sequence
    outputs, additional_outputs = self.encoder(embedding, valid_length=valid_length)
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/gluonnlp/model/transformer.py", line 445, in __call__
    return super(BaseTransformerEncoder, self).__call__(inputs, states, valid_length)
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/gluonnlp/model/seq2seq_encoder_decoder.py", line 151, in __call__
    return super(Seq2SeqEncoder, self).__call__(inputs, valid_length, states)
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/gluon/block.py", line 548, in __call__
    out = self.forward(*args)
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/gluon/block.py", line 925, in forward
    return self.hybrid_forward(ndarray, x, *args, **params)
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/gluonnlp/model/transformer.py", line 507, in hybrid_forward
    F.reshape(valid_length, shape=(-1, 1)))
  File "<string>", line 46, in broadcast_lesser
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/_ctypes/ndarray.py", line 92, in _imperative_invoke
    ctypes.byref(out_stypes)))
  File "/home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/base.py", line 253, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [18:05:40] src/operator/contrib/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected float32, got int32
Stack trace:
  [bt] (0) /home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x4b04cb) [0x7f0bee4794cb]
  [bt] (1) /home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x556b53) [0x7f0bee51fb53]
  [bt] (2) /home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x557480) [0x7f0bee520480]
  [bt] (3) /home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x557757) [0x7f0bee520757]
  [bt] (4) /home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/libmxnet.so(mxnet::imperative::SetShapeType(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, mxnet::DispatchMode*)+0xf68) [0x7f0bf062b5e8]
  [bt] (5) /home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x1db) [0x7f0bf0635a0b]
  [bt] (6) /home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2565409) [0x7f0bf052e409]
  [bt] (7) /home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/site-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x6f) [0x7f0bf052e9ff]
  [bt] (8) /home/ec2-user/anaconda3/envs/mtdnn/lib/python3.7/lib-dynload/../../libffi.so.6(ffi_call_unix64+0x4c) [0x7f0c2d16fec0]

Surprisingly, I do not get this error and everything works fine when I change these lines in the BERTSentenceTransform class as follows:

return np.array(input_ids, dtype='int32'), np.array(valid_length, dtype='float32'),\
            np.array(segment_ids, dtype='int32')

I am at my wit's end trying to debug this. Please help me. Also, if you want more information about my code, I can provide.

Thanks

@leezu
Copy link
Contributor

leezu commented Nov 20, 2019

It's due to MXNet's common requirement that inputs to an operator are of the same dtype. The requirement will be relaxed as part of apache/mxnet#14253 (matching numpy's behavior).

Unfortunately there seems to be an inconsistency in GluonNLP, where transformer_encoder(inputs, states, valid_length) expects valid_length is float dtype (same dtype as inputs), but the transform function returns integer dtype. I'll take a look at fixing it as well as why our tests don't catch this problem.

@leezu leezu added the bug Something isn't working label Nov 20, 2019
@leezu leezu changed the title Very peculiar error related to dtype of valid_length while using bert model TransformerEncoder expects valid_length float dtype, but BERTSentenceTransform returns int dtype Nov 20, 2019
@leezu
Copy link
Contributor

leezu commented Nov 20, 2019

The example you cite casts the valid_length dtype

valid_length = valid_length.as_in_context(ctx).astype('float32')

For now you should do the same.

@zeeshansayyed
Copy link
Author

@leezu Yes, you are right. The example does cast it. I somehow missed that. Thanks for your help. Should I keep this issue open?

@leezu
Copy link
Contributor

leezu commented Nov 21, 2019

I'm not aware of an immediate solution to the problem that wouldn't risk breaking other users code. But it's still an inconsistency / bug that should be fixed. Maybe we can fix it when migrating to MXNet's new numpy compatible interface. So let's leave it open.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working numpyrefactor
Projects
None yet
Development

No branches or pull requests

2 participants