diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 3bdaa49aea8..c823a119219 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1276,9 +1276,9 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { std::string genLoadBlockDim() { std::stringstream ss; const auto& pdim_map = kernel_->summary().parallel_dimension_map; - Val* tidx = pdim_map.getRawLoad(ParallelType::TIDx); - Val* tidy = pdim_map.getRawLoad(ParallelType::TIDy); - Val* tidz = pdim_map.getRawLoad(ParallelType::TIDz); + Val* tidx = pdim_map.getRawAsync(ParallelType::TIDx); + Val* tidy = pdim_map.getRawAsync(ParallelType::TIDy); + Val* tidz = pdim_map.getRawAsync(ParallelType::TIDz); int64_t num_threads = tidx->value().as() + tidy->value().as() + tidz->value().as(); NVF_ERROR( diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 86bdc8ae240..6486ad46396 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -275,7 +275,7 @@ Val* ParallelDimensionMap::getRawCompute(ParallelType pt) const { return raw; } -Val* ParallelDimensionMap::getRawLoad(ParallelType pt) const { +Val* ParallelDimensionMap::getRawAsync(ParallelType pt) const { if (warp_specialized_types_.count(pt)) { return IrBuilder::create( getWarpSpecializationPaddedVal(pt), DataType::Index); diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index ceb5218ec24..a2c6d80e292 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -50,10 +50,10 @@ class ParallelDimensionMap { Val* getRawCompute(ParallelType pt) const; //! Get the "load" parallel dimension on the given ParallelType. In case - //! of no warp specialization, this is the same as getRaw(pt). If we are doing - //! warp specialization on pt, the result is 1, because the last of pt is used - //! for loading circular buffer tensors. - Val* getRawLoad(ParallelType pt) const; + //! of without warp specialization, this is the same as getRaw(pt). For warp + //! warp specialization on pt, the result is padded_value. The last part of + //! pt is used for AsyncWarp warp group. + Val* getRawAsync(ParallelType pt) const; //! The padded val ensures that CTA has 128 threads for the AsyncWarp. This //! function returns the padded val for the warp specialized ParallelType.