Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,18 @@ Status ComputeMatMul(ComputeContext* context,

TensorShape output_shape = helper.OutputShape();

const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2];
// check if A is batch of vector (bach is not 1, M is 1) and B is a matrix (batch is 1)
if (batchA != 1 && dim_output_outer == 1 && batchB == 1) {
// optimization for batched vector matrix multiplication
// dimensions of A: [1,`batchA`,K]
TensorShapeVector dims_a = {1, batchA, helper.K()};
// When B is a matrix (batch is 1), we fold batchA into the M dimension for better
// performance (e.g., [2,3,5] → [1,6,5]).
if (batchA != 1 && batchB == 1) {
// dimensions of A: [1,`batchA`, M, K]
Comment thread
guschmue marked this conversation as resolved.
int64_t batchAndM = a_shape.SizeToDimension(a_shape.NumDimensions() - 1);
TensorShapeVector dims_a = {1, batchAndM, helper.K()};
// dimensions of B: [1,K,N]
TensorShapeVector dims_b = {1, helper.K(), helper.N()};

a_shape = TensorShape(dims_a);
b_shape = TensorShape(dims_b);
output_shape = {1, batchA, helper.N()};
output_shape = {1, batchAndM, helper.N()};
}

// helpful dimension variables
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ Status ApplyMatMulIntel(ComputeContext& context,

TensorShape output_shape = helper.OutputShape();

const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2];
// check if A is batch of vector (bach is not 1, M is 1) and B is a matrix (batch is 1)
if (batchA != 1 && dim_output_outer == 1 && batchB == 1) {
// optimization for batched vector matrix multiplication
// dimensions of A: [1,`batchA`,K]
TensorShapeVector dims_a = {1, batchA, helper.K()};
// When B is a matrix (batch is 1), we fold batchA into the M dimension for better
// performance (e.g., [2,3,5] → [1,6,5]).
if (batchA != 1 && batchB == 1) {
// dimensions of A: [1,`batchA`, M, K]
int64_t batchAndM = a_shape.SizeToDimension(a_shape.NumDimensions() - 1);
TensorShapeVector dims_a = {1, batchAndM, helper.K()};
// dimensions of B: [1,K,N]
TensorShapeVector dims_b = {1, helper.K(), helper.N()};

a_shape = TensorShape(dims_a);
b_shape = TensorShape(dims_b);
output_shape = {1, batchA, helper.N()};
output_shape = {1, batchAndM, helper.N()};
}

// helpful dimension variables
Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/test/providers/cpu/math/matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,38 @@ std::vector<MatMulTestData<T>> GenerateTestCases() {
// clang-format on
})});

test_cases.push_back(
{"test 3D tensors with batchA = 3, M = 2, N = 3",
{3, 2, 8},
{1, 8, 3},
{3, 2, 3},
real_expected_vals({
// clang-format off
420, 448, 476,
1092, 1184, 1276,
1764, 1920, 2076,
2436, 2656, 2876,
3108, 3392, 3676,
3780, 4128, 4476,
// clang-format on
})});

test_cases.push_back(
{"test 3D tensors with batchA = 3, M = 2, N = 4",
{3, 2, 8},
{1, 8, 4},
{3, 2, 4},
real_expected_vals({
// clang-format off
560, 588, 616, 644,
1456, 1548, 1640, 1732,
2352, 2508, 2664, 2820,
3248, 3468, 3688, 3908,
4144, 4428, 4712, 4996,
5040, 5388, 5736, 6084,
// clang-format on
})});

test_cases.push_back(
{"test 4D tensors with M = 1",
{2, 3, 1, 8},
Expand Down
Loading