From 53504a05e64d5743542eb7f2cb91a2c5c0d62fc3 Mon Sep 17 00:00:00 2001 From: baoqiwen Date: Mon, 7 Apr 2025 19:20:02 +0800 Subject: [PATCH] Fix the bug of misaligned address in vectorization. --- paddle/cinn/optim/vectorize_for_trans.cc | 105 ++++++++++++++--------- 1 file changed, 63 insertions(+), 42 deletions(-) diff --git a/paddle/cinn/optim/vectorize_for_trans.cc b/paddle/cinn/optim/vectorize_for_trans.cc index e63d389ca5ab4e..33319f3250e7c9 100644 --- a/paddle/cinn/optim/vectorize_for_trans.cc +++ b/paddle/cinn/optim/vectorize_for_trans.cc @@ -141,18 +141,40 @@ Expr GetOriginOffsetWithVectorizeAxis(Expr offset, Var var_iter) { return origin_offset; } -bool CheckoutTensorAddrLegalCastToVectorize(const std::vector &shapes, - const int vectorize_factor) { - int64_t nums = 1; - for (auto &size : shapes) { - auto const_val = size.As(); +bool CheckTensorAddrLegalCastToVectorize(const std::vector &indices, + const std::vector &shapes, + const int vectorize_factor) { + auto flattened_value = 1; + for (int i = 0; i < indices.size(); ++i) { + auto const_val = shapes[i].As(); PADDLE_ENFORCE_NOT_NULL(const_val, ::common::errors::InvalidArgument( "vectorize tiling only support static shape")); - nums *= const_val->value; + ir::Expr index = indices[i]; + index = optim::ArithSimplify(index); + if (index.is_constant() && index.get_constant() == 0) { + // If the index is zero (indicating broadcast behavior), reset + // flattened_value to 1. + flattened_value = 1; + } else { + flattened_value *= const_val->value; + } } - if (nums % vectorize_factor == 0) return true; - return false; + return flattened_value % vectorize_factor == 0; +} + +bool CheckTensorOffsetZero(const Expr &offset, const Var &iter_var) { + Expr only_vectorize_axis_offset = + UpdateOffsetOnlyContainsVectorizeAxis(offset, iter_var); + Expr origin_offset = + GetOriginOffsetWithVectorizeAxis(only_vectorize_axis_offset, iter_var); + Expr compare = CalculateOffsetWithVectorizeAxis( + only_vectorize_axis_offset, origin_offset, iter_var, 1); + auto const_val = compare.As(); + if (!const_val || const_val->value != 0) { + return false; + } + return true; } class ForOpWithMultiScheduleBlockSupportVectorize @@ -331,7 +353,7 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator { * { * vectorize[4] for (v1, 0, 4) * { - * float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4] + * float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4] * } * } * } @@ -345,14 +367,39 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator { * { * vectorize[4] for (v1, 0, 4) * { - * float a[i, j, v1] = select(i < 2, float b[i, j, v1], float c[i - 2, + * float a[i, j, v1] = select(i < 2, float b[i, j, v1], float c[i - 2, * j, v1]) * } * } * } * c[i - 2, j, v1] when i = 0, j = 0, v1 = 0, offset = -128 * - * Situation 3 . Tensor offset with vectorize axis is continuous + * Situation 3. Do not handle the scenario where there is a % b in the + * computation of the index, but b % factor != 0. serial for (i, 0, 4) + * { + * serial for (j, 0, 16) + * { + * vectorize[4] for (v1, 0, 4) + * { + * float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) % 3] + * } + * } + * } + * + * misaligned address + * + * Situation 4. Do not handle the offset 0. + * { + * serial for (j, 0, 16) + * { + * vectorize[4] for (v1, 0, 4) + * { + * float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4] + * } + * } + * } + * + * misaligned address */ bool TensorCanVectorize(ir::LoadStoreAddrMnger *node, const std::vector &indices) { @@ -378,40 +425,14 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator { return false; } - Expr only_vectorize_axis_offset = - UpdateOffsetOnlyContainsVectorizeAxis(offset, iter_var_); - - Expr origin_offset = - GetOriginOffsetWithVectorizeAxis(only_vectorize_axis_offset, iter_var_); - bool offset_is_zero = true; - bool tensor_is_continuous = true; - for (int i = 1; i < factor_; i++) { - Expr compare = CalculateOffsetWithVectorizeAxis( - only_vectorize_axis_offset, origin_offset, iter_var_, i); - auto const_val = compare.As(); - if (!const_val) return false; - if (const_val->value != 0) { - offset_is_zero = false; - } - - if (const_val->value != i) { - tensor_is_continuous = false; - break; - } - } - - if (offset_is_zero) { - return false; - } - - if (!tensor_is_continuous) { - vectorize_tensors_.clear(); - scalar_tensor_without_vectorize_axis_.clear(); - schedule_block_can_vectorize_ = false; + // situation 3. Do not handle the scenario where there is a % b in the + // computation of the index, but b % factor != 0. + if (!CheckTensorAddrLegalCastToVectorize(indices, tensor->shape, factor_)) { return false; } - if (!CheckoutTensorAddrLegalCastToVectorize(tensor->shape, factor_)) { + // situation 4. Do not handle the offset 0. + if (CheckTensorOffsetZero(offset, iter_var_)) { return false; }