diff --git a/projects/hipblaslt/tensilelite/include/Tensile/AMDGPU.hpp b/projects/hipblaslt/tensilelite/include/Tensile/AMDGPU.hpp index b1f09712f98..8caec03f87c 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/AMDGPU.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/AMDGPU.hpp @@ -228,6 +228,7 @@ namespace TensileLite int computeUnitCount = 0; int skDynamicGrid = 6; int skDynamicWGM = 0; + int fixedWGM = std::numeric_limits::max(); int skMaxCUs = 0; int skGridMultiplier = 1; int skFixedGrid = 0; @@ -263,6 +264,13 @@ namespace TensileLite return value; } + const int getFixedWGM() const + { + static const char* envStr = std::getenv("TENSILE_FIXED_WGM"); + static const int value = (envStr == NULL ? std::numeric_limits::max() : std::atoi(envStr)); + return value; + } + const int getSKMaxCUs() const { static const char* envStr = std::getenv("TENSILE_STREAMK_MAX_CUS"); diff --git a/projects/hipblaslt/tensilelite/src/AMDGPU.cpp b/projects/hipblaslt/tensilelite/src/AMDGPU.cpp index 799076bfb1a..e93e72c8b3c 100644 --- a/projects/hipblaslt/tensilelite/src/AMDGPU.cpp +++ b/projects/hipblaslt/tensilelite/src/AMDGPU.cpp @@ -53,6 +53,7 @@ namespace TensileLite , deviceName(name) , skDynamicGrid(getSKDynamicGrid()) , skDynamicWGM(getSKDynamicWGM()) + , fixedWGM(getFixedWGM()) , skMaxCUs(getSKMaxCUs()) , skGridMultiplier(getSKGridMultiplier()) , skFixedGrid(getSKFixedGrid()) diff --git a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp index 3b8a678adf6..680559f03ad 100644 --- a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp +++ b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp @@ -1204,7 +1204,11 @@ namespace TensileLite if(sizeMapping.streamK != 0) { AMDGPU const* pAMDGPU = dynamic_cast(&hardware); - if(pAMDGPU->skDynamicWGM == 1) + if(pAMDGPU->fixedWGM >= -1024 && pAMDGPU->fixedWGM <= 1024) + { + defaultWGM = pAMDGPU->fixedWGM; + } + else if(pAMDGPU->skDynamicWGM == 1) { hip::HipAMDGPU const* hipAMDGPU = dynamic_cast(&hardware);