Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 63 additions & 42 deletions paddle/cinn/optim/vectorize_for_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,40 @@ Expr GetOriginOffsetWithVectorizeAxis(Expr offset, Var var_iter) {
return origin_offset;
}

bool CheckoutTensorAddrLegalCastToVectorize(const std::vector<ir::Expr> &shapes,
const int vectorize_factor) {
int64_t nums = 1;
for (auto &size : shapes) {
auto const_val = size.As<ir::IntImm>();
bool CheckTensorAddrLegalCastToVectorize(const std::vector<ir::Expr> &indices,
const std::vector<ir::Expr> &shapes,
const int vectorize_factor) {
auto flattened_value = 1;
for (int i = 0; i < indices.size(); ++i) {
auto const_val = shapes[i].As<ir::IntImm>();
PADDLE_ENFORCE_NOT_NULL(const_val,
::common::errors::InvalidArgument(
"vectorize tiling only support static shape"));
nums *= const_val->value;
ir::Expr index = indices[i];
index = optim::ArithSimplify(index);
if (index.is_constant() && index.get_constant() == 0) {
// If the index is zero (indicating broadcast behavior), reset
// flattened_value to 1.
flattened_value = 1;
} else {
flattened_value *= const_val->value;
}
}
if (nums % vectorize_factor == 0) return true;
return false;
return flattened_value % vectorize_factor == 0;
}

bool CheckTensorOffsetZero(const Expr &offset, const Var &iter_var) {
Expr only_vectorize_axis_offset =
UpdateOffsetOnlyContainsVectorizeAxis(offset, iter_var);
Expr origin_offset =
GetOriginOffsetWithVectorizeAxis(only_vectorize_axis_offset, iter_var);
Expr compare = CalculateOffsetWithVectorizeAxis(
only_vectorize_axis_offset, origin_offset, iter_var, 1);
auto const_val = compare.As<ir::IntImm>();
if (!const_val || const_val->value != 0) {
return false;
}
return true;
}

class ForOpWithMultiScheduleBlockSupportVectorize
Expand Down Expand Up @@ -331,7 +353,7 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
* {
* vectorize[4] for (v1, 0, 4)
* {
* float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4]
* float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4]
* }
* }
* }
Expand All @@ -345,14 +367,39 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
* {
* vectorize[4] for (v1, 0, 4)
* {
* float a[i, j, v1] = select(i < 2, float b[i, j, v1], float c[i - 2,
* float a[i, j, v1] = select(i < 2, float b[i, j, v1], float c[i - 2,
* j, v1])
* }
* }
* }
* c[i - 2, j, v1] when i = 0, j = 0, v1 = 0, offset = -128
*
* Situation 3 . Tensor offset with vectorize axis is continuous
* Situation 3. Do not handle the scenario where there is a % b in the
* computation of the index, but b % factor != 0. serial for (i, 0, 4)
* {
* serial for (j, 0, 16)
* {
* vectorize[4] for (v1, 0, 4)
* {
* float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) % 3]
* }
* }
* }
*
* misaligned address
*
* Situation 4. Do not handle the offset 0.
* {
* serial for (j, 0, 16)
* {
* vectorize[4] for (v1, 0, 4)
* {
* float a[i, j, v1] = float b[(i * 64 + j * 4 + v1) / 4]
* }
* }
* }
*
* misaligned address
*/
bool TensorCanVectorize(ir::LoadStoreAddrMnger *node,
const std::vector<ir::Expr> &indices) {
Expand All @@ -378,40 +425,14 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
return false;
}

Expr only_vectorize_axis_offset =
UpdateOffsetOnlyContainsVectorizeAxis(offset, iter_var_);

Expr origin_offset =
GetOriginOffsetWithVectorizeAxis(only_vectorize_axis_offset, iter_var_);
bool offset_is_zero = true;
bool tensor_is_continuous = true;
for (int i = 1; i < factor_; i++) {
Expr compare = CalculateOffsetWithVectorizeAxis(
only_vectorize_axis_offset, origin_offset, iter_var_, i);
auto const_val = compare.As<ir::IntImm>();
if (!const_val) return false;
if (const_val->value != 0) {
offset_is_zero = false;
}

if (const_val->value != i) {
tensor_is_continuous = false;
break;
}
}

if (offset_is_zero) {
return false;
}

if (!tensor_is_continuous) {
vectorize_tensors_.clear();
scalar_tensor_without_vectorize_axis_.clear();
schedule_block_can_vectorize_ = false;
// situation 3. Do not handle the scenario where there is a % b in the
// computation of the index, but b % factor != 0.
if (!CheckTensorAddrLegalCastToVectorize(indices, tensor->shape, factor_)) {
return false;
}

if (!CheckoutTensorAddrLegalCastToVectorize(tensor->shape, factor_)) {
// situation 4. Do not handle the offset 0.
if (CheckTensorOffsetZero(offset, iter_var_)) {
return false;
}

Expand Down