Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,75 @@ def makeValidMatrixInstructions():
# 1: Use TDM for A
# 2: Use TDM for B
# 3: Use TDM for both A and B
"TDMInst": [0, 1, 2, 3]
"TDMInst": [0, 1, 2, 3],
# Bias support for GEMM operations
# 0: No bias
# 1: Bias vector on M direction
# 2: Bias vector on N direction
# 3: Bias vector on both M and N directions
"UseBias": [0, 1, 2, 3],
# Alpha vector scaling support
# 0: Disabled
# 1: Alpha vector on M direction
# 2: Alpha vector on N direction
# 3: Alpha vector on both M and N directions
"UseScaleAlphaVec": [0, 1, 2, 3],
# MX (microscaling) block size for matrix A
# 0: Disabled
# 16, 32: Valid MX block sizes
"MXBlockA": [0, 16, 32],
# MX (microscaling) block size for matrix B
# 0: Disabled
# 16, 32: Valid MX block sizes
"MXBlockB": [0, 16, 32],
# Enable beta scaling in GEMM operation (D = alpha*A*B + beta*C)
# False: beta is not used (beta = 0)
# True: beta is used
"UseBeta": [False, True],
# Enable auxiliary output matrix E
# False: E output is not used
# True: E output is used
"UseE": [False, True],
# Enable gradient computation
# False: gradient computation is not used
# True: gradient computation is used
"Gradient": [False, True],
# Enable scaling for C and D matrices
# False: scaling for C and D is not used
Comment on lines +962 to +996
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this doesn't seem correct to me that this fix requires adding all of these values to ValidParameters.py.

# True: scaling for C and D is used
"UseScaleCD": [False, True],
# Enable high precision accumulate
# False: standard precision accumulation
# True: high precision accumulation
"HighPrecisionAccumulate": [False, True],
# Enable silent high precision accumulate (no warning if downcast occurs)
# False: warnings enabled for precision downcasts
# True: silent mode for precision downcasts
"SilentHighPrecisionAccumulate": [False, True],
# Enable strided batched GEMM
# False: not strided batched
# True: strided batched GEMM
"StridedBatched": [False, True],
# Enable grouped GEMM
# False: not grouped GEMM
# True: grouped GEMM
"GroupedGemm": [False, True],
# Use initial strides for A and B matrices
# False: do not use initial strides for A and B
# True: use initial strides for A and B
"UseInitialStridesAB": [False, True],
# Use initial strides for C and D matrices
# False: do not use initial strides for C and D
# True: use initial strides for C and D
"UseInitialStridesCD": [False, True],
# Allow problems with no free dimensions
# False: do not allow problems with no free dimensions
# True: allow problems with no free dimensions
"AllowNoFreeDims": [False, True],
# Enable tile-aware solution selection
# False: disable tile-aware selection
# True: enable tile-aware selection
"TileAwareSelection": [False, True]
}

newMIValidParameters = {
Expand Down
164 changes: 160 additions & 4 deletions projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,29 @@ def validateParameterTypes(state, srcFile=""):
are different Python types and produce different msgpack wire types,
which causes ``std::bad_cast`` at C++ deserialization time.

Instead of raising on the first mismatch, mismatches are collected into
the module-level ``_typeMismatchCollector`` dict. Call
``printTypeMismatchSummary()`` at the end of the build to emit a
consolidated warning.
For critical parameters (UseBias, UseScaleAlphaVec), exits immediately
with a clear error message on type mismatch.

For other parameters, mismatches are collected into the module-level
``_typeMismatchCollector`` dict. Call ``printTypeMismatchSummary()``
at the end of the build to emit a consolidated warning.

Args:
state: The solution state dict (parameter name -> value).
srcFile: The YAML source file path, included in warning messages.
"""
# Input parameters to check - fail immediately on type mismatch
inputParamToCheck = {
# Integer parameters
"UseBias", "UseScaleAlphaVec", "MXBlockA", "MXBlockB",
# Boolean parameters
"UseBeta", "UseE", "Gradient", "UseScaleCD",
"HighPrecisionAccumulate", "SilentHighPrecisionAccumulate",
"StridedBatched", "GroupedGemm",
"UseInitialStridesAB", "UseInitialStridesCD",
"AllowNoFreeDims", "TileAwareSelection"
}

for key, value in state.items():
if key not in _expectedParamTypes or key in _skipTypeCheck:
continue
Expand All @@ -125,6 +139,146 @@ def validateParameterTypes(state, srcFile=""):
# Use type() not isinstance() so that bool and int are distinguished
if actualType not in expectedTypes:
expectedStr = " or ".join(sorted(t.__name__ for t in expectedTypes))

# For input parameters to check, exit immediately with clear error
if key in inputParamToCheck:
errorMsg = [
"",
"=" * 80,
"ERROR: Invalid type for parameter '{}'".format(key),
]
if srcFile:
errorMsg.append("File: {}".format(srcFile))
errorMsg.extend([
"=" * 80,
" Expected type: {}".format(expectedStr),
" Actual type: {}".format(actualType.__name__),
" Value: {}".format(repr(value)),
"",
])
if key == "UseBias":
errorMsg.extend([
" UseBias must be an integer (0, 1, 2, or 3):",
" 0 = no bias",
" 1 = bias vector on M direction",
" 2 = bias vector on N direction",
" 3 = bias vector on both M and N directions",
"",
])
elif key == "UseScaleAlphaVec":
errorMsg.extend([
" UseScaleAlphaVec must be an integer (0, 1, 2, or 3):",
" 0 = disabled",
" 1 = alpha vector on M direction",
Comment on lines +142 to +172
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. We cannot have this kind of code here.

" 2 = alpha vector on N direction",
" 3 = alpha vector on both M and N directions",
"",
])
elif key == "MXBlockA":
errorMsg.extend([
" MXBlockA must be an integer (0, 16, or 32):",
" 0 = disabled",
" 16, 32 = MX block sizes for matrix A",
"",
])
elif key == "MXBlockB":
errorMsg.extend([
" MXBlockB must be an integer (0, 16, or 32):",
" 0 = disabled",
" 16, 32 = MX block sizes for matrix B",
"",
])
elif key == "UseBeta":
errorMsg.extend([
" UseBeta must be a boolean (False or True):",
" False = beta is not used (beta = 0)",
" True = beta is used in GEMM (D = alpha*A*B + beta*C)",
"",
])
elif key == "UseE":
errorMsg.extend([
" UseE must be a boolean (False or True):",
" False = auxiliary output matrix E is not used",
" True = auxiliary output matrix E is used",
"",
])
elif key == "Gradient":
errorMsg.extend([
" Gradient must be a boolean (False or True):",
" False = gradient computation is not used",
" True = gradient computation is used",
"",
])
elif key == "UseScaleCD":
errorMsg.extend([
" UseScaleCD must be a boolean (False or True):",
" False = scaling for C and D matrices is not used",
" True = scaling for C and D matrices is used",
"",
])
elif key == "HighPrecisionAccumulate":
errorMsg.extend([
" HighPrecisionAccumulate must be a boolean (False or True):",
" False = standard precision accumulation",
" True = high precision accumulation",
"",
])
elif key == "SilentHighPrecisionAccumulate":
errorMsg.extend([
" SilentHighPrecisionAccumulate must be a boolean (False or True):",
" False = warnings enabled for precision downcasts",
" True = silent mode for precision downcasts",
"",
])
elif key == "StridedBatched":
errorMsg.extend([
" StridedBatched must be a boolean (False or True):",
" False = not strided batched",
" True = strided batched GEMM",
"",
])
elif key == "GroupedGemm":
errorMsg.extend([
" GroupedGemm must be a boolean (False or True):",
" False = not grouped GEMM",
" True = grouped GEMM",
"",
])
elif key == "UseInitialStridesAB":
errorMsg.extend([
" UseInitialStridesAB must be a boolean (False or True):",
" False = do not use initial strides for A and B matrices",
" True = use initial strides for A and B matrices",
"",
])
elif key == "UseInitialStridesCD":
errorMsg.extend([
" UseInitialStridesCD must be a boolean (False or True):",
" False = do not use initial strides for C and D matrices",
" True = use initial strides for C and D matrices",
"",
])
elif key == "AllowNoFreeDims":
errorMsg.extend([
" AllowNoFreeDims must be a boolean (False or True):",
" False = do not allow problems with no free dimensions",
" True = allow problems with no free dimensions",
"",
])
elif key == "TileAwareSelection":
errorMsg.extend([
" TileAwareSelection must be a boolean (False or True):",
" False = disable tile-aware solution selection",
" True = enable tile-aware solution selection",
"",
])
errorMsg.extend([
"Please fix the YAML configuration file.",
"=" * 80,
])
printExit("\n".join(errorMsg))

# For other parameters, collect for summary
collectorKey = (key, actualType.__name__, expectedStr)
if collectorKey not in _typeMismatchCollector:
_typeMismatchCollector[collectorKey] = {
Expand Down Expand Up @@ -316,6 +470,8 @@ def __init__(
self._state = {}
# problem type
if "ProblemType" in config:
# Validate ProblemType parameters before constructing the ProblemType object
validateParameterTypes(config["ProblemType"], srcFile=srcName)
self["ProblemType"] = ProblemType(config["ProblemType"], printIndexAssignmentInfo)
else:
self["ProblemType"] = ProblemType.FromDefaultConfig(printIndexAssignmentInfo)
Expand Down
Loading