diff --git a/src/op/copy.cc b/src/op/copy.cc index 2c01db367..31a0c0092 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -819,13 +819,13 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, PrimExpr local_indices_flattened = local_tensor.OffsetOf(local_indices_transformed).back(); if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) && - IndiceCanVectorize(local_indices_flattened, col_var->var, - col_var->dom->extent, 2, analyzer)) { + IndicesCanVectorize(local_indices_flattened, col_var->var, + col_var->dom->extent, 2, analyzer)) { is_transposed = false; } else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans, local_layout_thread_map) && - IndiceCanVectorize(local_indices_flattened, row_var->var, - row_var->dom->extent, 2, analyzer)) { + IndicesCanVectorize(local_indices_flattened, row_var->var, + row_var->dom->extent, 2, analyzer)) { is_transposed = true; } else { // TMA ldmatrix/stmatrix cannot support non-8x8 layout, will be fallback to @@ -841,8 +841,8 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, return LowerNormalCopy(T, analyzer); } PrimExpr flattened_indice = shared_tensor.OffsetOf(shared_indices).back(); - if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var, - loop_vars.back()->dom->extent, 8, analyzer)) { + if (!IndicesCanVectorize(flattened_indice, loop_vars.back()->var, + loop_vars.back()->dom->extent, 8, analyzer)) { // TMA ldmatrix/stmatrix cannot support non-16 bytes continuous layout, will // be fallback to normal copy return LowerNormalCopy(T, analyzer); diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 29424a6d8..222a6e79a 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -279,8 +279,8 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { << ")" << "\n"; } // vector_size may be greater than local/fragment buffers' vector_size. - // In such case, we need to re-validate if the indices are invariant - // at the new vector_size boundary. If not invariant, take GCD. + // In such case, we need to re-validate if the indices are vectorizable + // at the new vector_size boundary. If not, take GCD. for (const auto &info : local_fragment_buffers) { if (vector_size_ > info.vector_size && !info.indices.empty()) { // Compute elem_offset from indices and strides @@ -289,8 +289,9 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { for (size_t i = 0; i < info.indices.size(); ++i) { elem_offset += info.indices[i] * strides[i]; } - if (!IsExprInvariantInVectorBoundary( - elem_offset, inner_for_->loop_var, vector_size_, analyzer_)) { + if (!IndicesCanVectorize(elem_offset, inner_for_->loop_var, + inner_for_->extent, vector_size_, + analyzer_)) { // Not invariant at this vector_size, need to take GCD int old_vector_size = vector_size_; vector_size_ = arith::ZeroAwareGCD(vector_size_, info.vector_size); @@ -578,9 +579,9 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { } // 4. Try to find max vectorize size for this buffer while (buffer_vec_size > 1 && - !IndiceCanVectorize(elem_offset, inner_for_->loop_var, - inner_for_->extent, buffer_vec_size, - analyzer_)) { + !IndicesCanVectorize(elem_offset, inner_for_->loop_var, + inner_for_->extent, buffer_vec_size, + analyzer_)) { buffer_vec_size /= 2; } return buffer_vec_size; @@ -721,9 +722,10 @@ bool IsExprInvariantInVectorBoundary(const PrimExpr &expr, Var var, return false; } -bool IndiceCanVectorize(const PrimExpr &expr, Var var, - const PrimExpr &iter_var_size, - int target_vectorized_size, arith::Analyzer *analyzer) { +bool IndicesCanVectorize(const PrimExpr &expr, Var var, + const PrimExpr &iter_var_size, + int target_vectorized_size, + arith::Analyzer *analyzer) { ICHECK(target_vectorized_size >= 1); if (target_vectorized_size == 1) return true; diff --git a/src/transform/loop_vectorize.h b/src/transform/loop_vectorize.h index 591f047e0..214d703f0 100644 --- a/src/transform/loop_vectorize.h +++ b/src/transform/loop_vectorize.h @@ -55,9 +55,9 @@ bool IsExprInvariantInVectorBoundary(const PrimExpr &expr, Var var, int target_vectorized_size, arith::Analyzer *analyzer); -bool IndiceCanVectorize(const PrimExpr &expr, Var var, - const PrimExpr &iter_var_size, - int target_vectorized_size, arith::Analyzer *analyzer); +bool IndicesCanVectorize(const PrimExpr &expr, Var var, + const PrimExpr &iter_var_size, + int target_vectorized_size, arith::Analyzer *analyzer); } // namespace tl } // namespace tvm