diff --git a/src/layout/layout.cc b/src/layout/layout.cc index ccccc903d..8b0d37cb8 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -874,7 +874,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { element_size); }) .def("tl.make_linear_layout", - [](Array shape) { return makeLinearLayout(shape); }); + [](Array shape) { return makeLinearLayout(shape); }) + .def("tl.make_gemm_fragment_8x8", []() { return makeGemmFragment8x8(); }) + .def("tl.make_gemm_fragment_8x8_transposed", + []() { return makeGemmFragment8x8Transposed(); }); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/op/copy.cc b/src/op/copy.cc index 711d87afc..2c01db367 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -773,6 +773,20 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, Buffer shared_tensor = is_ldmatrix ? src : dst; Buffer local_tensor = is_ldmatrix ? dst : src; + Array local_region = is_ldmatrix ? src_range : dst_range; + bool is_full_range = true; + for (size_t i = 0; i < local_region.size(); i++) { + if (!analyzer->CanProveEqual(local_region[i]->extent, + local_tensor->shape[i])) { + is_full_range = false; + break; + } + } + if (!is_full_range) { + // ldmatrix/stmatrix can only support full range, will be fallback to + // normal copy + return LowerNormalCopy(T, analyzer); + } Array local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0); Fragment local_layout = Downcast(T.layout_map[local_tensor]); @@ -787,14 +801,6 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, } Array shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1); - Array shared_indices_transformed = shared_indices; - Layout shared_layout; - if (T.buffer_remap.count(shared_tensor)) { - shared_layout = T.layout_map[shared_tensor]; - shared_tensor = T.buffer_remap[shared_tensor]; - shared_indices_transformed = shared_layout->Forward(shared_indices); - } - // Check local_layout follows 8x8 layout // LDSM/STSM instructions require 8x8 matrix fragment layout // This matches the warp-level matrix multiplication pattern used in tensor @@ -834,8 +840,7 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, // be fallback to normal copy return LowerNormalCopy(T, analyzer); } - PrimExpr flattened_indice = - shared_tensor.OffsetOf(shared_indices_transformed).back(); + PrimExpr flattened_indice = shared_tensor.OffsetOf(shared_indices).back(); if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var, loop_vars.back()->dom->extent, 8, analyzer)) { // TMA ldmatrix/stmatrix cannot support non-16 bytes continuous layout, will @@ -853,11 +858,16 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, } // Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1 + // now, local_tensor is local instead of shared. PrimExpr extent = local_tensor->shape[0]; int num = 1; if (analyzer->CanProveEqual(FloorMod(extent, 8), 0)) + // 16x16 -> full warp, we use x4, for 32 threads in a warp, each thread can + // hold 4 elements num = 4; else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0)) + // 8x16 -> half warp, we use x2, for 32 threads in a warp, each thread can + // hold 2 elements num = 2; Array args; @@ -875,18 +885,21 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, Layout inv = local_layout->Inverse(); Array shared_coords; PrimExpr warp = FloorDiv(T.thread_var, 32) * 32; - if (!is_transposed) - shared_coords = inv->Forward( - {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num), - warp + FloorMod(T.thread_var, 8) * 4}); - else - shared_coords = inv->Forward( - {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) + - FloorMod(T.thread_var, 2), - warp + FloorDiv(FloorMod(T.thread_var, 8), 2)}); + if (!is_transposed) { + auto local_index = analyzer->Simplify( + local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num)); + auto thread_index = + analyzer->Simplify(warp + FloorMod(T.thread_var, 8) * 4); + shared_coords = inv->Forward({local_index, thread_index}); + } else { + auto local_index = analyzer->Simplify( + local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) + + FloorMod(T.thread_var, 2)); + auto thread_index = + analyzer->Simplify(warp + FloorDiv(FloorMod(T.thread_var, 8), 2)); + shared_coords = inv->Forward({local_index, thread_index}); + } shared_coords.pop_back(); // remove rep - if (shared_layout.defined()) - shared_coords = shared_layout->Forward(shared_coords); PrimExpr shared_addr = shared_tensor.access_ptr( is_ldmatrix ? 1 : 2, DataType::Handle(), 1, shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num)); diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 5fa90ba95..3679f0b62 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -667,7 +667,28 @@ ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer, << buffer << "` of layout " << src_layout->DebugOutput() << '\n'; Fragment result; + + // Check if access indices match loop vars AND loop ranges match layout input + // shape. We can only use src_layout directly when both conditions are met. + // Otherwise, if the loop is a sub-region of the buffer (e.g., loop is 4x128 + // but buffer layout is 64x128), using the full layout would cause index + // out-of-bounds. + bool can_use_src_layout_directly = false; if (IsCommonAccessIndice(buffer)) { + auto input_shape = src_layout->InputShape(); + if (input_shape.size() == loop_vars_.size()) { + can_use_src_layout_directly = true; + for (size_t i = 0; i < loop_vars_.size(); i++) { + if (!analyzer_.CanProveEqual(loop_vars_[i]->dom->extent, + input_shape[i])) { + can_use_src_layout_directly = false; + break; + } + } + } + } + + if (can_use_src_layout_directly) { result = src_layout; } else { Var rep; diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 5f2cf1a4c..3db3f9aa4 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -202,7 +202,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { arith::Analyzer analyzer; LowerTileOpPass substituter(&analyzer); // Trace the buffer map for tvm_access_ptr - substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end()); + // Insert both handle var and data var as keys for lookup + for (const auto &[param_var, buffer] : f->buffer_map) { + substituter.buffer_map_.insert( + {param_var, buffer}); // handle key (e.g., dQ_handle) + substituter.buffer_map_.insert( + {buffer->data, buffer}); // data key (e.g., dQ) + } for (const auto &[_, buffer] : f->buffer_map) { substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); } @@ -299,7 +305,91 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { << "Invalid access ptr for permuted layout: " << access_ptr; auto access_ptr_call = Downcast(access_ptr); if (access_ptr_call->op.same_as(builtin::tvm_access_ptr())) { - LOG(FATAL) << "Transformation for tvm_access_ptr is not implemented yet"; + // tvm_access_ptr format: (dtype, data, offset, extent, rw_mask) + auto buffer_var = Downcast(access_ptr_call->args[1]); + + // Find original buffer from buffer_map_ using buffer_var + auto it = buffer_map_.find(buffer_var); + if (it == buffer_map_.end()) { + // If not found, buffer_var might be a new var after remap + // Do reverse lookup in var_remap_ + for (const auto &[old_var, new_var] : var_remap_) { + if (new_var.same_as(buffer_var)) { + it = buffer_map_.find(old_var); + break; + } + } + } + + if (it == buffer_map_.end()) { + return result; // Buffer not found, no transformation needed + } + + Buffer original_buffer = it->second; + + // Check if this buffer has a layout + if (!layout_map_.count(original_buffer)) { + return result; // No layout, no transformation needed + } + + Layout layout = layout_map_[original_buffer]; + Buffer new_buffer = buffer_remap_[original_buffer]; + + // In TMA context, swizzle is encoded in TMA descriptor parameters + // rather than in memory indices, so we only update buffer data + // without recomputing indices. + if (in_tma_context_) { + Array new_args = access_ptr_call->args; + new_args.Set(1, new_buffer->data); // Only replace data var + layout_remap_.Set(new_buffer, layout); + result.rewritten = true; + result.expr = + Call(access_ptr_call->dtype, access_ptr_call->op, new_args, + access_ptr_call->annotations, access_ptr_call->span); + return result; + } + + // Get the offset from tvm_access_ptr args[2] + PrimExpr elem_offset = access_ptr_call->args[2]; + if (offset.defined()) { + elem_offset = elem_offset + offset.value(); + } + // Get original and new buffer shapes + Array old_shape = original_buffer->shape; + Array new_shape = new_buffer->shape; + // Convert linear offset to multi-dimensional indices + Array multi_dim_indices; + PrimExpr remaining_offset = elem_offset; + for (int i = static_cast(old_shape.size()) - 1; i >= 0; --i) { + multi_dim_indices.insert( + multi_dim_indices.begin(), + analyzer_->Simplify(floormod(remaining_offset, old_shape[i]))); + remaining_offset = floordiv(remaining_offset, old_shape[i]); + } + // Apply layout transformation + auto forward_indices = layout->Forward(multi_dim_indices); + PrimExpr new_offset = 0; + PrimExpr stride_offset = 1; + for (int i = static_cast(new_shape.size()) - 1; i >= 0; --i) { + new_offset += forward_indices[i] * stride_offset; + stride_offset *= new_shape[i]; + } + // Verify that access is within a single tile + ICHECK(is_zero(analyzer_->Simplify(remaining_offset))) + << "Access offset exceeds tile bounds, remaining_offset: " + << remaining_offset; + new_offset = analyzer_->Simplify(new_offset); + Array new_indices; + layout_remap_.Set(new_buffer, layout); + + // Build new tvm_access_ptr call with new buffer and offset + Array new_args = access_ptr_call->args; + new_args.Set(1, new_buffer->data); // Replace data var + new_args.Set(2, new_offset); // Replace offset + result.rewritten = true; + result.expr = Call(access_ptr_call->dtype, access_ptr_call->op, new_args, + access_ptr_call->annotations, access_ptr_call->span); + return result; } else if (access_ptr_call->op.same_as(builtin::address_of())) { Optional resolved = ResolveBufferLoad(access_ptr_call->args[0]); ICHECK(resolved.defined()) @@ -322,6 +412,31 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { << "but got indices size: " << indices.size() << " and shape size: " << old_shape.size(); + Buffer remap_key = FindRemapBuffer(load->buffer).value_or(load->buffer); + Optional layout = FindLayout(remap_key); + if (!layout.defined() || !buffer_map_.count(remap_key->data)) { + return result; + } + auto new_buffer = buffer_remap_.count(remap_key) + ? buffer_remap_[remap_key] + : load->buffer; + auto new_shape = new_buffer->shape; + + // In TMA context, swizzle is encoded in TMA descriptor parameters + // rather than in memory indices, so we only update buffer data + // without recomputing indices. + if (in_tma_context_) { + Array new_args = {BufferLoad(new_buffer, indices)}; + if (buffer_remap_.count(remap_key)) { + layout_remap_.Set(new_buffer, layout.value()); + } + result.rewritten = true; + result.expr = + Call(access_ptr_call->dtype, access_ptr_call->op, new_args, + access_ptr_call->annotations, access_ptr_call->span); + return result; + } + PrimExpr elem_offset = 0; PrimExpr stride = 1; @@ -333,16 +448,6 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { PrimExpr smem_offset = elem_offset + (offset.defined() ? offset.value() : 0); - Buffer remap_key = FindRemapBuffer(load->buffer).value_or(load->buffer); - Optional layout = FindLayout(remap_key); - if (!layout.defined() || !buffer_map_.count(remap_key->data)) { - return result; - } - auto new_buffer = buffer_remap_.count(remap_key) - ? buffer_remap_[remap_key] - : load->buffer; - auto new_shape = new_buffer->shape; - auto buffer_map_iter = buffer_map_.find(Downcast(remap_key->data)); int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second); @@ -442,55 +547,166 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const tir::CallNode *op) final { - if ((!has_tma_) && (op->op.same_as(tl::tma_load()) || - op->op.same_as(tl::tma_load_im2col()) || - op->op.same_as(tl::tma_store()))) { + if (op->op.same_as(tl::tma_load()) || + op->op.same_as(tl::tma_load_im2col()) || + op->op.same_as(tl::tma_store())) { + // skip tma related calls, as they were transformed implicitly. has_tma_ = true; - } - Array ptx_instructions = {builtin::ptx_ldmatrix(), - builtin::mma_store()}; - - if (std::find(ptx_instructions.begin(), ptx_instructions.end(), op->op) == - ptx_instructions.end()) { + in_tma_context_ = true; auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + in_tma_context_ = false; return call; - } else { - is_ptx_ = true; } - // Rewrite from/to shared or shared.dyn to/from local - auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); - if (call->op.same_as(builtin::ptx_ldmatrix())) { + + if (is_ptx_) { + return Downcast(op); + } + + // Handle ptx_ldmatrix + if (op->op.same_as(builtin::ptx_ldmatrix())) { + is_ptx_ = true; + auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + is_ptx_ = false; // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset) // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask) // or T.address_of(buffer, offset) PrimExpr access_ptr = call->args[5]; PrimExpr smem_offset = call->args[6]; - Call address_of_call = Downcast(access_ptr); - if (!address_of_call->op.same_as(builtin::address_of())) { + Call access_ptr_call = Downcast(access_ptr); + + // Handle both tvm_access_ptr and address_of + if (access_ptr_call->op.same_as(builtin::tvm_access_ptr())) { + auto new_access_ptr = + HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype); + if (new_access_ptr.rewritten) { + auto new_call = call.CopyOnWrite(); + new_call->args.Set(5, new_access_ptr.expr); + new_call->args.Set(6, IntImm(smem_offset->dtype, 0)); + } + } else if (access_ptr_call->op.same_as(builtin::address_of())) { + Optional resolved = + ResolveBufferLoad(access_ptr_call->args[0]); + ICHECK(resolved.defined()) + << "Invalid address_of argument for permuted layout: " + << access_ptr_call->args[0]; + PrimExpr load_expr = resolved.value(); + if (!load_expr.same_as(access_ptr_call->args[0])) { + auto call_node = call.CopyOnWrite(); + call_node->args.Set( + 5, Call(access_ptr_call->dtype, access_ptr_call->op, {load_expr}, + access_ptr_call->annotations, access_ptr_call->span)); + access_ptr_call = Downcast(call->args[5]); + access_ptr = call->args[5]; + } + auto new_access_ptr = + HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype); + if (new_access_ptr.rewritten) { + auto new_call = call.CopyOnWrite(); + new_call->args.Set(5, new_access_ptr.expr); + new_call->args.Set(6, IntImm(smem_offset->dtype, 0)); + } + } else { LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr; } - Optional resolved = ResolveBufferLoad(address_of_call->args[0]); - ICHECK(resolved.defined()) - << "Invalid address_of argument for permuted layout: " - << address_of_call->args[0]; - PrimExpr load_expr = resolved.value(); - if (!load_expr.same_as(address_of_call->args[0])) { - auto call_node = call.CopyOnWrite(); - call_node->args.Set(5, Call(address_of_call->dtype, address_of_call->op, - {load_expr}, address_of_call->annotations, - address_of_call->span)); - address_of_call = Downcast(call->args[5]); - access_ptr = call->args[5]; - } - BufferLoad load = Downcast(address_of_call->args[0]); - auto new_access_ptr = - HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype); - if (new_access_ptr.rewritten) { - auto new_call = call.CopyOnWrite(); - new_call->args.Set(5, new_access_ptr.expr); - new_call->args.Set(6, IntImm(smem_offset->dtype, 0)); + return call; + } + + if (op->op.same_as(tl::ptx_ldmatrix())) { + is_ptx_ = true; + auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + is_ptx_ = false; + // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset) + // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask) + // or T.address_of(buffer, offset) + PrimExpr access_ptr = call->args[2]; + Call access_ptr_call = Downcast(access_ptr); + + // Handle both tvm_access_ptr and address_of + if (access_ptr_call->op.same_as(builtin::tvm_access_ptr())) { + auto new_access_ptr = + HandleAccessPtrAndOffset(access_ptr, std::nullopt, call->dtype); + if (new_access_ptr.rewritten) { + auto new_call = call.CopyOnWrite(); + new_call->args.Set(2, new_access_ptr.expr); + } + } else if (access_ptr_call->op.same_as(builtin::address_of())) { + Optional resolved = + ResolveBufferLoad(access_ptr_call->args[0]); + ICHECK(resolved.defined()) + << "Invalid address_of argument for permuted layout: " + << access_ptr_call->args[0]; + PrimExpr load_expr = resolved.value(); + if (!load_expr.same_as(access_ptr_call->args[0])) { + auto call_node = call.CopyOnWrite(); + call_node->args.Set( + 2, Call(access_ptr_call->dtype, access_ptr_call->op, {load_expr}, + access_ptr_call->annotations, access_ptr_call->span)); + access_ptr_call = Downcast(call->args[2]); + access_ptr = call->args[2]; + } + auto new_access_ptr = + HandleAccessPtrAndOffset(access_ptr, std::nullopt, call->dtype); + if (new_access_ptr.rewritten) { + auto new_call = call.CopyOnWrite(); + new_call->args.Set(2, new_access_ptr.expr); + } + } else { + LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr; } - } else if (call->op.same_as(builtin::mma_store())) { + return call; + } + + // Handle tl::ptx_stmatrix + if (op->op.same_as(tl::ptx_stmatrix())) { + is_ptx_ = true; + auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + is_ptx_ = false; + // form: T.ptx_stmatrix(trans, num, smem_ptr, value0, value1, ...) + // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask) + // or T.address_of(buffer, offset) + PrimExpr access_ptr = call->args[2]; + Call access_ptr_call = Downcast(access_ptr); + + // Handle both tvm_access_ptr and address_of + if (access_ptr_call->op.same_as(builtin::tvm_access_ptr())) { + auto new_access_ptr = + HandleAccessPtrAndOffset(access_ptr, std::nullopt, call->dtype); + if (new_access_ptr.rewritten) { + auto new_call = call.CopyOnWrite(); + new_call->args.Set(2, new_access_ptr.expr); + } + } else if (access_ptr_call->op.same_as(builtin::address_of())) { + Optional resolved = + ResolveBufferLoad(access_ptr_call->args[0]); + ICHECK(resolved.defined()) + << "Invalid address_of argument for permuted layout: " + << access_ptr_call->args[0]; + PrimExpr load_expr = resolved.value(); + if (!load_expr.same_as(access_ptr_call->args[0])) { + auto call_node = call.CopyOnWrite(); + call_node->args.Set( + 2, Call(access_ptr_call->dtype, access_ptr_call->op, {load_expr}, + access_ptr_call->annotations, access_ptr_call->span)); + access_ptr_call = Downcast(call->args[2]); + access_ptr = call->args[2]; + } + auto new_access_ptr = + HandleAccessPtrAndOffset(access_ptr, std::nullopt, call->dtype); + if (new_access_ptr.rewritten) { + auto new_call = call.CopyOnWrite(); + new_call->args.Set(2, new_access_ptr.expr); + } + } else { + LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr; + } + return call; + } + + // Handle mma_store + if (op->op.same_as(builtin::mma_store())) { + is_ptx_ = true; + auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + is_ptx_ = false; // because we will directly store result to Buffer instead of calling // mma_store now auto access_ptr = call->args[2]; @@ -500,10 +716,22 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { auto new_call = call.CopyOnWrite(); new_call->args.Set(2, new_access_ptr.expr); } - } else { - LOG(FATAL) << "Invalid call node: " << call; + return call; + } + + // Handle standalone tvm_access_ptr calls with layout transformation + if (op->op.same_as(builtin::tvm_access_ptr())) { + auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); + auto new_access_ptr = + HandleAccessPtrAndOffset(call, std::nullopt, call->dtype); + if (new_access_ptr.rewritten) { + return new_access_ptr.expr; + } + return call; } - is_ptx_ = false; + + // Default: visit normally + auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); return call; } @@ -852,6 +1080,11 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { std::unordered_map buffer_map_; Map var_remap_; bool has_tma_{false}; + // Flag to indicate we are inside a TMA context (tma_load, tma_load_im2col, + // tma_store). When true, HandleAccessPtrAndOffset only updates buffer data + // without recomputing indices, since swizzle is encoded in TMA descriptor + // parameters rather than in memory indices. + bool in_tma_context_{false}; }; namespace transform { diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index 777802d2c..a46318aed 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -12,5 +12,7 @@ make_half_bank_swizzled_layout, # noqa: F401 make_quarter_bank_swizzled_layout, # noqa: F401 make_linear_layout, # noqa: F401 + make_gemm_fragment_8x8, # noqa: F401 + make_gemm_fragment_8x8_transposed, # noqa: F401 ) from .gemm_sp import make_cutlass_metadata_layout # noqa: F401 diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index beaf3b6b5..06ad000e3 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -196,3 +196,28 @@ def make_linear_layout(buffer_or_load_or_region: Buffer | BufferLoad | BufferReg """ _, shape, _ = _get_buffer_info(buffer_or_load_or_region) return _ffi_api.make_linear_layout(list(shape)) + + +def make_gemm_fragment_8x8(): + """ + Create a standard 8x8 GEMM fragment layout for ldmatrix/stmatrix. + + This layout matches the warp-level matrix multiplication pattern used in tensor cores. + + Returns: + Fragment: An 8x8 fragment layout + """ + return _ffi_api.make_gemm_fragment_8x8() + + +def make_gemm_fragment_8x8_transposed(): + """ + Create a transposed 8x8 GEMM fragment layout for ldmatrix/stmatrix. + + This layout is the transposed version of make_gemm_fragment_8x8, useful for + different access patterns in matrix operations. + + Returns: + Fragment: A transposed 8x8 fragment layout + """ + return _ffi_api.make_gemm_fragment_8x8_transposed()