@@ -105,7 +105,8 @@ std::vector<bool> GetCascadableAxes(const Part& part) {
105105 return cascadable_axes;
106106}
107107
108- std::vector<StripeConfig> GenerateOutputStripeConfigs (const Part& part, int stripe_factors) {
108+ std::vector<StripeConfig> GenerateOutputStripeConfigs (const Part& part, int stripe_factors,
109+ bool enable_striping) {
109110 // If stripe_factors is <= 0, then we won't produce any StripeConfigs
110111 if (stripe_factors <= 0 ) {
111112 return std::vector<StripeConfig>();
@@ -134,7 +135,7 @@ std::vector<StripeConfig> GenerateOutputStripeConfigs(const Part& part, int stri
134135 auto axis = output_shape[i];
135136 auto axis_align = part->GetStripeAlignHint ()[i];
136137 std::set<int > axis_splits; // Note this is a set to remove duplicate splits
137- if (!cascadable_axes[i]) {
138+ if (!cascadable_axes[i] || (!enable_striping) ) {
138139 axis_splits.insert (axis);
139140 } else {
140141 for (float factor : factors) {
@@ -436,7 +437,7 @@ std::unordered_map<std::vector<Part>, std::vector<Plan>> GenerateGraphPlans(
436437 // output of a Plan. The number generated is a function of stripe_factors and the number of
437438 // cascadable dimensions in the Part.
438439 std::vector<StripeConfig> stripe_configs =
439- GenerateOutputStripeConfigs (part, options->stripe_factors );
440+ GenerateOutputStripeConfigs (part, options->stripe_factors , options-> enable_striping );
440441 // Check to see if the output Tensor is part of any existing open Plans
441442 if (stripe_configs_by_tensor.find (part->GetOutputTensor ()) != stripe_configs_by_tensor.end ()) {
442443 // If there are other open Plans which have this Part's output Tensor as an input, then
@@ -514,11 +515,12 @@ std::unordered_map<std::vector<Part>, std::vector<Plan>> GenerateGraphPlans(
514515}
515516
516517TVM_REGISTER_GLOBAL (" contrib.ethosu.cascader.GenerateOutputStripeConfigs" )
517- .set_body_typed([](Part part, int stripe_factors) {
518+ .set_body_typed([](Part part, int stripe_factors, bool enable_striping ) {
518519 if (stripe_factors < 0 ) {
519520 return Array<StripeConfig>();
520521 }
521- return Array<StripeConfig>(GenerateOutputStripeConfigs (part, stripe_factors));
522+ return Array<StripeConfig>(
523+ GenerateOutputStripeConfigs (part, stripe_factors, enable_striping));
522524 });
523525
524526TVM_REGISTER_GLOBAL (" contrib.ethosu.cascader.GenerateSinglePlans" )
0 commit comments