Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion projects/hipblaslt/clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ int run_bench_test(Arguments& arg,
int64_t min_stride_c = arg.ldc[i] * arg.N[i];
int64_t min_stride_d = arg.ldd[i] * arg.N[i];
int64_t min_stride_e = arg.lde[i] * arg.N[i];
if(!any_stride && arg.stride_a[i] < min_stride_a && !arg.swizzle_a)
if(!any_stride && arg.stride_a[i] < min_stride_a)
{
//hipblaslt_cout << "hipblaslt-bench INFO: stride_a < min_stride_a, set stride_a = "
// << min_stride_a << std::endl;
Expand Down
5 changes: 5 additions & 0 deletions projects/hipblaslt/clients/gtest/matmul_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ Definitions:
- { transA: C, transB: N }
- { transA: C, transB: C }

- &swizzleA_strideA_support
- {}
- { stride_a: 0 }

Tests:
- name: matmul_bad_arg
Expand Down Expand Up @@ -850,6 +853,7 @@ Tests:
category: pre_checkin
function:
matmul: *real_precisions_swizzleA_support
arguments: *swizzleA_strideA_support
M: [128, 129]
N: [128, 129]
K: [128, 129]
Expand All @@ -867,6 +871,7 @@ Tests:
category: pre_checkin
function:
matmul: *real_precisions_swizzleA_support
arguments: *swizzleA_strideA_support
M: [128, 129]
N: [128, 129]
K: [128, 129]
Expand Down
22 changes: 10 additions & 12 deletions projects/hipblaslt/clients/include/testing_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1390,22 +1390,19 @@ void testing_matmul_with_bias(const Arguments& arg,
stride_da[i] = stride_a[i];
if(do_swizzle_a)
{
//TODO:
// 1. support stride_a is 0 when swizzle
// 2. support --any_stride for hipblaslt-bench when swizzle
// 3. support arbitrary stride_a for both hipblaslt-bench and hipblaslt-test when swizzled
stride_a[i] = lda[i] * A_col[i];
stride_da[i] = 0;
size_t MiM = 16, MiK = 0, __ = 0, PackK = 0;
calculateKforSwizzling(TiA, arg, MiK, __, PackK);
size_t K_block = MiK * PackK;
size_t stride_swizzle
int64_t stride_swizzle
= ((M[i] + MiM - 1) / MiM) * MiM * ((K[i] + K_block - 1) / K_block) * K_block;
if(do_batched[i] && stride_a[i] > 0 && stride_a[i] != lda[i] * A_col[i])
if(do_batched[i] && stride_a[i] != 0)
{
hipblaslt_cerr << "Error stride_a: swizzle_a does not yet support arbitrary stride_a!" << std::endl;
stride_da[i] = stride_a[i];
stride_swizzle = (size_t)stride_a[i];
stride_da[i] = stride_swizzle;

//TODO: support arbitrary stride_a for both hipblaslt-bench and hipblaslt-test when swizzled
if(stride_a[i] != lda[i] * A_col[i] && stride_a[i] != stride_swizzle)
hipblaslt_cerr << "Warning: swizzle_a does not yet support arbitrary stride_a!"
<< std::endl;
}
size_dA[i] = num_batches[i] * stride_swizzle;
}
Expand Down Expand Up @@ -1838,7 +1835,8 @@ void testing_matmul_with_bias(const Arguments& arg,
A_col[i],
(do_swizzle_a) ? A_row[i] : lda[i],
TiA,
(do_swizzle_a) ? A_row[i] * A_col[i] : stride_a[i],
(do_swizzle_a && stride_a[i] != 0) ? A_row[i] * A_col[i]
: stride_a[i],
num_batches[i]);
#ifdef HIPBLASLT_USE_ROCROLLER
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,38 +41,6 @@ inline bool isValidOrderForDatatype(hipDataType datatype, hipblasLtOrder_t order
return true;
}

/*******************************************************************************
* Calculate Default Swizzle-A-Batched-Stride
******************************************************************************/
inline void setDefaultSwizzledBatchedStride(const hipblasLtOrder_t& matOrder,
const uint64_t matUnrollDim,
const uint64_t matTileDim,
int64_t& batch_stride)
{
size_t MiM = 16, MiK = 0, MiKv = 0, PackK = 0;
if(matOrder == HIPBLASLT_ORDER_COL16_4R8)
{
//f16
MiK = 16;
MiKv = 4;
PackK = 16 / MiKv / 2;
}
else if(matOrder == HIPBLASLT_ORDER_COL16_4R16)
{
//f8
MiK = 32;
MiKv = 8;
PackK = 16 / MiKv / 1;
}
else
return;

size_t K_block = MiK * PackK;
//align to k for swizzleK and to m for 16
batch_stride = ((matTileDim + MiM - 1) / MiM) * MiM * ((matUnrollDim + K_block - 1) / K_block)
* K_block;
}

/*******************************************************************************
* Validate Matmul Swizzle Arguments
******************************************************************************/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ inline void assignAlphaBeta1(const rocblaslt_compute_type& compute_type, void* a
}
}


inline void heuristicResult_copy(rocblaslt_matmul_heuristic_result* heuristicResultsDest,
rocblaslt_matmul_heuristic_result* heuristicResultsSrc,
size_t& maxWorkSpaceBytes,
Expand Down Expand Up @@ -329,13 +328,6 @@ RocblasltContractionProblem construct_rocblaslt_problem(rocblaslt_handle
bool swizzleA = matA->order != HIPBLASLT_ORDER_COL && matA->order != HIPBLASLT_ORDER_ROW;
bool swizzleB = matB->order != HIPBLASLT_ORDER_COL && matB->order != HIPBLASLT_ORDER_ROW;

if(swizzleA && matA->batch_stride == 0)
{
// If batch_stride has never been assigned for swizzle, set it to the default value
// Swizzle is TN: For MatA, ->m (leading Dim) is unrollDim, ->n is tileDim
setDefaultSwizzledBatchedStride(matA->order, matA->m, matA->n, matA->batch_stride);
}

rocblaslt_status isValid = rocblaslt_matmul_valid_args(matmul_descr,
Comment thread
mengzcai marked this conversation as resolved.
Outdated
dummy_ptr,
dummy_ptr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,6 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl(const rocblaslt_handle
bool swizzleA = matA->order != HIPBLASLT_ORDER_COL && matA->order != HIPBLASLT_ORDER_ROW;
bool swizzleB = matB->order != HIPBLASLT_ORDER_COL && matB->order != HIPBLASLT_ORDER_ROW;

if(swizzleA && matA->batch_stride == 0)
{
// If batch_stride has never been assigned for swizzle, set it to the default value
// Swizzle is TN: For MatA, ->m (leading Dim) is unrollDim, ->n is tileDim
setDefaultSwizzledBatchedStride(matA->order, matA->m, matA->n, matA->batch_stride);
}

rocblaslt_status isValid = rocblaslt_matmul_valid_args(matmul_descr,
A,
B,
Expand Down Expand Up @@ -879,13 +872,6 @@ rocblaslt_status rocblaslt_gemm_create_cpp_impl_2(const rocblaslt_handle handle,
bool swizzleB
= problemtype.order_b != HIPBLASLT_ORDER_COL && problemtype.order_b != HIPBLASLT_ORDER_ROW;

if(swizzleA && batch_stride_a == 0)
{
// If batch_stride has never been assigned for swizzle, set it to the default value
// Swizzle is TN: For MatA, unrollDim is k (Leading Dim), tileDim is m
setDefaultSwizzledBatchedStride(problemtype.order_a, k, m, batch_stride_a);
}

auto status = validateMatmulArgs(m,
n,
k,
Expand Down