Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

#include <rocRoller/DataTypes/DataTypes.hpp>

#include "origami/types.hpp"
#include <origami/utils.hpp>

/**
* @brief Convert rocRoller::Datatype to analytical::DataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,52 +28,55 @@
#include "gemm.hpp"
#include "runtime_args_selection.hpp"

#include "origami/streamk.hpp"
#include <origami/streamk.hpp>

const int DEFAULT_DYNAMIC_MODE = 6;

int chooseStreamKGridSize(std::shared_ptr<GemmKernel> gemm,
const RocblasltContractionProblem& prob)
{
const origami::hardware_t analytical_hardware = origami::hardware_t::get_hardware_for_device(0);

const origami::grid_selection_t DEFAULT_DYNAMIC_MODE = origami::grid_selection_t::k_split_aware;

//setting max_cu's
size_t max_cus = analytical_hardware.N_CU;
const origami::hardware_t analaytical_hardware = origami::hardware_t::get_hardware_for_device(0);

size_t elementSizeA_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeA).elementBits;
size_t elementSizeB_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeB).elementBits;
size_t elementSizeD_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeD).elementBits;
size_t elementSizeAcc = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeAcc).elementBytes;

origami::problem_t origami_problem = {
.size = {prob.m, prob.n, prob.k},
.batch = prob.batch_count,
.a_dtype = rocroller_type_to_analytical_type(gemm->params->kernelType.typeA),
.b_dtype = rocroller_type_to_analytical_type(gemm->params->kernelType.typeB),
.mi_dtype = rocroller_type_to_analytical_type(elementSizeA_bits < elementSizeB_bits ? gemm->params->kernelType.typeB : gemm->params->kernelType.typeA),
};
origami::config_t origami_config = {
.mt = {
static_cast<size_t>(gemm->params->workgroupTile.m),
static_cast<size_t>(gemm->params->workgroupTile.n),
static_cast<size_t>(gemm->params->workgroupTile.k)
},
.occupancy = gemm->occupancy,
.workspace_size = prob.workspaceSize,
.workspace_size_per_elem_c = elementSizeAcc,
};

auto reduction_type = origami::streamk::select_reduction(origami_problem,
analytical_hardware,
origami_config,
DEFAULT_DYNAMIC_MODE);
origami::data_type_t dataType;
if (elementSizeA_bits < elementSizeB_bits)
dataType = rocroller_type_to_analytical_type(gemm->params->kernelType.typeB);
else
dataType = rocroller_type_to_analytical_type(gemm->params->kernelType.typeA);

origami_config.reduction_strategy = reduction_type;
auto reduction_type = origami::streamk::select_reduction(prob.m, prob.n, prob.k, prob.batch_count,
gemm->params->workgroupTile.m, gemm->params->workgroupTile.n, gemm->params->workgroupTile.k, analaytical_hardware, DEFAULT_DYNAMIC_MODE);
// Override reduction type to tree reduction for now.
// When Parallel reduction is available, this line can be removed
reduction_type = origami::streamk::reduction_type::Tree;

auto result = origami::streamk::select_grid_size(origami_problem,
analytical_hardware,
origami_config,
DEFAULT_DYNAMIC_MODE,
max_cus);
auto result = origami::streamk::select_grid(prob.m,
prob.n,
prob.k,
prob.batch_count,
prob.trans_a == HIPBLAS_OP_T,
prob.trans_b == HIPBLAS_OP_T,
elementSizeA_bits,
elementSizeB_bits,
elementSizeD_bits,
dataType,
prob.workspaceSize,
gemm->params->workgroupTile.m,
gemm->params->workgroupTile.n,
gemm->params->workgroupTile.k,
gemm->params->machineInstruction.m,
gemm->params->machineInstruction.n,
gemm->params->machineInstruction.k,
DEFAULT_WGM,
elementSizeAcc,
gemm->occupancy,
analaytical_hardware,
DEFAULT_DYNAMIC_MODE,
reduction_type);

return result;
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include "runtime_args_selection.hpp"
#include "solution_selection.hpp"

#include "origami/origami.hpp"
#include <origami/utils.hpp>

const int MAX_BITS_WORKGROUPTILE_M = 8;
const int MAX_BITS_WORKGROUPTILE_N = 8;
Expand All @@ -44,9 +44,7 @@ const int USE_WORKGROUP_MAPPING_K_SIZE = 4096;
* compile-time known.
*/

constexpr size_t possibleTileSizesCount = 34;

constexpr std::array<WorkGroupTileSize, possibleTileSizesCount> possibleTileSizes = {{
constexpr std::array<WorkGroupTileSize, 34> possibleTileSizes = {{
{256, 256, 128},
{256, 192, 128},
{256, 128, 128},
Expand Down Expand Up @@ -84,10 +82,10 @@ const int USE_WORKGROUP_MAPPING_K_SIZE = 4096;
}};

template <rocRoller::DataType typeA, rocRoller::DataType typeB>
auto generateTileList() {
std::array<origami::config_t, possibleTileSizesCount> tileList{};
constexpr auto generateTileList() {
std::array<origami::tile_tuple, possibleTileSizes.size()> tileList{};

for (size_t i = 0; i < possibleTileSizesCount; ++i) {
for (size_t i = 0; i < possibleTileSizes.size(); ++i) {
const auto& wgt = possibleTileSizes[i];
auto MI = pickMI(typeA, typeB, wgt);

Expand All @@ -98,33 +96,27 @@ auto generateTileList() {

int unroll = preferredUnrolling(typeA, typeB, wgt);

origami::config_t origami_config = {
.mt = {
static_cast<size_t>(wgt.m),
static_cast<size_t>(wgt.n),
static_cast<size_t>(wgtk * unroll)
},
.mi = {
static_cast<size_t>(MI.m),
static_cast<size_t>(MI.n),
static_cast<size_t>(MI.k)
},
.occupancy = 1,
.cache_hints_a = 0,
.cache_hints_b = 0,
};

tileList[i] = origami_config;
int non_temporal_a = 0;
int non_temporal_b = 0;

tileList[i] = std::make_tuple(
wgt.m, wgt.n, wgtk * unroll,
MI.m, MI.n, MI.k,
1, // occupancy
DEFAULT_WGM,
non_temporal_a,
non_temporal_b
);
}

return tileList;
}

using TileListGeneratorFn = std::vector<origami::config_t>(*)();
using TileListGeneratorFn = std::vector<origami::tile_tuple>(*)();

template <rocRoller::DataType A, rocRoller::DataType B>
std::vector<origami::config_t> generateTileListWrapper() {
auto arr = generateTileList<A, B>();
std::vector<origami::tile_tuple> generateTileListWrapper() {
constexpr auto arr = generateTileList<A, B>();
return {arr.begin(), arr.end()};
}

Expand Down Expand Up @@ -152,7 +144,7 @@ const std::map<std::pair<rocRoller::DataType, rocRoller::DataType>, TileListGene
INSTANTIATE_TILE_LIST_FOR(FP6)
};

std::vector<origami::config_t> getTileListForKernelType(KernelType kernelType)
std::vector<origami::tile_tuple> getTileListForKernelType(KernelType kernelType)
{
auto key = std::make_pair(kernelType.typeA, kernelType.typeB);
auto it = tileListGenerators.find(key);
Expand All @@ -178,42 +170,43 @@ std::vector<SolutionIndexParameters> chooseSolutionIndexParameters(
{
std::vector<SolutionIndexParameters> params;

std::vector<origami::config_t> origami_config_list = getTileListForKernelType(kernelType);
std::vector<origami::tile_tuple> tile_list = getTileListForKernelType(kernelType);

size_t elementSizeA_bits = rocRoller::DataTypeInfo::Get(kernelType.typeA).elementBits;
size_t elementSizeB_bits = rocRoller::DataTypeInfo::Get(kernelType.typeB).elementBits;

const origami::hardware_t analytical_hardware = origami::hardware_t::get_hardware_for_device(0);

origami::problem_t origami_problem = {
.size = {prob.m, prob.n, prob.k},
.batch = prob.batch_count,
.a_transpose = (prob.trans_a == hipblasOperation_t::HIPBLAS_OP_T) ? origami::transpose_t::T : origami::transpose_t::N,
.b_transpose = (prob.trans_b == hipblasOperation_t::HIPBLAS_OP_T) ? origami::transpose_t::T : origami::transpose_t::N,
.a_dtype = rocroller_type_to_analytical_type(kernelType.typeA),
.b_dtype = rocroller_type_to_analytical_type(kernelType.typeB),
.mi_dtype = rocroller_type_to_analytical_type(elementSizeA_bits < elementSizeB_bits ? kernelType.typeB : kernelType.typeA),
.a_mx_block_size = kernelType.scaleTypeA.blockRowSize * kernelType.scaleTypeA.blockColSize,
.b_mx_block_size = kernelType.scaleTypeB.blockRowSize * kernelType.scaleTypeB.blockColSize,
};

int defaultWGM = std::ceil(std::sqrt(analytical_hardware.N_CU / analytical_hardware.NUM_XCD));
for (auto& config : origami_config_list) {
config.workgroup_mapping = defaultWGM;
}

auto prediction_result = origami::rank_configs(
origami_problem,
analytical_hardware,
origami_config_list
);

for(auto const& result : prediction_result)
size_t elementSizeC_bits = rocRoller::DataTypeInfo::Get(kernelType.typeC).elementBits;

origami::data_type_t dataType;
if (elementSizeA_bits < elementSizeB_bits)
dataType = rocroller_type_to_analytical_type(kernelType.typeB);
else
dataType = rocroller_type_to_analytical_type(kernelType.typeA);

const origami::hardware_t analaytical_hardware = origami::hardware_t::get_hardware_for_device(0);

int WGM = std::sqrt(std::floor(analaytical_hardware.N_CU / analaytical_hardware.NUM_XCD));

auto selected_tiles = origami::select_best_macro_tile_size(
prob.m,
prob.n,
prob.k,
prob.batch_count,
prob.trans_a == hipblasOperation_t::HIPBLAS_OP_T,
prob.trans_b == hipblasOperation_t::HIPBLAS_OP_T,
analaytical_hardware,
tile_list,
elementSizeA_bits,
elementSizeB_bits,
elementSizeC_bits,
dataType,
kernelType.scaleTypeA.blockRowSize * kernelType.scaleTypeA.blockColSize, //Handle A vs B block size.
0.8,
false,
WGM);

for(auto const& selected_tile : selected_tiles)
{
auto mt_m = static_cast<int>(result.config.mt.m);
auto mt_n = static_cast<int>(result.config.mt.n);
auto mt_k = static_cast<int>(result.config.mt.k);
WorkGroupTileSize wgt{mt_m, mt_n, mt_k};
WorkGroupTileSize wgt{(int)std::get<1>(selected_tile), (int)std::get<2>(selected_tile), (int)std::get<3>(selected_tile)};
int unrollAmount = preferredUnrolling(kernelType.typeA, kernelType.typeB, wgt);
wgt.k /= unrollAmount;

Expand Down Expand Up @@ -249,7 +242,7 @@ std::vector<SolutionIndexParameters> chooseSolutionIndexParameters(
size_t numTilesN = prob.n / wgt.n;
size_t numTiles = numTilesM * numTilesN * prob.batch_count;
auto isF6 = (kernelType.typeA == rocRoller::DataType::FP6 || kernelType.typeA == rocRoller::DataType::BF6 || kernelType.typeB == rocRoller::DataType::FP6 || kernelType.typeB == rocRoller::DataType::BF6);
if(numTiles < analytical_hardware.N_CU && !isF6)
if(numTiles < analaytical_hardware.N_CU && !isF6)
{
params.back().streamK = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
#include <Tensile/Task.hpp>
#include <Tensile/Utils.hpp>

#include "origami/origami.hpp"
#include "origami/streamk.hpp"
#include <origami/streamk.hpp>

#define TENSILE_COMMON_KERNEL_ARGS_SIZE 16

Expand Down Expand Up @@ -167,7 +166,7 @@ namespace TensileLite

struct StreamKSettings
{
origami::reduction_t reduction = origami::reduction_t::tree;
origami::streamk::reduction_type reduction = origami::streamk::reduction_type::Tree;
size_t grid = 0;
};

Expand All @@ -184,7 +183,7 @@ namespace TensileLite
using Problem = ContractionProblemGemm;
using Inputs = ContractionInputs;
using GroupedInputs = ContractionGroupedInputs;
using ParamsCache = CacheMap<std::pair<int32_t, int32_t>, Problem>;
using ParamsCache = CacheMap<std::pair<int32_t, uint32_t>, Problem>;

/**
* Indicate a solution is equally or estimatedly matched.
Expand Down Expand Up @@ -219,11 +218,6 @@ namespace TensileLite
}
virtual bool isFallbackForHW(Hardware const&) const;

bool isStreamK() const
{
return sizeMapping.streamK > 0;
}

//! Estimates based on problem size, solution tile, and machine hardware
//! charz:
struct StaticPerformanceModel
Expand Down Expand Up @@ -296,8 +290,8 @@ namespace TensileLite
void calculateGrid(dim3& workGroupSize,
dim3& numWorkGroups,
ContractionSolution::Problem const& problem) const;
origami::reduction_t getSKReduction(Problem const& problem, Hardware const& hardware) const;
size_t getSKGrid(Problem const& problem, Hardware const& hardware, size_t tiles, origami::reduction_t reductionStrat) const;
origami::streamk::reduction_type getSKReduction(Problem const& problem, Hardware const& hardware) const;
size_t getSKGrid(Problem const& problem, Hardware const& hardware, size_t tiles, origami::streamk::reduction_type& reductionStrat) const;
size_t partialTileSize(size_t skGrid) const;

static float computeGranularity(float x);
Expand Down Expand Up @@ -572,9 +566,9 @@ namespace TensileLite
uint32_t magicNumber(int magicDivAlg, uint32_t x, uint32_t* magicShift) const;
uint32_t smallMagicNumber(uint32_t x) const;

std::pair<int32_t, int32_t> calculateAutoWGM(Problem const& problem,
Hardware const* hardware,
uint32_t skgrid) const;
std::pair<int32_t, uint32_t> calculateAutoWGM(Problem const& problem,
Hardware const* hardware,
uint32_t skgrid) const;
uint32_t calculateAutoGSU(Problem const& problem, Hardware const* hardware) const;
};

Expand Down
Loading
Loading