@@ -141,18 +141,40 @@ Expr GetOriginOffsetWithVectorizeAxis(Expr offset, Var var_iter) {
141141 return origin_offset;
142142}
143143
144- bool CheckoutTensorAddrLegalCastToVectorize (const std::vector<ir::Expr> &shapes,
145- const int vectorize_factor) {
146- int64_t nums = 1 ;
147- for (auto &size : shapes) {
148- auto const_val = size.As <ir::IntImm>();
144+ bool CheckTensorAddrLegalCastToVectorize (const std::vector<ir::Expr> &indices,
145+ const std::vector<ir::Expr> &shapes,
146+ const int vectorize_factor) {
147+ auto flattened_value = 1 ;
148+ for (int i = 0 ; i < indices.size (); ++i) {
149+ auto const_val = shapes[i].As <ir::IntImm>();
149150 PADDLE_ENFORCE_NOT_NULL (const_val,
150151 ::common::errors::InvalidArgument (
151152 " vectorize tiling only support static shape" ));
152- nums *= const_val->value ;
153+ ir::Expr index = indices[i];
154+ index = optim::ArithSimplify (index);
155+ if (index.is_constant () && index.get_constant () == 0 ) {
156+ // If the index is zero (indicating broadcast behavior), reset
157+ // flattened_value to 1.
158+ flattened_value = 1 ;
159+ } else {
160+ flattened_value *= const_val->value ;
161+ }
153162 }
154- if (nums % vectorize_factor == 0 ) return true ;
155- return false ;
163+ return flattened_value % vectorize_factor == 0 ;
164+ }
165+
166+ bool CheckTensorOffsetZero (const Expr &offset, const Var &iter_var) {
167+ Expr only_vectorize_axis_offset =
168+ UpdateOffsetOnlyContainsVectorizeAxis (offset, iter_var);
169+ Expr origin_offset =
170+ GetOriginOffsetWithVectorizeAxis (only_vectorize_axis_offset, iter_var);
171+ Expr compare = CalculateOffsetWithVectorizeAxis (
172+ only_vectorize_axis_offset, origin_offset, iter_var, 1 );
173+ auto const_val = compare.As <ir::IntImm>();
174+ if (!const_val || const_val->value != 0 ) {
175+ return false ;
176+ }
177+ return true ;
156178}
157179
158180class ForOpWithMultiScheduleBlockSupportVectorize
@@ -331,7 +353,7 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
331353 * {
332354 * vectorize[4] for (v1, 0, 4)
333355 * {
334- * float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4]
356+ * float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4]
335357 * }
336358 * }
337359 * }
@@ -345,14 +367,39 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
345367 * {
346368 * vectorize[4] for (v1, 0, 4)
347369 * {
348- * float a[i, j, v1] = select(i < 2, float b[i, j, v1], float c[i - 2,
370+ * float a[i, j, v1] = select(i < 2, float b[i, j, v1], float c[i - 2,
349371 * j, v1])
350372 * }
351373 * }
352374 * }
353375 * c[i - 2, j, v1] when i = 0, j = 0, v1 = 0, offset = -128
354376 *
355- * Situation 3 . Tensor offset with vectorize axis is continuous
377+ * Situation 3. Do not handle the scenario where there is a % b in the
378+ * computation of the index, but b % factor != 0. serial for (i, 0, 4)
379+ * {
380+ * serial for (j, 0, 16)
381+ * {
382+ * vectorize[4] for (v1, 0, 4)
383+ * {
384+ * float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) % 3]
385+ * }
386+ * }
387+ * }
388+ *
389+ * misaligned address
390+ *
391+ * Situation 4. Do not handle the offset 0.
392+ * {
393+ * serial for (j, 0, 16)
394+ * {
395+ * vectorize[4] for (v1, 0, 4)
396+ * {
397+ * float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4]
398+ * }
399+ * }
400+ * }
401+ *
402+ * misaligned address
356403 */
357404 bool TensorCanVectorize (ir::LoadStoreAddrMnger *node,
358405 const std::vector<ir::Expr> &indices) {
@@ -378,40 +425,14 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
378425 return false ;
379426 }
380427
381- Expr only_vectorize_axis_offset =
382- UpdateOffsetOnlyContainsVectorizeAxis (offset, iter_var_);
383-
384- Expr origin_offset =
385- GetOriginOffsetWithVectorizeAxis (only_vectorize_axis_offset, iter_var_);
386- bool offset_is_zero = true ;
387- bool tensor_is_continuous = true ;
388- for (int i = 1 ; i < factor_; i++) {
389- Expr compare = CalculateOffsetWithVectorizeAxis (
390- only_vectorize_axis_offset, origin_offset, iter_var_, i);
391- auto const_val = compare.As <ir::IntImm>();
392- if (!const_val) return false ;
393- if (const_val->value != 0 ) {
394- offset_is_zero = false ;
395- }
396-
397- if (const_val->value != i) {
398- tensor_is_continuous = false ;
399- break ;
400- }
401- }
402-
403- if (offset_is_zero) {
404- return false ;
405- }
406-
407- if (!tensor_is_continuous) {
408- vectorize_tensors_.clear ();
409- scalar_tensor_without_vectorize_axis_.clear ();
410- schedule_block_can_vectorize_ = false ;
428+ // situation 3. Do not handle the scenario where there is a % b in the
429+ // computation of the index, but b % factor != 0.
430+ if (!CheckTensorAddrLegalCastToVectorize (indices, tensor->shape , factor_)) {
411431 return false ;
412432 }
413433
414- if (!CheckoutTensorAddrLegalCastToVectorize (tensor->shape , factor_)) {
434+ // situation 4. Do not handle the offset 0.
435+ if (CheckTensorOffsetZero (offset, iter_var_)) {
415436 return false ;
416437 }
417438
0 commit comments