diff --git a/projects/hipblaslt/clients/gtest/matmul_gtest.yaml b/projects/hipblaslt/clients/gtest/matmul_gtest.yaml index bcac68806d1..4f12f251d77 100644 --- a/projects/hipblaslt/clients/gtest/matmul_gtest.yaml +++ b/projects/hipblaslt/clients/gtest/matmul_gtest.yaml @@ -863,6 +863,25 @@ Tests: unit_check: 1 gpu_arch: '942' +- name: matmul_extapi_swizzleA + category: pre_checkin + function: + matmul: *real_precisions_swizzleA_support + M: [128, 129] + N: [128, 129] + K: [128, 129] + lda: [129, 2440] + batch_count: [1, 5] + transA: T + transB: N + use_ext: [1] + use_ext_setproblem: [0, 1] + alpha: 1 + beta: [ 0.0, 2.0 ] + bias_vector: [0, 1] + unit_check: 1 + gpu_arch: '942' + ### GFX12 FP8 - name: matmul_f8_bf8_dst_fp32_gfx12 category: pre_checkin diff --git a/projects/hipblaslt/clients/include/testing_matmul.hpp b/projects/hipblaslt/clients/include/testing_matmul.hpp index 4db594fb7ad..cea2e506f1d 100644 --- a/projects/hipblaslt/clients/include/testing_matmul.hpp +++ b/projects/hipblaslt/clients/include/testing_matmul.hpp @@ -1353,6 +1353,8 @@ void testing_matmul_with_bias(const Arguments& arg, std::vector alpha_in(gemm_count); + bool do_swizzle_a = arg.swizzle_a && isSwizzleSupported(TiA); + // Need to split into two for loop to calculate the rotating buffer int64_t totalRotatingSizeNeeded = 0; for(int i = 0; i < gemm_count; i++) @@ -1386,7 +1388,7 @@ void testing_matmul_with_bias(const Arguments& arg, = stride_a[i] == 0 ? lda[i] * A_col[i] * num_batches[i] : stride_a[i] * num_batches[i]; size_dA[i] = size_A[i]; stride_da[i] = stride_a[i]; - if(arg.swizzle_a && isSwizzleSupported(TiA)) + if(do_swizzle_a) { //TODO: // 1. support stride_a is 0 when swizzle @@ -1491,7 +1493,7 @@ void testing_matmul_with_bias(const Arguments& arg, CHECK_HIPBLASLT_ERROR( hipblasLtMatrixLayoutCreate(&(matD[i]), arg.d_type, M[i], N[i], ldd[i])); - if(arg.swizzle_a && isSwizzleSupported(TiA)) + if(do_swizzle_a) { hipblasLtOrder_t orderA = orderForDatatype(TiA); CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutSetAttribute( @@ -1834,9 +1836,9 @@ void testing_matmul_with_bias(const Arguments& arg, dA[i].buf(), A_row[i], A_col[i], - (arg.swizzle_a) ? A_row[i] : lda[i], + (do_swizzle_a) ? A_row[i] : lda[i], TiA, - (arg.swizzle_a) ? A_row[i] * A_col[i] : stride_a[i], + (do_swizzle_a) ? A_row[i] * A_col[i] : stride_a[i], num_batches[i]); #ifdef HIPBLASLT_USE_ROCROLLER } @@ -1907,7 +1909,7 @@ void testing_matmul_with_bias(const Arguments& arg, CHECK_HIP_ERROR(broadcast(dB[i], block_count)); CHECK_HIP_ERROR(broadcast(dC[i], block_count)); - if(arg.unit_check || arg.norm_check || arg.allclose_check || arg.swizzle_a) + if(arg.unit_check || arg.norm_check || arg.allclose_check || do_swizzle_a) { CHECK_HIP_ERROR(synchronize(hA[i], dA[i], @@ -1916,12 +1918,12 @@ void testing_matmul_with_bias(const Arguments& arg, A_col[i], lda[i], realDataTypeSize(TiA), - arg.swizzle_a)); + do_swizzle_a)); CHECK_HIP_ERROR(synchronize(hB[i], dB[i])); CHECK_HIP_ERROR(synchronize(hC[i], dC[i])); } - if(arg.swizzle_a && isSwizzleSupported(TiA)) + if(do_swizzle_a) { HipHostBuffer tmp(TiA, size_dA[i]); swizzle_tensor_type(tmp, hA[i], TiA, arg, num_batches[i], M[i], K[i], lda[i], false); @@ -2435,6 +2437,12 @@ void testing_matmul_with_bias(const Arguments& arg, extproblemtype.setTypeC(arg.c_type); extproblemtype.setTypeD(arg.d_type); extproblemtype.setTypeCompute(arg.compute_type); + + if(do_swizzle_a) + { + hipblasLtOrder_t orderA = orderForDatatype(TiA); + extproblemtype.setOrderA(orderA); + } } else if(arg.grouped_gemm) { @@ -2545,7 +2553,7 @@ void testing_matmul_with_bias(const Arguments& arg, ldb[0], ldc[0], ldd[0], - stride_a[0], + (do_swizzle_a) ? stride_da[0] : stride_a[0], stride_b[0], stride_c[0], stride_d[0], @@ -2640,7 +2648,7 @@ void testing_matmul_with_bias(const Arguments& arg, ldb, ldc, ldd, - stride_a, + (do_swizzle_a) ? stride_da : stride_a, stride_b, stride_c, stride_d, @@ -2738,7 +2746,7 @@ void testing_matmul_with_bias(const Arguments& arg, ldb[0], ldc[0], ldd[0], - stride_a[0], + (do_swizzle_a) ? stride_da[0] : stride_a[0], stride_b[0], stride_c[0], stride_d[0], @@ -2842,7 +2850,7 @@ void testing_matmul_with_bias(const Arguments& arg, ldb, ldc, ldd, - stride_a, + (do_swizzle_a) ? stride_da : stride_a, stride_b, stride_c, stride_d, @@ -2920,7 +2928,7 @@ void testing_matmul_with_bias(const Arguments& arg, ldb[0], ldc[0], ldd[0], - stride_a[0], + (do_swizzle_a) ? stride_da[0] : stride_a[0], stride_b[0], stride_c[0], stride_d[0], @@ -3009,7 +3017,7 @@ void testing_matmul_with_bias(const Arguments& arg, ldb, ldc, ldd, - stride_a, + (do_swizzle_a) ? stride_da : stride_a, stride_b, stride_c, stride_d, diff --git a/projects/hipblaslt/library/include/hipblaslt/hipblaslt-ext.hpp b/projects/hipblaslt/library/include/hipblaslt/hipblaslt-ext.hpp index 3b7810fad43..9f6dcc6a4ee 100644 --- a/projects/hipblaslt/library/include/hipblaslt/hipblaslt-ext.hpp +++ b/projects/hipblaslt/library/include/hipblaslt/hipblaslt-ext.hpp @@ -134,6 +134,8 @@ namespace hipblaslt_ext HIPBLASLT_EXPORT void setTypeD(hipDataType type); //!< Set the D matrix datatype. HIPBLASLT_EXPORT void setTypeCompute(hipblasComputeType_t type); //!< Set the compute datatype. + HIPBLASLT_EXPORT void setOrderA(hipblasLtOrder_t order); //!< Set the A martix data order. + HIPBLASLT_EXPORT void setOrderB(hipblasLtOrder_t order); //!< Set the B matrix data order. HIPBLASLT_EXPORT hipblasOperation_t getOpA() const; //!< The A matrix transpose. HIPBLASLT_EXPORT hipblasOperation_t getOpB() const; //!< The B matrix transpose. @@ -142,6 +144,8 @@ namespace hipblaslt_ext HIPBLASLT_EXPORT hipDataType getTypeC() const; //!< The C matrix datatype. HIPBLASLT_EXPORT hipDataType getTypeD() const; //!< The D matrix datatype. HIPBLASLT_EXPORT hipblasComputeType_t getTypeCompute() const; //!< The compute datatype. + HIPBLASLT_EXPORT hipblasLtOrder_t getOrderA() const; //!< The A matrix data order. + HIPBLASLT_EXPORT hipblasLtOrder_t getOrderB() const; //!< The B matrix data order. }; [[deprecated("GemmProblemTypeV2 is deprecated, use GemmProblemType instead.")]] diff --git a/projects/hipblaslt/library/src/amd_detail/hipblaslt-ext.cpp b/projects/hipblaslt/library/src/amd_detail/hipblaslt-ext.cpp index 913d179a1ed..f8f0d9d862b 100644 --- a/projects/hipblaslt/library/src/amd_detail/hipblaslt-ext.cpp +++ b/projects/hipblaslt/library/src/amd_detail/hipblaslt-ext.cpp @@ -83,6 +83,8 @@ namespace hipblaslt_ext hipDataType type_c; //!< The C matrix datatype. hipDataType type_d; //!< The D matrix datatype. hipblasComputeType_t type_compute; //!< The compute datatype. + hipblasLtOrder_t order_a; //!< The A martix data layout order + hipblasLtOrder_t order_b; //!< The B martix data layout order }; GemmProblemType::GemmProblemType() @@ -106,6 +108,11 @@ namespace hipblaslt_ext pimpl->type_c = typeC; pimpl->type_d = typeD; pimpl->type_compute = typeCompute; + + // default value of order is COL despite of opA/B, + // currently only swizzle cases use the variables + pimpl->order_a = HIPBLASLT_ORDER_COL; + pimpl->order_b = HIPBLASLT_ORDER_COL; } GemmProblemType::~GemmProblemType() = default; @@ -170,6 +177,16 @@ namespace hipblaslt_ext pimpl->type_compute = type; } + void GemmProblemType::setOrderA(hipblasLtOrder_t order) + { + pimpl->order_a = order; + } + + void GemmProblemType::setOrderB(hipblasLtOrder_t order) + { + pimpl->order_b = order; + } + hipblasOperation_t GemmProblemType::getOpA() const { return pimpl->op_a; @@ -205,6 +222,16 @@ namespace hipblaslt_ext return pimpl->type_compute; } + hipblasLtOrder_t GemmProblemType::getOrderA() const + { + return pimpl->order_a; + } + + hipblasLtOrder_t GemmProblemType::getOrderB() const + { + return pimpl->order_b; + } + class GemmEpilogue::GemmEpilogueImpl { public: @@ -307,12 +334,12 @@ namespace hipblaslt_ext void GemmEpilogue::setAct0(float act0) { - pimpl->act0 = act0; + pimpl->act0 = act0; } void GemmEpilogue::setAct1(float act1) { - pimpl->act1 = act1; + pimpl->act1 = act1; } hipblasLtEpilogue_t GemmEpilogue::getMode() const @@ -372,12 +399,12 @@ namespace hipblaslt_ext float GemmEpilogue::getAct0() { - return pimpl->act0; + return pimpl->act0; } float GemmEpilogue::getAct1() { - return pimpl->act1; + return pimpl->act1; } class GemmTuning::GemmTuningImpl diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/include/rocblaslt-types.h b/projects/hipblaslt/library/src/amd_detail/rocblaslt/include/rocblaslt-types.h index f730c6b04d5..444298805bf 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/include/rocblaslt-types.h +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/include/rocblaslt-types.h @@ -634,6 +634,8 @@ namespace rocblaslt hipDataType type_c; hipDataType type_d; rocblaslt_compute_type type_compute; + hipblasLtOrder_t order_a; + hipblasLtOrder_t order_b; }; class RocGemmEpilogueV2 diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp index 9d8dc35e63b..44255d44c69 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp @@ -41,6 +41,38 @@ inline bool isValidOrderForDatatype(hipDataType datatype, hipblasLtOrder_t order return true; } +/******************************************************************************* + * Calculate Default Swizzle-A-Batched-Stride + ******************************************************************************/ +inline void setDefaultSwizzledBatchedStride(const hipblasLtOrder_t& matOrder, + const uint64_t matUnrollDim, + const uint64_t matTileDim, + int64_t& batch_stride) +{ + size_t MiM = 16, MiK = 0, MiKv = 0, PackK = 0; + if(matOrder == HIPBLASLT_ORDER_COL16_4R8) + { + //f16 + MiK = 16; + MiKv = 4; + PackK = 16 / MiKv / 2; + } + else if(matOrder == HIPBLASLT_ORDER_COL16_4R16) + { + //f8 + MiK = 32; + MiKv = 8; + PackK = 16 / MiKv / 1; + } + else + return; + + size_t K_block = MiK * PackK; + //align to k for swizzleK and to m for 16 + batch_stride = ((matTileDim + MiM - 1) / MiM) * MiM * ((matUnrollDim + K_block - 1) / K_block) + * K_block; +} + /******************************************************************************* * Validate Matmul Swizzle Arguments ******************************************************************************/ diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp index 5cbe959bce4..2ddecf7c0b4 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp @@ -80,32 +80,6 @@ inline void assignAlphaBeta1(const rocblaslt_compute_type& compute_type, void* a } } -inline void setDefaultSwizzledBatchedStride(const rocblaslt_matrix_layout& matLayout, - int64_t& batch_stride) -{ - size_t MiM = 16, MiK = 0, MiKv = 0, PackK = 0; - if(matLayout->order == HIPBLASLT_ORDER_COL16_4R8) - { - //f16 - MiK = 16; - MiKv = 4; - PackK = 16 / MiKv / 2; - } - else if(matLayout->order == HIPBLASLT_ORDER_COL16_4R16) - { - //f8 - MiK = 32; - MiKv = 8; - PackK = 16 / MiKv / 1; - } - else - return; - - size_t K_block = MiK * PackK; - //align to k for swizzleK and to m for 16 - batch_stride = ((matLayout->n + MiM - 1) / MiM) * MiM * ((matLayout->m + K_block - 1) / K_block) - * K_block; -} inline void heuristicResult_copy(rocblaslt_matmul_heuristic_result* heuristicResultsDest, rocblaslt_matmul_heuristic_result* heuristicResultsSrc, @@ -357,8 +331,9 @@ RocblasltContractionProblem construct_rocblaslt_problem(rocblaslt_handle if(swizzleA && matA->batch_stride == 0) { - //If batch_stride has never been assigned for swizzle, set it to the default value - setDefaultSwizzledBatchedStride(matA, matA->batch_stride); + // If batch_stride has never been assigned for swizzle, set it to the default value + // Swizzle is TN: For MatA, ->m (leading Dim) is unrollDim, ->n is tileDim + setDefaultSwizzledBatchedStride(matA->order, matA->m, matA->n, matA->batch_stride); } rocblaslt_status isValid = rocblaslt_matmul_valid_args(matmul_descr, diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp index 10ca8fc1324..f4fad921ac1 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp @@ -67,6 +67,8 @@ rocblaslt_status rocblaslt_matmul_impl(const rocblaslt_handle handle, bool gradient = false; bool swizzleA = matA->order != HIPBLASLT_ORDER_COL && matA->order != HIPBLASLT_ORDER_ROW; bool swizzleB = matB->order != HIPBLASLT_ORDER_COL && matB->order != HIPBLASLT_ORDER_ROW; + // no need to do batched swizzle stride check here, it was done in set-problem stage + rocblaslt_status isValid = rocblaslt_matmul_valid_args(matmul_descr, A, B, @@ -249,6 +251,14 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl(const rocblaslt_handle bool gradient = false; bool swizzleA = matA->order != HIPBLASLT_ORDER_COL && matA->order != HIPBLASLT_ORDER_ROW; bool swizzleB = matB->order != HIPBLASLT_ORDER_COL && matB->order != HIPBLASLT_ORDER_ROW; + + if(swizzleA && matA->batch_stride == 0) + { + // If batch_stride has never been assigned for swizzle, set it to the default value + // Swizzle is TN: For MatA, ->m (leading Dim) is unrollDim, ->n is tileDim + setDefaultSwizzledBatchedStride(matA->order, matA->m, matA->n, matA->batch_stride); + } + rocblaslt_status isValid = rocblaslt_matmul_valid_args(matmul_descr, A, B, @@ -864,6 +874,18 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl_2(const rocblaslt_handle handle, bool strided_batch = true; bool grouped_gemm = false; + bool swizzleA + = problemtype.order_a != HIPBLASLT_ORDER_COL && problemtype.order_a != HIPBLASLT_ORDER_ROW; + bool swizzleB + = problemtype.order_b != HIPBLASLT_ORDER_COL && problemtype.order_b != HIPBLASLT_ORDER_ROW; + + if(swizzleA && batch_stride_a == 0) + { + // If batch_stride has never been assigned for swizzle, set it to the default value + // Swizzle is TN: For MatA, unrollDim is k (Leading Dim), tileDim is m + setDefaultSwizzledBatchedStride(problemtype.order_a, k, m, batch_stride_a); + } + auto status = validateMatmulArgs(m, n, k, @@ -1009,9 +1031,8 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl_2(const rocblaslt_handle handle, rocEpilogue.act1, 0, handle->Synchronizer, - /*TODO: support C++ API */ - false, - false}; + swizzleA, + swizzleB}; return gemmCreate(problem, gemmData, gemmCount); } @@ -1128,6 +1149,8 @@ rocblaslt_status rocblaslt_groupedgemm_create_cpp_impl_2(const rocblaslt_handle hipDataType type_b = problemtype[0].type_b; hipDataType type_c = problemtype[0].type_c; hipDataType type_d = problemtype[0].type_d; + hipblasLtOrder_t orderA = problemtype[0].order_a; + hipblasLtOrder_t orderB = problemtype[0].order_b; std::vector A_vec, B_vec, C_vec, alpha_vec, beta_vec; std::vector D_vec, E_vec, amaxD_vec; @@ -1266,6 +1289,9 @@ rocblaslt_status rocblaslt_groupedgemm_create_cpp_impl_2(const rocblaslt_handle bool strided_batch = true; bool grouped_gemm = true; + bool swizzleA = orderA != HIPBLASLT_ORDER_COL && orderA != HIPBLASLT_ORDER_ROW; + bool swizzleB = orderB != HIPBLASLT_ORDER_COL && orderB != HIPBLASLT_ORDER_ROW; + std::vector problems; for(int i = 0; i < m.size(); i++) { @@ -1334,9 +1360,8 @@ rocblaslt_status rocblaslt_groupedgemm_create_cpp_impl_2(const rocblaslt_handle rocEpilogue[iIdx].act1, 0, handle->Synchronizer, - /*TODO: support grouped gemm */ - false, - false}); + swizzleA, + swizzleB}); } return groupedGemmCreate(problems, gemmData, gemmCount); }