Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1319,13 +1319,22 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
// int32. Therefore, we need to divide by the ratio of their
// sizes in that case.
int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();
index_str =
PrintExpr(arith::Analyzer().Simplify(truncdiv(index, div_factor)));

os << "*("
<< "(" << ptr_cast(t) << vid << ")"
<< " + " << index_str << " / " << div_factor << ")";
os << "*((" << ptr_cast(t) << vid << ")" << " + " << index_str << ")";
} else if (t == buffer_element_dtype) {
os << buffer_str << "[" << index_str << "]";
} else {
// Fix fp4 pointer arithmetic: fp4 elements are 4-bit packed 2 per byte.
// fp4* + n incorrectly advances n bytes (skipping 2n elements).
int div_factor = 1;
if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1) {
div_factor = 2;
}
index_str =
PrintExpr(arith::Analyzer().Simplify(truncdiv(index, div_factor)));

os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
}

Expand Down Expand Up @@ -2177,7 +2186,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
<< "[(i % 8) / 4 * " + smem_stride +
" * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
"+ (i % 4) * " + smem_stride +
" + threadIdx.x / 4 + (i / 8) * 8];\n";
" + threadIdx.x / 4 + (i / 8) * 8];\n";
os << "}\n";
} else {
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
Expand Down
Loading