Skip to content

Commit

Permalink
updat
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriPlyakhin committed Jan 31, 2025
1 parent 0ffd40a commit d6e6847
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 54 deletions.
7 changes: 7 additions & 0 deletions sycl/test-e2e/Matrix/Inputs/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,10 @@ void matrix_print(unsigned int rows, unsigned int cols, T *mat) {
std::cout << "\n";
}
}

template <typename T, layout Layout> constexpr int vnni_factor() {
if constexpr (Layout != layout::ext_intel_packed)
return 1;
static_assert(sizeof(T) <= 4 && "Unsupported type in vnni_factor().");
return 4 / sizeof(T);
}
45 changes: 36 additions & 9 deletions sycl/test-e2e/Matrix/Inputs/joint_matrix_out_bounds_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,20 @@
#include <iostream>
#include <sycl/usm.hpp>

template <typename Tab, size_t K, layout B_layout, unsigned int vnniFactor>
class mult;
template <typename Tab, size_t K, layout B_layout> class mult;

template <typename T1, typename T2, size_t M, size_t N, size_t K, size_t TM,
size_t TN, size_t TK, layout A_layout, layout B_layout,
unsigned int vnniFactor>
size_t TN, size_t TK, layout A_layout, layout B_layout>
void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {

// Add one iteration for the out of bounds dpas instruction
size_t NDRangeM = M / TM + (((M % TM) != 0) ? 1 : 0);
size_t NDRangeN = N / TN;
size_t sg_size = get_sg_size<mult<T2, K, B_layout, vnniFactor>>(q);
size_t sg_size = get_sg_size<mult<T2, K, B_layout>>(q);
std::cout << "SG size: " << sg_size << " ";

q.submit([&](handler &cgh) {
cgh.parallel_for<mult<T2, K, B_layout, vnniFactor>>(
cgh.parallel_for<mult<T2, K, B_layout>>(
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
[=](nd_item<2> spmd_item)
#ifdef SG_SZ
Expand Down Expand Up @@ -72,6 +70,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
// bounds-checked load where width and height are added
// params order: Stride, Height, Width, CoordX, CoordY
if constexpr (B_layout != layout::col_major) {
constexpr unsigned int vnniFactor = vnni_factor<T2, B_layout>();
ext::intel::experimental::matrix::joint_matrix_load_checked(
sg, sub_b, pB, N * vnniFactor, K / vnniFactor,
N * vnniFactor, k / vnniFactor,
Expand All @@ -94,7 +93,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {

template <typename Tab, typename Tc, size_t MATRIX_M, size_t MATRIX_N,
size_t MATRIX_K, size_t TM, size_t TN, size_t TK, layout A_layout,
layout B_layout, unsigned int vnniFactor>
layout B_layout>
void test() {
std::cout << MATRIX_M << "x" << MATRIX_N << "x" << MATRIX_K << ", " << TM
<< "x" << TN << "x" << TK << ": ";
Expand Down Expand Up @@ -129,13 +128,14 @@ void test() {

if constexpr (B_layout == layout::ext_intel_packed) {
Tab *vnniB = malloc_shared<Tab>(MATRIX_K * MATRIX_N, q);
matrix_vnni(MATRIX_K, MATRIX_N, B, vnniB, vnniFactor);
matrix_vnni(MATRIX_K, MATRIX_N, B, vnniB, vnni_factor<Tab, B_layout>());
Tab *tmp = B;
B = vnniB;
free(tmp, q);
}

matrix_multiply<Tc, Tab, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK, A_layout, B_layout, vnniFactor>(C, A, B, q);
matrix_multiply<Tc, Tab, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK, A_layout,
B_layout>(C, A, B, q);
assert(matrix_compare(MATRIX_M, MATRIX_N, C, D));
std::cout << "passed" << std::endl;

Expand All @@ -144,3 +144,30 @@ void test() {
free(C, q);
free(D, q);
}

template <layout A_layout, layout B_layout> void test_all() {
std::cout << "bf16: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16, A_layout,
B_layout>();
std::cout << "half: ";
test<half, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16, A_layout,
B_layout>();
std::cout << "int8: ";
test<int8_t, int32_t, 1024 + 14, 1024, 1024 + 24, 8, 16, 32, A_layout,
B_layout>();

// unaligned k:
std::cout << "bf16: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16, A_layout,
B_layout>();
std::cout << "half: ";
test<half, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16, A_layout,
B_layout>();

// row major A fails, so disabled. CMPLRLLVM-65239
if constexpr (A_layout != layout::row_major) {
std::cout << "int8: ";
test<int8_t, int32_t, 1024 + 14, 1024, 1024 + 14, 8, 16, 32, A_layout,
B_layout>();
}
}
15 changes: 4 additions & 11 deletions sycl/test-e2e/Matrix/SG32/joint_matrix_out_bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,8 @@
#include "joint_matrix_out_bounds_impl.hpp"

int main() {
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::row_major, layout::row_major, 1>();
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::row_major, layout::ext_intel_packed, 2>();

// unaligned k:
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::row_major, layout::row_major, 1>();
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::row_major, layout::ext_intel_packed, 2>();

std::cout << "A row major, B row major:\n";
test_all<layout::row_major, layout::row_major>();
std::cout << "A row major, B packed:\n";
test_all<layout::row_major, layout::ext_intel_packed>();
}
18 changes: 4 additions & 14 deletions sycl/test-e2e/Matrix/joint_matrix_out_bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,8 @@
#include "joint_matrix_out_bounds_impl.hpp"

int main() {
std::cout << "bf16 A row major, B row major: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::row_major, layout::row_major, 1>();
std::cout << "bf16 A row major, B packed: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::row_major, layout::ext_intel_packed, 2>();

// unaligned k:
std::cout << "bf16 A row major, B row major: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::row_major, layout::row_major, 1>();
std::cout << "bf16 A row major, B packed: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::row_major, layout::ext_intel_packed, 2>();
std::cout << "A row major, B row major:\n";
test_all<layout::row_major, layout::row_major>();
std::cout << "A row major, B packed:\n";
test_all<layout::row_major, layout::ext_intel_packed>();
}
22 changes: 2 additions & 20 deletions sycl/test-e2e/Matrix/joint_matrix_out_bounds_colmajor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,6 @@
#include "joint_matrix_out_bounds_impl.hpp"

int main() {
std::cout << "bf16 A col major, B col major: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::col_major, layout::col_major, 1>();
std::cout << "half A col major, B col major: ";
test<half, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16, layout::col_major,
layout::col_major, 1>();
std::cout << "int8 A col major, B col major: ";
test<int8_t, int32_t, 1024 + 14, 1024, 1024 + 24, 8, 16, 32,
layout::col_major, layout::col_major, 1>();

// unaligned k:
std::cout << "bf16 A col major, B col major: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::col_major, layout::col_major, 1>();
std::cout << "half A col major, B col major: ";
test<half, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16, layout::col_major,
layout::col_major, 1>();
std::cout << "int8 A col major, B col major: ";
test<int8_t, int32_t, 1024 + 14, 1024, 1024 + 14, 8, 16, 32,
layout::col_major, layout::col_major, 1>();
std::cout << "A col major, B col major:\n";
test_all<layout::col_major, layout::col_major>();
}

0 comments on commit d6e6847

Please sign in to comment.