diff --git a/projects/hipblaslt/clients/common/include/testing_matmul.hpp b/projects/hipblaslt/clients/common/include/testing_matmul.hpp index 196cdf4a8d5..e037dc11060 100644 --- a/projects/hipblaslt/clients/common/include/testing_matmul.hpp +++ b/projects/hipblaslt/clients/common/include/testing_matmul.hpp @@ -55,6 +55,7 @@ #include #include #include +#include #include extern "C" __global__ void flush_icache() @@ -3731,144 +3732,125 @@ void testing_matmul_with_bias(const Arguments& arg, if(arg.algo_method == 2) { - std::vector tmpAlgo; heuristicResult.clear(); heuristicTuningIndex.clear(); - int algoIndexCount = 0; - int algoIndexInc = 100; - while(1) + const bool indicesAreDiscovered = (arg.solution_index == -1); + + std::vector validIndices; + auto discoverValidIndices = [&]() { + std::vector allAlgos; + EXPECT_HIPBLAS_STATUS(hipblaslt_ext::getAllAlgos(handle, + gemmType, + transA, + transB, + arg.a_type, + arg.b_type, + arg.c_type, + arg.d_type, + arg.compute_type, + allAlgos), + HIPBLAS_STATUS_SUCCESS); + validIndices.reserve(allAlgos.size()); + for(auto& a : allAlgos) + { + validIndices.push_back(hipblaslt_ext::getIndexFromAlgo(a.algo)); + } + }; + + if(indicesAreDiscovered) { - std::vector algoIndex; - std::vector tmpAlgo; - bool foundAlgo = false; - if(arg.solution_index == -1) - { - // Get algos by index - // In real cases, the user can use the saved algo index to get the algorithm. - // isAlgoSupported is not necessary if the user is sure that the algo supports the problem. - algoIndex.resize(algoIndexInc); - std::iota(std::begin(algoIndex), std::end(algoIndex), algoIndexCount); - algoIndexCount += algoIndexInc; - } - else - { - // Specify the index - algoIndex.resize(1); - algoIndex[0] = arg.solution_index; - } + discoverValidIndices(); + } - // INVALID_VALUE means some indices exceeded the pool size; valid algos are still returned in tmpAlgo - bool lastBatch = (HIPBLAS_STATUS_INVALID_VALUE - == hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); - if(tmpAlgo.empty()) - { - break; - } - returnedAlgoCount = tmpAlgo.size(); + bool selectionWasAttempted = false; - if(!do_grouped_gemm) - { - if(arg.use_ext && batchMode != HIPBLASLT_BATCH_MODE_POINTER_ARRAY) + auto searchForSupportedAlgoViaIndexAPI = [&]() { + constexpr size_t batchSize = 100; + size_t batchStart = 0; + bool explicitConsumed = false; + + auto nextBatchOfIndices = [&]() -> std::optional> { + if(indicesAreDiscovered) { - if(arg.use_ext_setproblem) + if(batchStart >= validIndices.size()) { - for(int32_t b = 0; b < block_count; b++) - CHECK_HIPBLASLT_ERROR(gemmVec[b].setProblem(M[0], - N[0], - K[0], - num_batches[0], - lda[0], - ldb[0], - ldc[0], - ldd[0], - stride_da[0], - stride_db[0], - stride_c[0], - stride_d[0], - extepilogue[0], - extinputs[b][0], - extproblemtype)); - } - else - { - for(int32_t b = 0; b < block_count; b++) - CHECK_HIPBLASLT_ERROR(gemmVec[b].setProblem( - matmul[b][0], - alpha_in[0], - (dA[0].as()) + b * size_dA[0] * realDataTypeSize(TiA), - matA[0], - (dB[0].as()) + b * size_dB[0] * realDataTypeSize(TiB), - matB[0], - &h_beta[0], - (dC[0].as()) + b * size_C[0] * realDataTypeSize(To), - matC[0], - ((*dDp)[0].as()) + b * size_D[0] * realDataTypeSize(To), - matD[0])); - } - for(int j = 0; j < returnedAlgoCount; j++) - { - for(size_t t = 0; t < tuningVec.size(); t++) - { - size_t tmpWorkspaceSize = 0; - if(gemmVec[0].isAlgoSupported( - tmpAlgo[j].algo, tuningVec[t], tmpWorkspaceSize) - == HIPBLAS_STATUS_SUCCESS) - { - if(tmpWorkspaceSize <= max_workspace_size) - { - heuristicResult.push_back(tmpAlgo[j]); - heuristicTuningIndex.push_back(t); - workspace_size = std::max(workspace_size, tmpWorkspaceSize); - foundAlgo = true; - } - } - } - CHECK_RETURNED_WORKSPACE_SIZE(workspace_size, max_workspace_size); - if(foundAlgo) - break; + return std::nullopt; } + const size_t batchEnd + = std::min(batchStart + batchSize, validIndices.size()); + std::vector batch(validIndices.begin() + batchStart, + validIndices.begin() + batchEnd); + batchStart = batchEnd; + return batch; } - else + if(explicitConsumed) + { + return std::nullopt; + } + explicitConsumed = true; + return std::vector{arg.solution_index}; + }; + + auto fetchAlgosForBatch + = [&](std::vector& batch, + std::vector& candidates) { + candidates.clear(); + const auto status + = hipblaslt_ext::getAlgosFromIndex(handle, batch, candidates); + if(indicesAreDiscovered) + { + EXPECT_HIPBLAS_STATUS(status, HIPBLAS_STATUS_SUCCESS); + } + }; + + auto configureExtGemmForCurrentProblem = [&]() { + if(arg.use_ext_setproblem) { - for(int j = 0; j < returnedAlgoCount; j++) + for(int32_t b = 0; b < block_count; b++) { - for(size_t t = 0; t < 1; t++) // CAPI not supported yet - { - size_t tmpWorkspaceSize = 0; - if(hipblaslt_ext::matmulIsAlgoSupported(handle, - matmul[0][0], - alpha_in[0], - matA[0], - matB[0], - &h_beta[0], - matC[0], - matD[0], - tmpAlgo[j].algo, - tmpWorkspaceSize) - == HIPBLAS_STATUS_SUCCESS) - { - if(tmpWorkspaceSize <= max_workspace_size) - { - heuristicResult.push_back(tmpAlgo[j]); - heuristicTuningIndex.push_back(t); - workspace_size = std::max(workspace_size, tmpWorkspaceSize); - foundAlgo = true; - break; - } - } - } - CHECK_RETURNED_WORKSPACE_SIZE(workspace_size, max_workspace_size); + CHECK_HIPBLASLT_ERROR(gemmVec[b].setProblem(M[0], + N[0], + K[0], + num_batches[0], + lda[0], + ldb[0], + ldc[0], + ldd[0], + stride_da[0], + stride_db[0], + stride_c[0], + stride_d[0], + extepilogue[0], + extinputs[b][0], + extproblemtype)); } + return; } - } - else - { + for(int32_t b = 0; b < block_count; b++) + { + CHECK_HIPBLASLT_ERROR(gemmVec[b].setProblem( + matmul[b][0], + alpha_in[0], + (dA[0].as()) + b * size_dA[0] * realDataTypeSize(TiA), + matA[0], + (dB[0].as()) + b * size_dB[0] * realDataTypeSize(TiB), + matB[0], + &h_beta[0], + (dC[0].as()) + b * size_C[0] * realDataTypeSize(To), + matC[0], + ((*dDp)[0].as()) + b * size_D[0] * realDataTypeSize(To), + matD[0])); + } + }; + + auto configureGroupedGemmForCurrentProblem = [&]() { if(arg.use_ext_setproblem) { auto num_batches_64 = std::vector{num_batches.begin(), num_batches.end()}; for(int32_t b = 0; b < block_count; b++) + { CHECK_HIPBLASLT_ERROR(groupedGemmVec[b].setProblem(M, N, K, @@ -3884,62 +3866,241 @@ void testing_matmul_with_bias(const Arguments& arg, extepilogue, extinputs[b], extproblemtype)); + } + return; } - else + std::vector h_alpha_void, h_beta_void; + for(size_t i = 0; i < h_alpha.size(); i++) { - std::vector h_alpha_void, h_beta_void; - for(size_t i = 0; i < h_alpha.size(); i++) - { - h_alpha_void.push_back(&h_alpha[i]); - h_beta_void.push_back(&h_beta[i]); - } - for(int32_t b = 0; b < block_count; b++) - CHECK_HIPBLASLT_ERROR(groupedGemmVec[b].setProblem(matmul[b], - h_alpha_void, - da[b], - matA, - db[b], - matB, - h_beta_void, - dc[b], - matC, - dd[b], - matD)); + h_alpha_void.push_back(&h_alpha[i]); + h_beta_void.push_back(&h_beta[i]); } - - for(int j = 0; j < returnedAlgoCount; j++) + for(int32_t b = 0; b < block_count; b++) { - for(size_t t = 0; t < tuningVec.size(); t++) - { - size_t tmpWorkspaceSize = 0; - if(groupedGemmVec[0].isAlgoSupported( - tmpAlgo[j].algo, tuningVec[t], tmpWorkspaceSize) - == HIPBLAS_STATUS_SUCCESS) - { - if(tmpWorkspaceSize <= max_workspace_size) - { - heuristicResult.push_back(tmpAlgo[j]); - heuristicTuningIndex.push_back(t); - workspace_size = std::max(workspace_size, tmpWorkspaceSize); - foundAlgo = true; - } - } - } - CHECK_RETURNED_WORKSPACE_SIZE(workspace_size, max_workspace_size); - if(foundAlgo) - break; + CHECK_HIPBLASLT_ERROR(groupedGemmVec[b].setProblem(matmul[b], + h_alpha_void, + da[b], + matA, + db[b], + matB, + h_beta_void, + dc[b], + matC, + dd[b], + matD)); + } + }; + + auto collectTuningsForFirstViableAlgo + = [&](std::vector& candidates, + auto& gemmObject, + bool& foundAlgo) { + foundAlgo = false; + for(int j = 0; j < returnedAlgoCount; j++) + { + for(size_t t = 0; t < tuningVec.size(); t++) + { + size_t tmpWorkspaceSize = 0; + if(gemmObject.isAlgoSupported( + candidates[j].algo, tuningVec[t], tmpWorkspaceSize) + != HIPBLAS_STATUS_SUCCESS) + { + continue; + } + if(tmpWorkspaceSize > max_workspace_size) + { + continue; + } + heuristicResult.push_back(candidates[j]); + heuristicTuningIndex.push_back(t); + workspace_size = std::max(workspace_size, tmpWorkspaceSize); + foundAlgo = true; + } + CHECK_RETURNED_WORKSPACE_SIZE(workspace_size, max_workspace_size); + if(foundAlgo) + { + break; + } + } + }; + + auto collectAllSupportedAlgosViaCAPI + = [&](std::vector& candidates, + bool& foundAlgo) { + foundAlgo = false; + for(int j = 0; j < returnedAlgoCount; j++) + { + size_t tmpWorkspaceSize = 0; + if(hipblaslt_ext::matmulIsAlgoSupported(handle, + matmul[0][0], + alpha_in[0], + matA[0], + matB[0], + &h_beta[0], + matC[0], + matD[0], + candidates[j].algo, + tmpWorkspaceSize) + == HIPBLAS_STATUS_SUCCESS + && tmpWorkspaceSize <= max_workspace_size) + { + heuristicResult.push_back(candidates[j]); + heuristicTuningIndex.push_back(0); + workspace_size = std::max(workspace_size, tmpWorkspaceSize); + foundAlgo = true; + } + CHECK_RETURNED_WORKSPACE_SIZE(workspace_size, max_workspace_size); + } + }; + + auto trySelectFromBatch + = [&](std::vector& candidates, + bool& foundAlgo) { + returnedAlgoCount = candidates.size(); + foundAlgo = false; + if(do_grouped_gemm) + { + configureGroupedGemmForCurrentProblem(); + collectTuningsForFirstViableAlgo( + candidates, groupedGemmVec[0], foundAlgo); + return; + } + if(arg.use_ext && batchMode != HIPBLASLT_BATCH_MODE_POINTER_ARRAY) + { + configureExtGemmForCurrentProblem(); + collectTuningsForFirstViableAlgo(candidates, gemmVec[0], foundAlgo); + return; + } + collectAllSupportedAlgosViaCAPI(candidates, foundAlgo); + }; + + while(auto batch = nextBatchOfIndices()) + { + std::vector candidates; + fetchAlgosForBatch(*batch, candidates); + if(candidates.empty()) + { + break; + } + selectionWasAttempted = true; + bool foundAlgo = false; + trySelectFromBatch(candidates, foundAlgo); + if(foundAlgo) + { + break; } } + }; + + auto verifyMixedValidityContract = [&]() { + auto verifyShape = [&](const char* shapeName, + const std::vector& indices, + hipblasStatus_t expectedStatus, + size_t expectedValidCount) { + std::vector mixedOut; + std::vector indicesCopy = indices; + const auto status + = hipblaslt_ext::getAlgosFromIndex(handle, indicesCopy, mixedOut); + if(status != expectedStatus || mixedOut.size() != expectedValidCount) + { + hipblaslt_cerr + << "verifyMixedValidityContract[" << shapeName + << "]: status=" << hipblas_status_to_string(status) << " (expected " + << hipblas_status_to_string(expectedStatus) + << "), mixedOut.size()=" << mixedOut.size() << " (expected " + << expectedValidCount << ")" << std::endl; + } + EXPECT_HIPBLAS_STATUS(status, expectedStatus); +#ifdef GOOGLE_TEST + EXPECT_EQ(mixedOut.size(), expectedValidCount) << "shape: " << shapeName; +#endif + }; + + const size_t numValid = validIndices.size(); - if(arg.solution_index != -1) { - CHECK_SOLUTION_FOUND(foundAlgo); - foundAlgo = true; + std::vector shape = validIndices; + shape.push_back(std::numeric_limits::max()); + verifyShape("trailing-invalid", shape, HIPBLAS_STATUS_INVALID_VALUE, numValid); } - if(lastBatch || foundAlgo) { - break; + std::vector shape = validIndices; + shape.insert(shape.begin() + (numValid / 2), + std::numeric_limits::max()); + verifyShape("middle-invalid", shape, HIPBLAS_STATUS_INVALID_VALUE, numValid); + } + { + std::vector shape = validIndices; + std::reverse(shape.begin(), shape.end()); + verifyShape("reversed", shape, HIPBLAS_STATUS_SUCCESS, numValid); } + { + std::vector shape; + shape.reserve(2 * numValid); + for(int idx : validIndices) + { + shape.push_back(idx); + shape.push_back(idx); + } + verifyShape("duplicated", shape, HIPBLAS_STATUS_SUCCESS, 2 * numValid); + } + { + verifyShape("empty", {}, HIPBLAS_STATUS_SUCCESS, 0); + } + { + // INT_MIN is avoided here because under HIPBLASLT_USE_ROCROLLER negative + // indices route to a different (rocroller) code path; two large positive + // out-of-range values pin the all-miss return in both build configs. + const std::vector shape{std::numeric_limits::max(), + std::numeric_limits::max() - 1}; + verifyShape("all-invalid", shape, HIPBLAS_STATUS_INVALID_VALUE, 0); + } + }; + + searchForSupportedAlgoViaIndexAPI(); + + if(!indicesAreDiscovered) + { + if(!selectionWasAttempted) + { + hipblaslt_cerr + << "MatmulAlgoIndex: explicit solution_index=" << arg.solution_index + << " returned no candidates from getAlgosFromIndex() (M=" << M[0] + << " N=" << N[0] << " K=" << K[0] << " batch=" << num_batches[0] + << " transA=" << arg.transA << " transB=" << arg.transB + << " a=" << hip_datatype_to_string(arg.a_type) + << " b=" << hip_datatype_to_string(arg.b_type) + << " c=" << hip_datatype_to_string(arg.c_type) + << " d=" << hip_datatype_to_string(arg.d_type) + << " compute=" << hipblas_computetype_to_string(arg.compute_type) << ")" + << std::endl; + CHECK_SOLUTION_FOUND(0); + } + else + { + CHECK_SOLUTION_FOUND(heuristicResult.size()); + } + } + else if(!validIndices.empty()) + { + if(heuristicResult.empty()) + { + hipblaslt_cerr + << "MatmulAlgoIndex: " << validIndices.size() + << " algo indices discovered via getAllAlgos() but none produced a " + "viable algo+tuning under the workspace budget (M=" + << M[0] << " N=" << N[0] << " K=" << K[0] + << " batch=" << num_batches[0] << " transA=" << arg.transA + << " transB=" << arg.transB + << " a=" << hip_datatype_to_string(arg.a_type) + << " b=" << hip_datatype_to_string(arg.b_type) + << " c=" << hip_datatype_to_string(arg.c_type) + << " d=" << hip_datatype_to_string(arg.d_type) + << " compute=" << hipblas_computetype_to_string(arg.compute_type) << ")" + << std::endl; + CHECK_SOLUTION_FOUND(0); + } + verifyMixedValidityContract(); } } else if(arg.algo_method == 1) diff --git a/projects/hipblaslt/tensilelite/include/Tensile/MasterSolutionLibrary.hpp b/projects/hipblaslt/tensilelite/include/Tensile/MasterSolutionLibrary.hpp index fc764dc9ddf..3b619482cd2 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/MasterSolutionLibrary.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/MasterSolutionLibrary.hpp @@ -145,53 +145,65 @@ namespace TensileLite void loadLibrary(const int index) const { + // TODO(#7080): point-key + upper_bound misses on above-largest and + // gap-between-keys; switch to range-encoded mapping. auto it = libraryMapping.upper_bound(index); - if(it != libraryMapping.begin()) + if(it == libraryMapping.begin()) { - --it; - std::string filePrefix = it->second; - // load the file here directly and push the library for later use. + if(Debug::Instance().printDataInit()) { - std::lock_guard lock(solutionsGuard); - if(loadedFiles.find(filePrefix) != loadedFiles.end()) - return; + std::cout << "Index " << index << " not in this arch's mapping range" + << std::endl; } - if(Debug::Instance().printDataInit()) - std::cout << "Loading library for index " << index - << " from file: " << filePrefix << std::endl; + return; + } + --it; + std::string filePrefix = it->second; + // load the file here directly and push the library for later use. + { + std::lock_guard lock(solutionsGuard); + if(loadedFiles.find(filePrefix) != loadedFiles.end()) + { + return; + } + } + if(Debug::Instance().printDataInit()) + { + std::cout << "Loading library for index " << index + << " from file: " << filePrefix << std::endl; + } - fs::path path(libraryDirectory); - path = path / (filePrefix + suffix); + fs::path path(libraryDirectory); + path = path / (filePrefix + suffix); - auto newLibrary = LoadLibraryFile(path.string()); - auto mLibrary - = static_cast*>(newLibrary.get()); + auto newLibrary = LoadLibraryFile(path.string()); + auto mLibrary + = static_cast*>(newLibrary.get()); - using std::begin; - using std::end; + using std::begin; + using std::end; - std::lock_guard lock(solutionsGuard); - if(loadedFiles.find(filePrefix) != loadedFiles.end()) - return; - // Push to cache - indexLoadedLibraries[filePrefix] = mLibrary->library; - - std::transform(begin(mLibrary->solutions), - end(mLibrary->solutions), - std::inserter(solutions, end(solutions)), - [this, filePrefix](auto& i) { - i.second->codeObjectFilename = filePrefix + ".co"; - return i; - }); - loadedFiles.insert(filePrefix); - - if(Debug::Instance().printCodeObjectInfo()) - std::cout << "load placeholder library " << path << std::endl - << mLibrary->solutions.size() << " solutions loaded" << std::endl; + std::lock_guard lock(solutionsGuard); + if(loadedFiles.find(filePrefix) != loadedFiles.end()) + { + return; } - else + // Push to cache + indexLoadedLibraries[filePrefix] = mLibrary->library; + + std::transform(begin(mLibrary->solutions), + end(mLibrary->solutions), + std::inserter(solutions, end(solutions)), + [this, filePrefix](auto& i) { + i.second->codeObjectFilename = filePrefix + ".co"; + return i; + }); + loadedFiles.insert(filePrefix); + + if(Debug::Instance().printCodeObjectInfo()) { - std::cerr << "No library found for index " << index << std::endl; + std::cout << "load placeholder library " << path << std::endl + << mLibrary->solutions.size() << " solutions loaded" << std::endl; } }