-
Notifications
You must be signed in to change notification settings - Fork 294
Fix for type mismatch: AIHPBLAS-1465 #6177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
22ed25b
Fix for AIHPBLAS-1465
pdhirajkumarprasad 0424464
Revert "Fix for AIHPBLAS-1465"
pdhirajkumarprasad 806165b
Correct fix for AIHPBLAS-1465: Add validation for UseBias and UseScal…
pdhirajkumarprasad 2668ceb
Merge branch 'develop' into users/dhirajp/AIHPBLAS-1465
pdhirajkumarprasad 57b2980
Revert "Correct fix for AIHPBLAS-1465: Add validation for UseBias and…
pdhirajkumarprasad 232b3b1
added check in solution.py
pdhirajkumarprasad File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,15 +108,29 @@ def validateParameterTypes(state, srcFile=""): | |
| are different Python types and produce different msgpack wire types, | ||
| which causes ``std::bad_cast`` at C++ deserialization time. | ||
|
|
||
| Instead of raising on the first mismatch, mismatches are collected into | ||
| the module-level ``_typeMismatchCollector`` dict. Call | ||
| ``printTypeMismatchSummary()`` at the end of the build to emit a | ||
| consolidated warning. | ||
| For critical parameters (UseBias, UseScaleAlphaVec), exits immediately | ||
| with a clear error message on type mismatch. | ||
|
|
||
| For other parameters, mismatches are collected into the module-level | ||
| ``_typeMismatchCollector`` dict. Call ``printTypeMismatchSummary()`` | ||
| at the end of the build to emit a consolidated warning. | ||
|
|
||
| Args: | ||
| state: The solution state dict (parameter name -> value). | ||
| srcFile: The YAML source file path, included in warning messages. | ||
| """ | ||
| # Input parameters to check - fail immediately on type mismatch | ||
| inputParamToCheck = { | ||
| # Integer parameters | ||
| "UseBias", "UseScaleAlphaVec", "MXBlockA", "MXBlockB", | ||
| # Boolean parameters | ||
| "UseBeta", "UseE", "Gradient", "UseScaleCD", | ||
| "HighPrecisionAccumulate", "SilentHighPrecisionAccumulate", | ||
| "StridedBatched", "GroupedGemm", | ||
| "UseInitialStridesAB", "UseInitialStridesCD", | ||
| "AllowNoFreeDims", "TileAwareSelection" | ||
| } | ||
|
|
||
| for key, value in state.items(): | ||
| if key not in _expectedParamTypes or key in _skipTypeCheck: | ||
| continue | ||
|
|
@@ -125,6 +139,146 @@ def validateParameterTypes(state, srcFile=""): | |
| # Use type() not isinstance() so that bool and int are distinguished | ||
| if actualType not in expectedTypes: | ||
| expectedStr = " or ".join(sorted(t.__name__ for t in expectedTypes)) | ||
|
|
||
| # For input parameters to check, exit immediately with clear error | ||
| if key in inputParamToCheck: | ||
| errorMsg = [ | ||
| "", | ||
| "=" * 80, | ||
| "ERROR: Invalid type for parameter '{}'".format(key), | ||
| ] | ||
| if srcFile: | ||
| errorMsg.append("File: {}".format(srcFile)) | ||
| errorMsg.extend([ | ||
| "=" * 80, | ||
| " Expected type: {}".format(expectedStr), | ||
| " Actual type: {}".format(actualType.__name__), | ||
| " Value: {}".format(repr(value)), | ||
| "", | ||
| ]) | ||
| if key == "UseBias": | ||
| errorMsg.extend([ | ||
| " UseBias must be an integer (0, 1, 2, or 3):", | ||
| " 0 = no bias", | ||
| " 1 = bias vector on M direction", | ||
| " 2 = bias vector on N direction", | ||
| " 3 = bias vector on both M and N directions", | ||
| "", | ||
| ]) | ||
| elif key == "UseScaleAlphaVec": | ||
| errorMsg.extend([ | ||
| " UseScaleAlphaVec must be an integer (0, 1, 2, or 3):", | ||
| " 0 = disabled", | ||
| " 1 = alpha vector on M direction", | ||
|
Comment on lines
+142
to
+172
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. We cannot have this kind of code here. |
||
| " 2 = alpha vector on N direction", | ||
| " 3 = alpha vector on both M and N directions", | ||
| "", | ||
| ]) | ||
| elif key == "MXBlockA": | ||
| errorMsg.extend([ | ||
| " MXBlockA must be an integer (0, 16, or 32):", | ||
| " 0 = disabled", | ||
| " 16, 32 = MX block sizes for matrix A", | ||
| "", | ||
| ]) | ||
| elif key == "MXBlockB": | ||
| errorMsg.extend([ | ||
| " MXBlockB must be an integer (0, 16, or 32):", | ||
| " 0 = disabled", | ||
| " 16, 32 = MX block sizes for matrix B", | ||
| "", | ||
| ]) | ||
| elif key == "UseBeta": | ||
| errorMsg.extend([ | ||
| " UseBeta must be a boolean (False or True):", | ||
| " False = beta is not used (beta = 0)", | ||
| " True = beta is used in GEMM (D = alpha*A*B + beta*C)", | ||
| "", | ||
| ]) | ||
| elif key == "UseE": | ||
| errorMsg.extend([ | ||
| " UseE must be a boolean (False or True):", | ||
| " False = auxiliary output matrix E is not used", | ||
| " True = auxiliary output matrix E is used", | ||
| "", | ||
| ]) | ||
| elif key == "Gradient": | ||
| errorMsg.extend([ | ||
| " Gradient must be a boolean (False or True):", | ||
| " False = gradient computation is not used", | ||
| " True = gradient computation is used", | ||
| "", | ||
| ]) | ||
| elif key == "UseScaleCD": | ||
| errorMsg.extend([ | ||
| " UseScaleCD must be a boolean (False or True):", | ||
| " False = scaling for C and D matrices is not used", | ||
| " True = scaling for C and D matrices is used", | ||
| "", | ||
| ]) | ||
| elif key == "HighPrecisionAccumulate": | ||
| errorMsg.extend([ | ||
| " HighPrecisionAccumulate must be a boolean (False or True):", | ||
| " False = standard precision accumulation", | ||
| " True = high precision accumulation", | ||
| "", | ||
| ]) | ||
| elif key == "SilentHighPrecisionAccumulate": | ||
| errorMsg.extend([ | ||
| " SilentHighPrecisionAccumulate must be a boolean (False or True):", | ||
| " False = warnings enabled for precision downcasts", | ||
| " True = silent mode for precision downcasts", | ||
| "", | ||
| ]) | ||
| elif key == "StridedBatched": | ||
| errorMsg.extend([ | ||
| " StridedBatched must be a boolean (False or True):", | ||
| " False = not strided batched", | ||
| " True = strided batched GEMM", | ||
| "", | ||
| ]) | ||
| elif key == "GroupedGemm": | ||
| errorMsg.extend([ | ||
| " GroupedGemm must be a boolean (False or True):", | ||
| " False = not grouped GEMM", | ||
| " True = grouped GEMM", | ||
| "", | ||
| ]) | ||
| elif key == "UseInitialStridesAB": | ||
| errorMsg.extend([ | ||
| " UseInitialStridesAB must be a boolean (False or True):", | ||
| " False = do not use initial strides for A and B matrices", | ||
| " True = use initial strides for A and B matrices", | ||
| "", | ||
| ]) | ||
| elif key == "UseInitialStridesCD": | ||
| errorMsg.extend([ | ||
| " UseInitialStridesCD must be a boolean (False or True):", | ||
| " False = do not use initial strides for C and D matrices", | ||
| " True = use initial strides for C and D matrices", | ||
| "", | ||
| ]) | ||
| elif key == "AllowNoFreeDims": | ||
| errorMsg.extend([ | ||
| " AllowNoFreeDims must be a boolean (False or True):", | ||
| " False = do not allow problems with no free dimensions", | ||
| " True = allow problems with no free dimensions", | ||
| "", | ||
| ]) | ||
| elif key == "TileAwareSelection": | ||
| errorMsg.extend([ | ||
| " TileAwareSelection must be a boolean (False or True):", | ||
| " False = disable tile-aware solution selection", | ||
| " True = enable tile-aware solution selection", | ||
| "", | ||
| ]) | ||
| errorMsg.extend([ | ||
| "Please fix the YAML configuration file.", | ||
| "=" * 80, | ||
| ]) | ||
| printExit("\n".join(errorMsg)) | ||
|
|
||
| # For other parameters, collect for summary | ||
| collectorKey = (key, actualType.__name__, expectedStr) | ||
| if collectorKey not in _typeMismatchCollector: | ||
| _typeMismatchCollector[collectorKey] = { | ||
|
|
@@ -316,6 +470,8 @@ def __init__( | |
| self._state = {} | ||
| # problem type | ||
| if "ProblemType" in config: | ||
| # Validate ProblemType parameters before constructing the ProblemType object | ||
| validateParameterTypes(config["ProblemType"], srcFile=srcName) | ||
| self["ProblemType"] = ProblemType(config["ProblemType"], printIndexAssignmentInfo) | ||
| else: | ||
| self["ProblemType"] = ProblemType.FromDefaultConfig(printIndexAssignmentInfo) | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry this doesn't seem correct to me that this fix requires adding all of these values to ValidParameters.py.