From f0c9d692d8cc1d02837b717f8e0d5235e734f42f Mon Sep 17 00:00:00 2001 From: mengzcai Date: Wed, 6 Aug 2025 07:39:17 +0000 Subject: [PATCH] [stride when Swizzle] Accept stride_a=0, support --any_stride for bench add swizzle_arguments to gtest: defualt stride & stride0 remove setDefaultSwizzledBatchedStride() --- .../hipblaslt/clients/benchmarks/client.cpp | 2 +- .../hipblaslt/clients/gtest/matmul_gtest.yaml | 5 +++ .../clients/include/testing_matmul.hpp | 22 ++++++------- .../src/include/rocblaslt_mat_utils.hpp | 32 ------------------- .../rocblaslt/src/rocblaslt_auxiliary.cpp | 8 ----- .../rocblaslt/src/rocblaslt_mat.cpp | 14 -------- 6 files changed, 16 insertions(+), 67 deletions(-) diff --git a/projects/hipblaslt/clients/benchmarks/client.cpp b/projects/hipblaslt/clients/benchmarks/client.cpp index e29ab985078..213358f6dc6 100644 --- a/projects/hipblaslt/clients/benchmarks/client.cpp +++ b/projects/hipblaslt/clients/benchmarks/client.cpp @@ -133,7 +133,7 @@ int run_bench_test(Arguments& arg, int64_t min_stride_c = arg.ldc[i] * arg.N[i]; int64_t min_stride_d = arg.ldd[i] * arg.N[i]; int64_t min_stride_e = arg.lde[i] * arg.N[i]; - if(!any_stride && arg.stride_a[i] < min_stride_a && !arg.swizzle_a) + if(!any_stride && arg.stride_a[i] < min_stride_a) { //hipblaslt_cout << "hipblaslt-bench INFO: stride_a < min_stride_a, set stride_a = " // << min_stride_a << std::endl; diff --git a/projects/hipblaslt/clients/gtest/matmul_gtest.yaml b/projects/hipblaslt/clients/gtest/matmul_gtest.yaml index 4f12f251d77..a0bde8a4383 100644 --- a/projects/hipblaslt/clients/gtest/matmul_gtest.yaml +++ b/projects/hipblaslt/clients/gtest/matmul_gtest.yaml @@ -41,6 +41,9 @@ Definitions: - { transA: C, transB: N } - { transA: C, transB: C } + - &swizzleA_strideA_support + - {} + - { stride_a: 0 } Tests: - name: matmul_bad_arg @@ -850,6 +853,7 @@ Tests: category: pre_checkin function: matmul: *real_precisions_swizzleA_support + arguments: *swizzleA_strideA_support M: [128, 129] N: [128, 129] K: [128, 129] @@ -867,6 +871,7 @@ Tests: category: pre_checkin function: matmul: *real_precisions_swizzleA_support + arguments: *swizzleA_strideA_support M: [128, 129] N: [128, 129] K: [128, 129] diff --git a/projects/hipblaslt/clients/include/testing_matmul.hpp b/projects/hipblaslt/clients/include/testing_matmul.hpp index cea2e506f1d..b1b4ac62172 100644 --- a/projects/hipblaslt/clients/include/testing_matmul.hpp +++ b/projects/hipblaslt/clients/include/testing_matmul.hpp @@ -1390,22 +1390,19 @@ void testing_matmul_with_bias(const Arguments& arg, stride_da[i] = stride_a[i]; if(do_swizzle_a) { - //TODO: - // 1. support stride_a is 0 when swizzle - // 2. support --any_stride for hipblaslt-bench when swizzle - // 3. support arbitrary stride_a for both hipblaslt-bench and hipblaslt-test when swizzled - stride_a[i] = lda[i] * A_col[i]; - stride_da[i] = 0; size_t MiM = 16, MiK = 0, __ = 0, PackK = 0; calculateKforSwizzling(TiA, arg, MiK, __, PackK); size_t K_block = MiK * PackK; - size_t stride_swizzle + int64_t stride_swizzle = ((M[i] + MiM - 1) / MiM) * MiM * ((K[i] + K_block - 1) / K_block) * K_block; - if(do_batched[i] && stride_a[i] > 0 && stride_a[i] != lda[i] * A_col[i]) + if(do_batched[i] && stride_a[i] != 0) { - hipblaslt_cerr << "Error stride_a: swizzle_a does not yet support arbitrary stride_a!" << std::endl; - stride_da[i] = stride_a[i]; - stride_swizzle = (size_t)stride_a[i]; + stride_da[i] = stride_swizzle; + + //TODO: support arbitrary stride_a for both hipblaslt-bench and hipblaslt-test when swizzled + if(stride_a[i] != lda[i] * A_col[i] && stride_a[i] != stride_swizzle) + hipblaslt_cerr << "Warning: swizzle_a does not yet support arbitrary stride_a!" + << std::endl; } size_dA[i] = num_batches[i] * stride_swizzle; } @@ -1838,7 +1835,8 @@ void testing_matmul_with_bias(const Arguments& arg, A_col[i], (do_swizzle_a) ? A_row[i] : lda[i], TiA, - (do_swizzle_a) ? A_row[i] * A_col[i] : stride_a[i], + (do_swizzle_a && stride_a[i] != 0) ? A_row[i] * A_col[i] + : stride_a[i], num_batches[i]); #ifdef HIPBLASLT_USE_ROCROLLER } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp index 44255d44c69..9d8dc35e63b 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/rocblaslt_mat_utils.hpp @@ -41,38 +41,6 @@ inline bool isValidOrderForDatatype(hipDataType datatype, hipblasLtOrder_t order return true; } -/******************************************************************************* - * Calculate Default Swizzle-A-Batched-Stride - ******************************************************************************/ -inline void setDefaultSwizzledBatchedStride(const hipblasLtOrder_t& matOrder, - const uint64_t matUnrollDim, - const uint64_t matTileDim, - int64_t& batch_stride) -{ - size_t MiM = 16, MiK = 0, MiKv = 0, PackK = 0; - if(matOrder == HIPBLASLT_ORDER_COL16_4R8) - { - //f16 - MiK = 16; - MiKv = 4; - PackK = 16 / MiKv / 2; - } - else if(matOrder == HIPBLASLT_ORDER_COL16_4R16) - { - //f8 - MiK = 32; - MiKv = 8; - PackK = 16 / MiKv / 1; - } - else - return; - - size_t K_block = MiK * PackK; - //align to k for swizzleK and to m for 16 - batch_stride = ((matTileDim + MiM - 1) / MiM) * MiM * ((matUnrollDim + K_block - 1) / K_block) - * K_block; -} - /******************************************************************************* * Validate Matmul Swizzle Arguments ******************************************************************************/ diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp index 2ddecf7c0b4..c1bc05db44f 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_auxiliary.cpp @@ -80,7 +80,6 @@ inline void assignAlphaBeta1(const rocblaslt_compute_type& compute_type, void* a } } - inline void heuristicResult_copy(rocblaslt_matmul_heuristic_result* heuristicResultsDest, rocblaslt_matmul_heuristic_result* heuristicResultsSrc, size_t& maxWorkSpaceBytes, @@ -329,13 +328,6 @@ RocblasltContractionProblem construct_rocblaslt_problem(rocblaslt_handle bool swizzleA = matA->order != HIPBLASLT_ORDER_COL && matA->order != HIPBLASLT_ORDER_ROW; bool swizzleB = matB->order != HIPBLASLT_ORDER_COL && matB->order != HIPBLASLT_ORDER_ROW; - if(swizzleA && matA->batch_stride == 0) - { - // If batch_stride has never been assigned for swizzle, set it to the default value - // Swizzle is TN: For MatA, ->m (leading Dim) is unrollDim, ->n is tileDim - setDefaultSwizzledBatchedStride(matA->order, matA->m, matA->n, matA->batch_stride); - } - rocblaslt_status isValid = rocblaslt_matmul_valid_args(matmul_descr, dummy_ptr, dummy_ptr, diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp index f4fad921ac1..f939d2f59e2 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp @@ -252,13 +252,6 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl(const rocblaslt_handle bool swizzleA = matA->order != HIPBLASLT_ORDER_COL && matA->order != HIPBLASLT_ORDER_ROW; bool swizzleB = matB->order != HIPBLASLT_ORDER_COL && matB->order != HIPBLASLT_ORDER_ROW; - if(swizzleA && matA->batch_stride == 0) - { - // If batch_stride has never been assigned for swizzle, set it to the default value - // Swizzle is TN: For MatA, ->m (leading Dim) is unrollDim, ->n is tileDim - setDefaultSwizzledBatchedStride(matA->order, matA->m, matA->n, matA->batch_stride); - } - rocblaslt_status isValid = rocblaslt_matmul_valid_args(matmul_descr, A, B, @@ -879,13 +872,6 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl_2(const rocblaslt_handle handle, bool swizzleB = problemtype.order_b != HIPBLASLT_ORDER_COL && problemtype.order_b != HIPBLASLT_ORDER_ROW; - if(swizzleA && batch_stride_a == 0) - { - // If batch_stride has never been assigned for swizzle, set it to the default value - // Swizzle is TN: For MatA, unrollDim is k (Leading Dim), tileDim is m - setDefaultSwizzledBatchedStride(problemtype.order_a, k, m, batch_stride_a); - } - auto status = validateMatmulArgs(m, n, k,