diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index aea41c399..9b20a4b6d 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1319,13 +1319,22 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, // int32. Therefore, we need to divide by the ratio of their // sizes in that case. int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes(); + index_str = + PrintExpr(arith::Analyzer().Simplify(truncdiv(index, div_factor))); - os << "*(" - << "(" << ptr_cast(t) << vid << ")" - << " + " << index_str << " / " << div_factor << ")"; + os << "*((" << ptr_cast(t) << vid << ")" << " + " << index_str << ")"; } else if (t == buffer_element_dtype) { os << buffer_str << "[" << index_str << "]"; } else { + // Fix fp4 pointer arithmetic: fp4 elements are 4-bit packed 2 per byte. + // fp4* + n incorrectly advances n bytes (skipping 2n elements). + int div_factor = 1; + if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1) { + div_factor = 2; + } + index_str = + PrintExpr(arith::Analyzer().Simplify(truncdiv(index, div_factor))); + os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; } @@ -2177,7 +2186,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride + "+ (i % 4) * " + smem_stride + - " + threadIdx.x / 4 + (i / 8) * 8];\n"; + " + threadIdx.x / 4 + (i / 8) * 8];\n"; os << "}\n"; } else { std::string smem_elem_offset = this->PrintExpr(op->args[6]);