Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when converting nn.Transformer #133

Open
HaoKang-Timmy opened this issue Sep 26, 2021 · 0 comments
Open

Error when converting nn.Transformer #133

HaoKang-Timmy opened this issue Sep 26, 2021 · 0 comments

Comments

@HaoKang-Timmy
Copy link

This is the onnx model I get from pytorch

import torch
import torch.nn as nn
src = torch.rand((10,32,10))
class Former(nn.Module):
    def __init__(self):
        super(Former, self).__init__()
        self.linear1 = nn.Linear(10,512)
        self.linear2 = nn.Linear(10,512)
        self.transformer = nn.Transformer()
    def forward(self,input):
        input1 = self.linear1(input)
        input2 = self.linear2(input)
        output = self.transformer(input1,input2)
        
        return output
src = torch.rand(1,1,10)
model = Former()
torch.onnx.export(model,src,"transformer.onnx",verbose = True,input_names=["input"], opset_version= 11)

Then I tried to use onnx2keras to translate.

from onnx2keras import onnx_to_keras
import keras
import onnx

onnx_model = onnx.load('linear.onnx')
k_model = onnx_to_keras(onnx_model, ['input'])

keras.models.save_model(k_model,'./kerasModel.h5',overwrite=True,include_optimizer=True)

And I have this error.

Traceback (most recent call last):
  File "/home/kh/Documents/onnx2keras1.py", line 7, in <module>
    k_model = onnx_to_keras(onnx_model, ['input'])
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/onnx2keras/converter.py", line 175, in onnx_to_keras
    AVAILABLE_CONVERTERS[node_type](
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/onnx2keras/elementwise_layers.py", line 60, in convert_elementwise_add
    input_0 = ensure_tf_type(layers[node.input[0]], layers[list(layers)[0]], name="%s_const1" % keras_name)
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/onnx2keras/utils.py", line 45, in ensure_tf_type
    return lambda_layer(fake_input_layer)
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/keras/engine/base_layer.py", line 976, in __call__
    return self._functional_construction_call(inputs, args, kwargs,
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1114, in _functional_construction_call
    outputs = self._keras_tensor_symbolic_call(
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/keras/engine/base_layer.py", line 848, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/keras/engine/base_layer.py", line 888, in _infer_output_signature
    outputs = call_fn(inputs, *args, **kwargs)
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/keras/layers/core.py", line 903, in call
    result = self.function(inputs, **kwargs)
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/onnx2keras/utils.py", line 42, in target_layer
    return tf.constant(inp, dtype=inp.dtype)
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 271, in constant
    return _constant_impl(value, dtype, shape, name, verify_shape=False,
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 288, in _constant_impl
    tensor_util.make_tensor_proto(
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/framework/tensor_util.py", line 564, in make_tensor_proto
    append_fn(tensor_proto, proto_values)
  File "tensorflow/python/framework/fast_tensor_util.pyx", line 127, in tensorflow.python.framework.fast_tensor_util.AppendObjectArrayToTensorProto
  File "/home/kh/anaconda3/envs/3.8/lib/python3.8/site-packages/tensorflow/python/util/compat.py", line 86, in as_bytes
    raise TypeError('Expected binary or unicode string, got %r' %
TypeError: Expected binary or unicode string, got 1536
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant