diff --git a/projects/hipblaslt/clients/gtest/matmul_gtest.yaml b/projects/hipblaslt/clients/gtest/matmul_gtest.yaml index d8abe82a4db..bcac68806d1 100644 --- a/projects/hipblaslt/clients/gtest/matmul_gtest.yaml +++ b/projects/hipblaslt/clients/gtest/matmul_gtest.yaml @@ -855,7 +855,6 @@ Tests: K: [128, 129] lda: [129, 2440] batch_count: [1, 5] - stride_a: 0 transA: T transB: N alpha: 1 diff --git a/projects/hipblaslt/clients/include/testing_matmul.hpp b/projects/hipblaslt/clients/include/testing_matmul.hpp index f68355abfb0..4db594fb7ad 100644 --- a/projects/hipblaslt/clients/include/testing_matmul.hpp +++ b/projects/hipblaslt/clients/include/testing_matmul.hpp @@ -1376,28 +1376,34 @@ void testing_matmul_with_bias(const Arguments& arg, do_batched[i] = (arg.batch_count > 1); num_batches[i] = (do_batched[i] ? arg.batch_count : 1); - stride_a[i] = (do_batched[i] && arg.stride_a[i] >= lda[i] * A_col[i]) ? arg.stride_a[i] - : lda[i] * A_col[i]; + stride_a[i] = do_batched[i] ? arg.stride_a[i] : lda[i] * A_col[i]; stride_b[i] = do_batched[i] ? arg.stride_b[i] : ldb[i] * B_col[i]; stride_c[i] = do_batched[i] ? arg.stride_c[i] : ldc[i] * N[i]; stride_d[i] = do_batched[i] ? arg.stride_c[i] : ldd[i] * N[i]; stride_e[i] = do_batched[i] ? arg.stride_e[i] : lde[i] * N[i]; - size_A[i] = stride_a[i] * num_batches[i]; + size_A[i] + = stride_a[i] == 0 ? lda[i] * A_col[i] * num_batches[i] : stride_a[i] * num_batches[i]; size_dA[i] = size_A[i]; stride_da[i] = stride_a[i]; if(arg.swizzle_a && isSwizzleSupported(TiA)) { + //TODO: + // 1. support stride_a is 0 when swizzle + // 2. support --any_stride for hipblaslt-bench when swizzle + // 3. support arbitrary stride_a for both hipblaslt-bench and hipblaslt-test when swizzled + stride_a[i] = lda[i] * A_col[i]; stride_da[i] = 0; size_t MiM = 16, MiK = 0, __ = 0, PackK = 0; calculateKforSwizzling(TiA, arg, MiK, __, PackK); size_t K_block = MiK * PackK; size_t stride_swizzle = ((M[i] + MiM - 1) / MiM) * MiM * ((K[i] + K_block - 1) / K_block) * K_block; - if(do_batched[i] && arg.stride_a[i] >= (int64_t)stride_swizzle) + if(do_batched[i] && stride_a[i] > 0 && stride_a[i] != lda[i] * A_col[i]) { - stride_da[i] = arg.stride_a[i]; - stride_swizzle = (size_t)arg.stride_a[i]; + hipblaslt_cerr << "Error stride_a: swizzle_a does not yet support arbitrary stride_a!" << std::endl; + stride_da[i] = stride_a[i]; + stride_swizzle = (size_t)stride_a[i]; } size_dA[i] = num_batches[i] * stride_swizzle; }