Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -819,13 +819,13 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
PrimExpr local_indices_flattened =
local_tensor.OffsetOf(local_indices_transformed).back();
if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, col_var->var,
col_var->dom->extent, 2, analyzer)) {
IndicesCanVectorize(local_indices_flattened, col_var->var,
col_var->dom->extent, 2, analyzer)) {
is_transposed = false;
} else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans,
local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, row_var->var,
row_var->dom->extent, 2, analyzer)) {
IndicesCanVectorize(local_indices_flattened, row_var->var,
row_var->dom->extent, 2, analyzer)) {
is_transposed = true;
} else {
// TMA ldmatrix/stmatrix cannot support non-8x8 layout, will be fallback to
Expand All @@ -841,8 +841,8 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
return LowerNormalCopy(T, analyzer);
}
PrimExpr flattened_indice = shared_tensor.OffsetOf(shared_indices).back();
if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
loop_vars.back()->dom->extent, 8, analyzer)) {
if (!IndicesCanVectorize(flattened_indice, loop_vars.back()->var,
loop_vars.back()->dom->extent, 8, analyzer)) {
// TMA ldmatrix/stmatrix cannot support non-16 bytes continuous layout, will
// be fallback to normal copy
return LowerNormalCopy(T, analyzer);
Expand Down
22 changes: 12 additions & 10 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
<< ")" << "\n";
}
// vector_size may be greater than local/fragment buffers' vector_size.
// In such case, we need to re-validate if the indices are invariant
// at the new vector_size boundary. If not invariant, take GCD.
// In such case, we need to re-validate if the indices are vectorizable
// at the new vector_size boundary. If not, take GCD.
for (const auto &info : local_fragment_buffers) {
if (vector_size_ > info.vector_size && !info.indices.empty()) {
// Compute elem_offset from indices and strides
Expand All @@ -289,8 +289,9 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
for (size_t i = 0; i < info.indices.size(); ++i) {
elem_offset += info.indices[i] * strides[i];
}
if (!IsExprInvariantInVectorBoundary(
elem_offset, inner_for_->loop_var, vector_size_, analyzer_)) {
if (!IndicesCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_,
analyzer_)) {
// Not invariant at this vector_size, need to take GCD
int old_vector_size = vector_size_;
vector_size_ = arith::ZeroAwareGCD(vector_size_, info.vector_size);
Expand Down Expand Up @@ -578,9 +579,9 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
}
// 4. Try to find max vectorize size for this buffer
while (buffer_vec_size > 1 &&
!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, buffer_vec_size,
analyzer_)) {
!IndicesCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, buffer_vec_size,
analyzer_)) {
buffer_vec_size /= 2;
}
return buffer_vec_size;
Expand Down Expand Up @@ -721,9 +722,10 @@ bool IsExprInvariantInVectorBoundary(const PrimExpr &expr, Var var,
return false;
}

bool IndiceCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer) {
bool IndicesCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size,
arith::Analyzer *analyzer) {
ICHECK(target_vectorized_size >= 1);
if (target_vectorized_size == 1)
return true;
Expand Down
6 changes: 3 additions & 3 deletions src/transform/loop_vectorize.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ bool IsExprInvariantInVectorBoundary(const PrimExpr &expr, Var var,
int target_vectorized_size,
arith::Analyzer *analyzer);

bool IndiceCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer);
bool IndicesCanVectorize(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer);

} // namespace tl
} // namespace tvm
Expand Down
Loading