@@ -33,23 +33,28 @@ namespace ethosu {
3333namespace cascader {
3434
3535void BlockConfigNode::VisitAttrs (AttrVisitor* v) {
36- Array<Integer> tmp_arr = make_array (output_shape_);
36+ Array<Integer> tmp_arr = make_array (input_shape_);
37+ v->Visit (" _input_shape" , &tmp_arr);
38+ tmp_arr = make_array (output_shape_);
3739 v->Visit (" _output_shape" , &tmp_arr);
3840}
3941
40- BlockConfig::BlockConfig (const std::vector<int >& output_shape, int compute_cycles ,
41- int output_cycles) {
42+ BlockConfig::BlockConfig (const std::vector<int >& input_shape, const std::vector< int >& output_shape ,
43+ int compute_cycles, int output_cycles) {
4244 auto n = make_object<BlockConfigNode>();
45+ n->input_shape_ = std::move (input_shape);
4346 n->output_shape_ = std::move (output_shape);
4447 n->compute_cycles_ = compute_cycles;
4548 n->output_cycles_ = output_cycles;
4649 data_ = std::move (n);
4750}
4851
4952TVM_REGISTER_GLOBAL (" contrib.ethosu.cascader.BlockConfig" )
50- .set_body_typed([](Array<Integer> output_shape, int compute_cycles, int output_cycles) {
53+ .set_body_typed([](Array<Integer> input_shape, Array<Integer> output_shape, int compute_cycles,
54+ int output_cycles) {
55+ std::vector<int > vinput_shape = make_vector<int , Integer>(input_shape);
5156 std::vector<int > voutput_shape = make_vector<int , Integer>(output_shape);
52- return BlockConfig (voutput_shape, compute_cycles, output_cycles);
57+ return BlockConfig (vinput_shape, voutput_shape, compute_cycles, output_cycles);
5358 });
5459
5560TVM_REGISTER_NODE_TYPE (BlockConfigNode);
0 commit comments