Skip to content

Commit 018ffdd

Browse files
authored
Fix the bug of misaligned address in vectorization (#72106)
1 parent 46c897e commit 018ffdd

File tree

1 file changed

+63
-42
lines changed

1 file changed

+63
-42
lines changed

paddle/cinn/optim/vectorize_for_trans.cc

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,40 @@ Expr GetOriginOffsetWithVectorizeAxis(Expr offset, Var var_iter) {
141141
return origin_offset;
142142
}
143143

144-
bool CheckoutTensorAddrLegalCastToVectorize(const std::vector<ir::Expr> &shapes,
145-
const int vectorize_factor) {
146-
int64_t nums = 1;
147-
for (auto &size : shapes) {
148-
auto const_val = size.As<ir::IntImm>();
144+
bool CheckTensorAddrLegalCastToVectorize(const std::vector<ir::Expr> &indices,
145+
const std::vector<ir::Expr> &shapes,
146+
const int vectorize_factor) {
147+
auto flattened_value = 1;
148+
for (int i = 0; i < indices.size(); ++i) {
149+
auto const_val = shapes[i].As<ir::IntImm>();
149150
PADDLE_ENFORCE_NOT_NULL(const_val,
150151
::common::errors::InvalidArgument(
151152
"vectorize tiling only support static shape"));
152-
nums *= const_val->value;
153+
ir::Expr index = indices[i];
154+
index = optim::ArithSimplify(index);
155+
if (index.is_constant() && index.get_constant() == 0) {
156+
// If the index is zero (indicating broadcast behavior), reset
157+
// flattened_value to 1.
158+
flattened_value = 1;
159+
} else {
160+
flattened_value *= const_val->value;
161+
}
153162
}
154-
if (nums % vectorize_factor == 0) return true;
155-
return false;
163+
return flattened_value % vectorize_factor == 0;
164+
}
165+
166+
bool CheckTensorOffsetZero(const Expr &offset, const Var &iter_var) {
167+
Expr only_vectorize_axis_offset =
168+
UpdateOffsetOnlyContainsVectorizeAxis(offset, iter_var);
169+
Expr origin_offset =
170+
GetOriginOffsetWithVectorizeAxis(only_vectorize_axis_offset, iter_var);
171+
Expr compare = CalculateOffsetWithVectorizeAxis(
172+
only_vectorize_axis_offset, origin_offset, iter_var, 1);
173+
auto const_val = compare.As<ir::IntImm>();
174+
if (!const_val || const_val->value != 0) {
175+
return false;
176+
}
177+
return true;
156178
}
157179

158180
class ForOpWithMultiScheduleBlockSupportVectorize
@@ -331,7 +353,7 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
331353
* {
332354
* vectorize[4] for (v1, 0, 4)
333355
* {
334-
* float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4]
356+
* float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4]
335357
* }
336358
* }
337359
* }
@@ -345,14 +367,39 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
345367
* {
346368
* vectorize[4] for (v1, 0, 4)
347369
* {
348-
* float a[i, j, v1] = select(i < 2, float b[i, j, v1], float c[i - 2,
370+
* float a[i, j, v1] = select(i < 2, float b[i, j, v1], float c[i - 2,
349371
* j, v1])
350372
* }
351373
* }
352374
* }
353375
* c[i - 2, j, v1] when i = 0, j = 0, v1 = 0, offset = -128
354376
*
355-
* Situation 3 . Tensor offset with vectorize axis is continuous
377+
* Situation 3. Do not handle the scenario where there is a % b in the
378+
* computation of the index, but b % factor != 0. serial for (i, 0, 4)
379+
* {
380+
* serial for (j, 0, 16)
381+
* {
382+
* vectorize[4] for (v1, 0, 4)
383+
* {
384+
* float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) % 3]
385+
* }
386+
* }
387+
* }
388+
*
389+
* misaligned address
390+
*
391+
* Situation 4. Do not handle the offset 0.
392+
* {
393+
* serial for (j, 0, 16)
394+
* {
395+
* vectorize[4] for (v1, 0, 4)
396+
* {
397+
* float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4]
398+
* }
399+
* }
400+
* }
401+
*
402+
* misaligned address
356403
*/
357404
bool TensorCanVectorize(ir::LoadStoreAddrMnger *node,
358405
const std::vector<ir::Expr> &indices) {
@@ -378,40 +425,14 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
378425
return false;
379426
}
380427

381-
Expr only_vectorize_axis_offset =
382-
UpdateOffsetOnlyContainsVectorizeAxis(offset, iter_var_);
383-
384-
Expr origin_offset =
385-
GetOriginOffsetWithVectorizeAxis(only_vectorize_axis_offset, iter_var_);
386-
bool offset_is_zero = true;
387-
bool tensor_is_continuous = true;
388-
for (int i = 1; i < factor_; i++) {
389-
Expr compare = CalculateOffsetWithVectorizeAxis(
390-
only_vectorize_axis_offset, origin_offset, iter_var_, i);
391-
auto const_val = compare.As<ir::IntImm>();
392-
if (!const_val) return false;
393-
if (const_val->value != 0) {
394-
offset_is_zero = false;
395-
}
396-
397-
if (const_val->value != i) {
398-
tensor_is_continuous = false;
399-
break;
400-
}
401-
}
402-
403-
if (offset_is_zero) {
404-
return false;
405-
}
406-
407-
if (!tensor_is_continuous) {
408-
vectorize_tensors_.clear();
409-
scalar_tensor_without_vectorize_axis_.clear();
410-
schedule_block_can_vectorize_ = false;
428+
// situation 3. Do not handle the scenario where there is a % b in the
429+
// computation of the index, but b % factor != 0.
430+
if (!CheckTensorAddrLegalCastToVectorize(indices, tensor->shape, factor_)) {
411431
return false;
412432
}
413433

414-
if (!CheckoutTensorAddrLegalCastToVectorize(tensor->shape, factor_)) {
434+
// situation 4. Do not handle the offset 0.
435+
if (CheckTensorOffsetZero(offset, iter_var_)) {
415436
return false;
416437
}
417438

0 commit comments

Comments
 (0)