diff --git a/library/src/blas3/Tensile/gemm.hpp b/library/src/blas3/Tensile/gemm.hpp index 696afc7e6..ccfc024ca 100644 --- a/library/src/blas3/Tensile/gemm.hpp +++ b/library/src/blas3/Tensile/gemm.hpp @@ -383,24 +383,6 @@ hipError_t call_tensile(const T* alpha, rocblas_int sizeL, rocblas_handle handle) { - // Currently alpha and beta can only be single values as - // tensile does not support arrays for scalars yet. -#ifndef NDEBUG - std::cout << "Solution Name: " - << tensileGetSolutionName(trans_a, - trans_b, - strideC1, - strideC2, - strideA1, - strideA2, - strideB1, - strideB2, - sizeI, - sizeJ, - sizeK, - sizeL) - << std::endl; -#endif // Collect alpha / beta (either from host or device). // Tensile doesn't support arrays of scalars for now, so we must handle diff --git a/library/src/tensile_host.cpp b/library/src/tensile_host.cpp index 495abf083..defec3590 100644 --- a/library/src/tensile_host.cpp +++ b/library/src/tensile_host.cpp @@ -61,52 +61,48 @@ auto create_gemm_contraction_problem_strided_batched(rocblas_operation trans_a, { auto transposeA = trans_a != rocblas_operation_none; auto transposeB = trans_b != rocblas_operation_none; + auto dt = tensile_datatype; - auto dt = tensile_datatype; + Tensile::ContractionProblem::FreeIndices free{2}; + Tensile::ContractionProblem::BoundIndices bound{1}; + Tensile::ContractionProblem::BatchIndices batch{1}; - Tensile::ContractionProblem::FreeIndices free(2); - Tensile::ContractionProblem::BoundIndices bound(1); - Tensile::ContractionProblem::BatchIndices batch(1); - - free[0].isA=true; + free[0].isA = true; free[0].i = free[0].c = free[0].d = 0; - free[1].isA=false; + free[1].isA = false; free[1].i = free[1].c = free[1].d = 1; batch[0].a = batch[0].b = batch[0].c = batch[0].d = 2; - Tensile::TensorDescriptor a, b, c, d; + Tensile::TensorDescriptor a, b; if(transposeA) { - a = Tensile::TensorDescriptor(dt, {k, m, batchSize}, {1, ld_a, stride_a}); - free[0].i = 1; + a = {dt, {k, m, batchSize}, {1, ld_a, stride_a}}; + free[0].i = 1; bound[0].a = 0; } else { - a = Tensile::TensorDescriptor(dt, {m, k, batchSize}, {1, ld_a, stride_a}); - free[0].i = 0; + a = {dt, {m, k, batchSize}, {1, ld_a, stride_a}}; + free[0].i = 0; bound[0].a = 1; } if(transposeB) { - b = Tensile::TensorDescriptor(dt, {n, k, batchSize}, {1, ld_b, stride_b}); - free[1].i = 0; + b = {dt, {n, k, batchSize}, {1, ld_b, stride_b}}; + free[1].i = 0; bound[0].b = 1; } else { - b = Tensile::TensorDescriptor(dt, {k, n, batchSize}, {1, ld_b, stride_b}); - free[1].i = 1; + b = {dt, {k, n, batchSize}, {1, ld_b, stride_b}}; + free[1].i = 1; bound[0].b = 0; } - c = Tensile::TensorDescriptor(dt, {m, n, batchSize}, {1, ld_c, stride_c}); - d = Tensile::TensorDescriptor(dt, {m, n, batchSize}, {1, ld_c, stride_c}); - - Tensile::TensorOps nop; + Tensile::TensorDescriptor c{dt, {m, n, batchSize}, {1, ld_c, stride_c}}; Tensile::TensorOps aops; if(is_complex && trans_a == rocblas_operation_conjugate_transpose) @@ -116,8 +112,8 @@ auto create_gemm_contraction_problem_strided_batched(rocblas_operation trans_a, if(is_complex && trans_b == rocblas_operation_conjugate_transpose) bops = {Tensile::TensorOp::Type::ComplexConjugate}; - Tensile::ContractionProblem problem(a, aops, b, bops, c, nop, d, nop, - free, batch, bound, value_category(beta)); + Tensile::ContractionProblem problem{ + a, aops, b, bops, c, {}, c, {}, free, batch, bound, value_category(beta)}; return problem; } @@ -140,6 +136,7 @@ auto create_gemm_contraction_problem(rocblas_operation trans_a, { auto transposeA = trans_a != rocblas_operation_none; auto transposeB = trans_b != rocblas_operation_none; + auto dt = tensile_datatype; Tensile::ContractionProblem::FreeIndices freeIndex(2); Tensile::ContractionProblem::BoundIndex boundIndex; @@ -150,7 +147,6 @@ auto create_gemm_contraction_problem(rocblas_operation trans_a, freeIndex[1].i = freeIndex[1].c = freeIndex[1].d = 1; Tensile::TensorDescriptor a, b; - auto dt = tensile_datatype; if(transposeA) {