Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions projects/hipblaslt/library/include/hipblaslt/hipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
#define HIPBLASLT_OPERATION_INVALID static_cast<hipblasOperation_t>(0)
#define ROCBLASLT_COMPUTE_TYPE_INVALID static_cast<rocblaslt_compute_type>(255)

#define HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT \
static_assert(false, "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT is deprecated and not supported. Please set HIPBLASLT_MATMUL_DESC_A_SCALE_MODE as HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F instead.")
#define HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT \
static_assert(false, "HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT is deprecated and not supported. Please set HIPBLASLT_MATMUL_DESC_B_SCALE_MODE as HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F instead.")

/*! \ingroup types_module
* \brief Specify the enum type to set the postprocessing options for the epilogue.
*/
Expand Down Expand Up @@ -187,8 +192,6 @@ typedef enum {
HIPBLASLT_MATMUL_DESC_B_SCALE_MODE = 32, /**<Scaling mode that defines how the matrix scaling factor for matrix B is interpreted. See hipblasLtMatmulMatrixScale_t */
HIPBLASLT_MATMUL_DESC_COMPUTE_INPUT_TYPE_A_EXT = 100, /**<Compute input A types. Defines the data type used for the input A of matrix multiply. */
HIPBLASLT_MATMUL_DESC_COMPUTE_INPUT_TYPE_B_EXT, /**<Compute input B types. Defines the data type used for the input B of matrix multiply. */
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, /**<Equivalent to HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER but in vector. Default value: NULL Type: void* /const void* */
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, /**<Equivalent to HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER but in vector. Default value: NULL Type: void* /const void* */
HIPBLASLT_MATMUL_DESC_MAX,
} hipblasLtMatmulDescAttributes_t;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,6 @@ typedef enum rocblaslt_matmul_desc_attributes_
ROCBLASLT_MATMUL_DESC_B_SCALE_MODE = 32,
ROCBLASLT_MATMUL_DESC_COMPUTE_INPUT_TYPE_A_EXT = 100,
ROCBLASLT_MATMUL_DESC_COMPUTE_INPUT_TYPE_B_EXT,
ROCBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT,
ROCBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT,
ROCBLASLT_MATMUL_DESC_MAX,
} rocblaslt_matmul_desc_attributes;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1035,11 +1035,8 @@ rocblaslt_status rocblaslt_matmul_desc_set_attribute(rocblaslt_matmul_desc
return rocblaslt_status_invalid_value;
}
break;
case ROCBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT:
matmulDesc->scaleAType = RocblasltContractionProblem::ScalingFormat::Vector;
case ROCBLASLT_MATMUL_DESC_A_SCALE_POINTER:
if(matmulAttr == ROCBLASLT_MATMUL_DESC_A_SCALE_POINTER
&& matmulDesc->scaleAType == RocblasltContractionProblem::ScalingFormat::None)
if(matmulDesc->scaleAType == RocblasltContractionProblem::ScalingFormat::None)
{
matmulDesc->scaleAType = RocblasltContractionProblem::ScalingFormat::Scalar;
}
Expand Down Expand Up @@ -1087,11 +1084,8 @@ rocblaslt_status rocblaslt_matmul_desc_set_attribute(rocblaslt_matmul_desc
return rocblaslt_status_invalid_value;
}
break;
case ROCBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT:
matmulDesc->scaleBType = RocblasltContractionProblem::ScalingFormat::Vector;
case ROCBLASLT_MATMUL_DESC_B_SCALE_POINTER:
if(matmulAttr == ROCBLASLT_MATMUL_DESC_B_SCALE_POINTER
&& matmulDesc->scaleBType == RocblasltContractionProblem::ScalingFormat::None)
if(matmulDesc->scaleBType == RocblasltContractionProblem::ScalingFormat::None)
{
matmulDesc->scaleBType = RocblasltContractionProblem::ScalingFormat::Scalar;
}
Expand Down Expand Up @@ -1351,7 +1345,6 @@ rocblaslt_status rocblaslt_matmul_desc_get_attribute(rocblaslt_matmul_desc
memcpy(buf, &matmulDesc->bias, sizeof(void*));
break;
case ROCBLASLT_MATMUL_DESC_A_SCALE_POINTER:
case ROCBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT:
if(sizeWritten)
*sizeWritten = sizeof(void*);
if(sizeInBytes < sizeof(void*))
Expand Down Expand Up @@ -1397,7 +1390,6 @@ rocblaslt_status rocblaslt_matmul_desc_get_attribute(rocblaslt_matmul_desc
}
break;
case ROCBLASLT_MATMUL_DESC_B_SCALE_POINTER:
case ROCBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT:
if(sizeWritten)
*sizeWritten = sizeof(void*);
if(sizeInBytes < sizeof(void*))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,6 @@ const char* rocblaslt_matmul_desc_attributes_to_string(rocblaslt_matmul_desc_att
return "MATMUL_DESC_AMAX_D_POINTER";
case ROCBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE:
return "MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE";
case ROCBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT:
return "MATMUL_DESC_A_SCALE_POINTER_VEC";
case ROCBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT:
return "MATMUL_DESC_B_SCALE_POINTER_VEC";
case ROCBLASLT_MATMUL_DESC_COMPUTE_INPUT_TYPE_A_EXT:
return "MATMUL_DESC_COMPUTE_INPUT_TYPE_A_EXT";
case ROCBLASLT_MATMUL_DESC_COMPUTE_INPUT_TYPE_B_EXT:
Expand Down
Loading