diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py b/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py index 7e626d189c5..081becce04f 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py @@ -959,7 +959,75 @@ def makeValidMatrixInstructions(): # 1: Use TDM for A # 2: Use TDM for B # 3: Use TDM for both A and B - "TDMInst": [0, 1, 2, 3] + "TDMInst": [0, 1, 2, 3], + # Bias support for GEMM operations + # 0: No bias + # 1: Bias vector on M direction + # 2: Bias vector on N direction + # 3: Bias vector on both M and N directions + "UseBias": [0, 1, 2, 3], + # Alpha vector scaling support + # 0: Disabled + # 1: Alpha vector on M direction + # 2: Alpha vector on N direction + # 3: Alpha vector on both M and N directions + "UseScaleAlphaVec": [0, 1, 2, 3], + # MX (microscaling) block size for matrix A + # 0: Disabled + # 16, 32: Valid MX block sizes + "MXBlockA": [0, 16, 32], + # MX (microscaling) block size for matrix B + # 0: Disabled + # 16, 32: Valid MX block sizes + "MXBlockB": [0, 16, 32], + # Enable beta scaling in GEMM operation (D = alpha*A*B + beta*C) + # False: beta is not used (beta = 0) + # True: beta is used + "UseBeta": [False, True], + # Enable auxiliary output matrix E + # False: E output is not used + # True: E output is used + "UseE": [False, True], + # Enable gradient computation + # False: gradient computation is not used + # True: gradient computation is used + "Gradient": [False, True], + # Enable scaling for C and D matrices + # False: scaling for C and D is not used + # True: scaling for C and D is used + "UseScaleCD": [False, True], + # Enable high precision accumulate + # False: standard precision accumulation + # True: high precision accumulation + "HighPrecisionAccumulate": [False, True], + # Enable silent high precision accumulate (no warning if downcast occurs) + # False: warnings enabled for precision downcasts + # True: silent mode for precision downcasts + "SilentHighPrecisionAccumulate": [False, True], + # Enable strided batched GEMM + # False: not strided batched + # True: strided batched GEMM + "StridedBatched": [False, True], + # Enable grouped GEMM + # False: not grouped GEMM + # True: grouped GEMM + "GroupedGemm": [False, True], + # Use initial strides for A and B matrices + # False: do not use initial strides for A and B + # True: use initial strides for A and B + "UseInitialStridesAB": [False, True], + # Use initial strides for C and D matrices + # False: do not use initial strides for C and D + # True: use initial strides for C and D + "UseInitialStridesCD": [False, True], + # Allow problems with no free dimensions + # False: do not allow problems with no free dimensions + # True: allow problems with no free dimensions + "AllowNoFreeDims": [False, True], + # Enable tile-aware solution selection + # False: disable tile-aware selection + # True: enable tile-aware selection + "TileAwareSelection": [False, True] } newMIValidParameters = { diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index af0f2bc8c8a..f00de5ff6b8 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -108,15 +108,29 @@ def validateParameterTypes(state, srcFile=""): are different Python types and produce different msgpack wire types, which causes ``std::bad_cast`` at C++ deserialization time. - Instead of raising on the first mismatch, mismatches are collected into - the module-level ``_typeMismatchCollector`` dict. Call - ``printTypeMismatchSummary()`` at the end of the build to emit a - consolidated warning. + For critical parameters (UseBias, UseScaleAlphaVec), exits immediately + with a clear error message on type mismatch. + + For other parameters, mismatches are collected into the module-level + ``_typeMismatchCollector`` dict. Call ``printTypeMismatchSummary()`` + at the end of the build to emit a consolidated warning. Args: state: The solution state dict (parameter name -> value). srcFile: The YAML source file path, included in warning messages. """ + # Input parameters to check - fail immediately on type mismatch + inputParamToCheck = { + # Integer parameters + "UseBias", "UseScaleAlphaVec", "MXBlockA", "MXBlockB", + # Boolean parameters + "UseBeta", "UseE", "Gradient", "UseScaleCD", + "HighPrecisionAccumulate", "SilentHighPrecisionAccumulate", + "StridedBatched", "GroupedGemm", + "UseInitialStridesAB", "UseInitialStridesCD", + "AllowNoFreeDims", "TileAwareSelection" + } + for key, value in state.items(): if key not in _expectedParamTypes or key in _skipTypeCheck: continue @@ -125,6 +139,146 @@ def validateParameterTypes(state, srcFile=""): # Use type() not isinstance() so that bool and int are distinguished if actualType not in expectedTypes: expectedStr = " or ".join(sorted(t.__name__ for t in expectedTypes)) + + # For input parameters to check, exit immediately with clear error + if key in inputParamToCheck: + errorMsg = [ + "", + "=" * 80, + "ERROR: Invalid type for parameter '{}'".format(key), + ] + if srcFile: + errorMsg.append("File: {}".format(srcFile)) + errorMsg.extend([ + "=" * 80, + " Expected type: {}".format(expectedStr), + " Actual type: {}".format(actualType.__name__), + " Value: {}".format(repr(value)), + "", + ]) + if key == "UseBias": + errorMsg.extend([ + " UseBias must be an integer (0, 1, 2, or 3):", + " 0 = no bias", + " 1 = bias vector on M direction", + " 2 = bias vector on N direction", + " 3 = bias vector on both M and N directions", + "", + ]) + elif key == "UseScaleAlphaVec": + errorMsg.extend([ + " UseScaleAlphaVec must be an integer (0, 1, 2, or 3):", + " 0 = disabled", + " 1 = alpha vector on M direction", + " 2 = alpha vector on N direction", + " 3 = alpha vector on both M and N directions", + "", + ]) + elif key == "MXBlockA": + errorMsg.extend([ + " MXBlockA must be an integer (0, 16, or 32):", + " 0 = disabled", + " 16, 32 = MX block sizes for matrix A", + "", + ]) + elif key == "MXBlockB": + errorMsg.extend([ + " MXBlockB must be an integer (0, 16, or 32):", + " 0 = disabled", + " 16, 32 = MX block sizes for matrix B", + "", + ]) + elif key == "UseBeta": + errorMsg.extend([ + " UseBeta must be a boolean (False or True):", + " False = beta is not used (beta = 0)", + " True = beta is used in GEMM (D = alpha*A*B + beta*C)", + "", + ]) + elif key == "UseE": + errorMsg.extend([ + " UseE must be a boolean (False or True):", + " False = auxiliary output matrix E is not used", + " True = auxiliary output matrix E is used", + "", + ]) + elif key == "Gradient": + errorMsg.extend([ + " Gradient must be a boolean (False or True):", + " False = gradient computation is not used", + " True = gradient computation is used", + "", + ]) + elif key == "UseScaleCD": + errorMsg.extend([ + " UseScaleCD must be a boolean (False or True):", + " False = scaling for C and D matrices is not used", + " True = scaling for C and D matrices is used", + "", + ]) + elif key == "HighPrecisionAccumulate": + errorMsg.extend([ + " HighPrecisionAccumulate must be a boolean (False or True):", + " False = standard precision accumulation", + " True = high precision accumulation", + "", + ]) + elif key == "SilentHighPrecisionAccumulate": + errorMsg.extend([ + " SilentHighPrecisionAccumulate must be a boolean (False or True):", + " False = warnings enabled for precision downcasts", + " True = silent mode for precision downcasts", + "", + ]) + elif key == "StridedBatched": + errorMsg.extend([ + " StridedBatched must be a boolean (False or True):", + " False = not strided batched", + " True = strided batched GEMM", + "", + ]) + elif key == "GroupedGemm": + errorMsg.extend([ + " GroupedGemm must be a boolean (False or True):", + " False = not grouped GEMM", + " True = grouped GEMM", + "", + ]) + elif key == "UseInitialStridesAB": + errorMsg.extend([ + " UseInitialStridesAB must be a boolean (False or True):", + " False = do not use initial strides for A and B matrices", + " True = use initial strides for A and B matrices", + "", + ]) + elif key == "UseInitialStridesCD": + errorMsg.extend([ + " UseInitialStridesCD must be a boolean (False or True):", + " False = do not use initial strides for C and D matrices", + " True = use initial strides for C and D matrices", + "", + ]) + elif key == "AllowNoFreeDims": + errorMsg.extend([ + " AllowNoFreeDims must be a boolean (False or True):", + " False = do not allow problems with no free dimensions", + " True = allow problems with no free dimensions", + "", + ]) + elif key == "TileAwareSelection": + errorMsg.extend([ + " TileAwareSelection must be a boolean (False or True):", + " False = disable tile-aware solution selection", + " True = enable tile-aware solution selection", + "", + ]) + errorMsg.extend([ + "Please fix the YAML configuration file.", + "=" * 80, + ]) + printExit("\n".join(errorMsg)) + + # For other parameters, collect for summary collectorKey = (key, actualType.__name__, expectedStr) if collectorKey not in _typeMismatchCollector: _typeMismatchCollector[collectorKey] = { @@ -316,6 +470,8 @@ def __init__( self._state = {} # problem type if "ProblemType" in config: + # Validate ProblemType parameters before constructing the ProblemType object + validateParameterTypes(config["ProblemType"], srcFile=srcName) self["ProblemType"] = ProblemType(config["ProblemType"], printIndexAssignmentInfo) else: self["ProblemType"] = ProblemType.FromDefaultConfig(printIndexAssignmentInfo)