Skip to content

[Bug][Relay] crash at shape checking in SimpleRNN at Keras frontend  #14868

@jikechao

Description

@jikechao

For the RNN model:
image
TVM convert it to RelayIR:

def @main(%input_1: Tensor[(2, 2, 2), float32], %v_param_2: Tensor[(2, 2), float32], %v_param_4: Tensor[(2), float32], %v_param_1: Tensor[(1, 2), float32], %v_param_3: Tensor[(2, 2), float32]) {
  %0 = nn.batch_flatten(%input_1);
  %1 = nn.dense(%0, %v_param_2, units=2);
  %2 = nn.batch_flatten(%v_param_1);
  %3 = nn.bias_add(%1, %v_param_4);
  %4 = nn.dense(%2, %v_param_3, units=2);
  %5 = add(%3, %4);
  %6 = tanh(%5);
  reshape(%6, newshape=[1, 2])
}

and crash when compile it in LLVM.

Expected behavior

Run well

Actual behavior

Traceback (most recent call last):
  File "test.py", line 18, in <module>
    model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 'llvm', params).evaluate()
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/interpreter.py", line 171, in evaluate
    return self._make_executor()
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 219, in _make_executor
    self.executable = compile(self.mod, self.target)
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 67, in compile
    compiler.lower(mod, target, target_host)
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 126, in lower
    self._lower(mod, raw_targets)
  File "/workplace/software/tvm/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  15: TVMFuncCall
  14: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  13: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&)
  12: tvm::relay::vm::VMCompiler::LowerImpl(tvm::IRModule)
  11: tvm::relay::vm::VMCompiler::OptimizeModuleImpl(tvm::IRModule)
  10: tvm::transform::Pass::operator()(tvm::IRModule) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  1: tvm::relay::TypeSolver::Solve()
  0: _ZN3tvm7runtime6detail
  19: TVMFuncCall
  18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  17: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&)
  16: tvm::relay::vm::VMCompiler::LowerImpl(tvm::IRModule)
  15: tvm::relay::vm::VMCompiler::OptimizeModuleImpl(tvm::IRModule)
  14: tvm::transform::Pass::operator()(tvm::IRModule) const
  13: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  12: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  11: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  10: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  6: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  5: tvm::relay::TypeSolver::Solve()
  4: tvm::TypedEnvFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::operator()(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) const
  3: _ZN3tvm7runtime13Pac
  2: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  1: tvm::relay::ReshapeRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  0: _ZN3tvm7runtime6detail
  File "/workplace/software/tvm/tvm/src/relay/analysis/type_solver.cc", line 643
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [07:01:21] /workplace/software/tvm/tvm/src/relay/op/tensor/transform.cc:794: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------

  Check failed: oshape_sum == data_shape_sum (2 vs. 4) : Input tensor shape(2,2) and reshaped shape(1,2) are not compatible!

Steps to reproduce

import tvm
import tvm.relay as relay
from tensorflow import keras
from tensorflow.keras import layers, models

input_shape = (2, 2, 2)
x = layers.Input(shape=input_shape[1:], dtype='float32')

layer = keras.layers.SimpleRNN(units=2)
layer.set_weights(layer.get_weights())

y = layer(x)
model = models.Model(x, y)
model.summary()
mod, params = relay.frontend.from_keras(model, {'input_1': input_shape})
print(mod)
with tvm.transform.PassContext(opt_level=3):
    model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 'llvm', params).evaluate()

Triage

  • frontend:keras

cc @Hzfengsy @echuraev
Could you help me check if this crash trigger a bug in TVM?

cc @shingjan

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions