Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gpu: nvidia: Add support for cublaslt matmul #1972

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions cmake/FindcublasLt.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# ===============================================================================
# Copyright 2020-2024 Intel Corporation
# Copyright 2020-2024 Codeplay Software Limited
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
# ===============================================================================

find_package(CUDA 10.0 REQUIRED)
find_package(Threads REQUIRED)

find_path(
CUBLASLT_INCLUDE_DIR "cublasLt.h"
HINTS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES include)

find_library(CUDA_DRIVER_LIBRARY cuda)

find_library(
CUBLAS_LIBRARY cublasLt
HINTS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 bin)

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(
densamoilov marked this conversation as resolved.
Show resolved Hide resolved
cublasLt REQUIRED_VARS CUBLASLT_INCLUDE_DIR CUDA_INCLUDE_DIRS CUBLAS_LIBRARY
CUDA_LIBRARIES CUDA_DRIVER_LIBRARY)

if(NOT TARGET cublasLt::cublasLt)
add_library(cublasLt::cublasLt SHARED IMPORTED)
set_target_properties(
cublasLt::cublasLt
PROPERTIES IMPORTED_LOCATION ${CUBLAS_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES
"${CUBLASLT_INCLUDE_DIR};${CUDA_INCLUDE_DIRS}"
INTERFACE_LINK_LIBRARIES
"Threads::Threads;${CUDA_DRIVER_LIBRARY};${CUDA_LIBRARIES}"
INTERFACE_COMPILE_DEFINITIONS CUDA_NO_HALF)
endif()
3 changes: 2 additions & 1 deletion cmake/SYCL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ endmacro()
if(DNNL_SYCL_CUDA)
suppress_warnings_for_nvidia_target()
find_package(cuBLAS REQUIRED)
find_package(cublasLt REQUIRED)
find_package(cuDNN REQUIRED)

adjust_headers_priority("cuBLAS::cuBLAS;cuDNN::cuDNN")
adjust_headers_priority("cuBLAS::cuBLAS;cuDNN::cuDNN;cublasLt::cublasLt")
add_definitions_with_host_compiler("-DCUDA_NO_HALF")

list(APPEND EXTRA_SHARED_LIBS cuBLAS::cuBLAS cuDNN::cuDNN)
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,7 @@ struct memory : public handle<dnnl_memory_t> {
AB16b64a2b = dnnl_AB16b64a2b,
Ab4a = dnnl_Ab4a,
Ab8a = dnnl_Ab8a,
Ab32a = dnnl_Ab32a,
Abc16a = dnnl_Abc16a,
ABc16a16b = dnnl_ABc16a16b,
ABc4a4b = dnnl_ABc4a4b,
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ typedef enum {
dnnl_bcad,
dnnl_cabd,
dnnl_dabc,
dnnl_Ab32a,

/// Just a sentinel, not real memory format tag. Must be changed after new
/// format tag is added.
Expand Down
2 changes: 2 additions & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ const format_kind_t sparse = static_cast<format_kind_t>(4);
const format_kind_t internal_only_start = (format_kind_t)(1 << 8);
const format_kind_t wino = internal_only_start;
const format_kind_t rnn_packed = (format_kind_t)(internal_only_start + 1);
const format_kind_t cublaslt_blocked = (format_kind_t)(internal_only_start + 2);
} // namespace format_kind

#ifdef DNNL_EXPERIMENTAL_PROFILING
Expand Down Expand Up @@ -372,6 +373,7 @@ const format_tag_t aCB16b64c4b = dnnl_aCB16b64c4b;

const format_tag_t Ab4a = dnnl_Ab4a;
const format_tag_t Ab8a = dnnl_Ab8a;
const format_tag_t Ab32a = dnnl_Ab32a;
const format_tag_t Abc16a = dnnl_Abc16a;
const format_tag_t ABc16a16b = dnnl_ABc16a16b;
const format_tag_t ABc4a2b = dnnl_ABc4a2b;
Expand Down
4 changes: 3 additions & 1 deletion src/common/dnnl_debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ const char *dnnl_fmt_kind2str(dnnl_format_kind_t v) {
#ifdef DNNL_EXPERIMENTAL_SPARSE
if (v == dnnl_format_kind_sparse) return "sparse";
#endif
if (v == format_kind::wino || v == format_kind::rnn_packed) return "opaque";
if (v == format_kind::wino || v == format_kind::rnn_packed
|| v == format_kind::cublaslt_blocked)
return "opaque";
if (v == dnnl_format_kind_max) return "max";
assert(!"unknown fmt_kind");
return "unknown fmt_kind";
Expand Down
1 change: 1 addition & 0 deletions src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_AcdeB4b8a4b) return "AcdeB4b8a4b";
if (v == dnnl_Ab4a) return "Ab4a";
if (v == dnnl_Ab8a) return "Ab8a";
if (v == dnnl_Ab32a) return "Ab32a";
if (v == dnnl_BA4b4a) return "BA4b4a";
if (v == dnnl_BA8b4a) return "BA8b4a";
if (v == dnnl_BA2a24b) return "BA2a24b";
Expand Down
1 change: 1 addition & 0 deletions src/common/memory_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ status_t dnnl_memory_desc_query(
case query::format_kind:
switch ((int)md->format_kind) {
case format_kind::rnn_packed:
case format_kind::cublaslt_blocked:
case format_kind::wino:
*(format_kind_t *)result = format_kind::opaque;
break;
Expand Down
11 changes: 10 additions & 1 deletion src/common/memory_desc.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2023 Intel Corporation
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,8 @@
namespace dnnl {
namespace impl {

enum class cublaslt_memory_format_t { col32_2r_4r4 };

// Winograd-specific formats
enum class wino_memory_format_t {
// Undefined memory format, used for empty memory descriptors.
Expand Down Expand Up @@ -135,6 +137,11 @@ struct rnn_packed_desc_t {
size_t size;
};

struct cublaslt_blocked_desc_t {
cublaslt_memory_format_t cublaslt_format;
size_t size;
};

struct sparse_desc_t {
static constexpr int max_metadata_types = 2;
// Each encoding defines the number of handles it requires and their
Expand Down Expand Up @@ -289,6 +296,8 @@ struct dnnl_memory_desc : public dnnl::impl::c_compatible {
dnnl::impl::wino_desc_t wino_desc;
// Tensor of packed weights for RNN.
dnnl::impl::rnn_packed_desc_t rnn_packed_desc;
// Description of the data layout for memory formats used in cublasLt IMMA kernels.
dnnl::impl::cublaslt_blocked_desc_t cublaslt_blocked_desc;
// Description of the sparse encodings.
dnnl::impl::sparse_desc_t sparse_desc;
// ... other descriptions possible
Expand Down
1 change: 1 addition & 0 deletions src/common/memory_desc_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ status_t memory_desc_wrapper::compute_blocking(

C(Ab4a, {0, 1}, {4}, {0});
C(Ab8a, {0, 1}, {8}, {0});
C(Ab32a, {0, 1}, {32}, {0});

C(BA4b4a, {1, 0}, {4, 4}, {1, 0});
C(BA8b4a, {1, 0}, {8, 4}, {1, 0});
Expand Down
15 changes: 13 additions & 2 deletions src/common/memory_desc_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ struct memory_desc_wrapper : public c_compatible {
bool is_rnn_packed_desc() const {
return format_kind() == format_kind::rnn_packed;
}
bool is_cublaslt_blocked_desc() const {
return format_kind() == format_kind::cublaslt_blocked;
}
bool is_sparse_desc() const { return format_kind() == format_kind::sparse; }

const blocking_desc_t &blocking_desc() const {
Expand All @@ -82,6 +85,10 @@ struct memory_desc_wrapper : public c_compatible {
assert(is_rnn_packed_desc());
return md_->format_desc.rnn_packed_desc;
}
const cublaslt_blocked_desc_t &cublaslt_blocked_desc() const {
assert(is_cublaslt_blocked_desc());
return md_->format_desc.cublaslt_blocked_desc;
}

const sparse_desc_t &sparse_desc() const {
assert(is_sparse_desc());
Expand Down Expand Up @@ -224,7 +231,8 @@ struct memory_desc_wrapper : public c_compatible {
return 0;

if (utils::one_of(format_kind(), format_kind::blocked,
format_kind::wino, format_kind::rnn_packed)
format_kind::wino, format_kind::rnn_packed,
format_kind::cublaslt_blocked)
&& index != 0) {
return 0;
}
Expand All @@ -235,6 +243,8 @@ struct memory_desc_wrapper : public c_compatible {
return wino_desc().size;
} else if (is_rnn_packed_desc()) {
return rnn_packed_desc().size;
} else if (is_cublaslt_blocked_desc()) {
return cublaslt_blocked_desc().size;
} else if (is_blocking_desc()) {
if (offset0() != 0) return 0;

Expand Down Expand Up @@ -581,7 +591,8 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,

if (one_of(format_kind(), format_kind::undef, format_kind::any))
return false;
if (is_wino_desc() || is_rnn_packed_desc()) return false;
if (is_wino_desc() || is_rnn_packed_desc() || is_cublaslt_blocked_desc())
return false;

const int ds = dim_start;
const auto &blk = blocking_desc();
Expand Down
7 changes: 7 additions & 0 deletions src/common/memory_tracking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,14 @@ enum {
key_lnorm_tmp_diff_ss,
key_lnorm_reduction,
key_matmul_dst_in_acc_dt,
key_matmul_lt_algo_scratch,
key_matmul_lt_block_c,
key_matmul_src_trans,
key_matmul_wei_trans,
key_matmul_dst_trans,
key_matmul_dst_cast_acc,
key_matmul_lt_src_scale,
key_matmul_lt_wei_scale,
key_matmul_sparse_tmp_ptr,
key_pool_dst_bf16cvt,
key_pool_dst_plain2blocked_cvt,
Expand All @@ -282,6 +286,9 @@ enum {
key_reorder_rnn_weights_reduction,
key_reorder_rnn_weights_transposition,
key_reorder_rnn_weights_xf16_cvt,
key_reorder_cublaslt_src_float,
key_reorder_cublaslt_dst_float,
key_reorder_cublaslt_generic,
key_rnn_space,
key_rnn_bf32_attention_trans,
key_rnn_bf32_wei_layer_trans,
Expand Down
7 changes: 7 additions & 0 deletions src/common/primitive_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ size_t get_md_hash(const memory_desc_t &md) {
seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale);
seed = hash_combine(seed, md.format_desc.wino_desc.size);
break;
case format_kind::cublaslt_blocked:
seed = hash_combine(seed,
static_cast<size_t>(md.format_desc.cublaslt_blocked_desc
.cublaslt_format));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you skip hashing size intentionally?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not on purpose, thank you for picking this up, adding the change.

seed = hash_combine(
seed, (md.format_desc.cublaslt_blocked_desc.size));
break;
case format_kind::rnn_packed:
seed = hash_combine(seed,
static_cast<size_t>(md.format_desc.rnn_packed_desc.format));
Expand Down
4 changes: 4 additions & 0 deletions src/common/reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
engine_t *engine, const memory_desc_t *src_md,
const memory_desc_t *dst_md, const primitive_attr_t *attr = nullptr);

status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
engine_t *engine, const memory_desc_t *src_md, engine_t *src_engine,
const memory_desc_t *dst_md, engine_t *dst_engine,
const primitive_attr_t *attr = nullptr);
} // namespace impl
} // namespace dnnl

Expand Down
5 changes: 5 additions & 0 deletions src/common/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ void serialize_md(serialization_stream_t &sstream, const memory_desc_t &md) {
sstream.write(&md.format_desc.wino_desc.adj_scale);
sstream.write(&md.format_desc.wino_desc.size);
break;
case format_kind::cublaslt_blocked:
sstream.write(
&md.format_desc.cublaslt_blocked_desc.cublaslt_format);
sstream.write(&md.format_desc.cublaslt_blocked_desc.size);
break;
case format_kind::rnn_packed:
sstream.write(&md.format_desc.rnn_packed_desc.format);
sstream.write(&md.format_desc.rnn_packed_desc.n_parts);
Expand Down
4 changes: 4 additions & 0 deletions src/common/type_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,10 @@ inline bool wino_desc_is_equal(const wino_desc_t &lhs, const wino_desc_t &rhs) {
&& lhs.ic2_block == rhs.ic2_block && lhs.oc2_block == rhs.oc2_block
&& lhs.r == rhs.r;
}
inline bool cublaslt_blocked_desc_is_equal(const cublaslt_blocked_desc_t &lhs,
const cublaslt_blocked_desc_t &rhs) {
return lhs.cublaslt_format == rhs.cublaslt_format && lhs.size == rhs.size;
}

inline bool rnn_packed_desc_is_equal(
const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) {
Expand Down
9 changes: 9 additions & 0 deletions src/common/verbose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,14 @@ std::string rnn_flags2str(unsigned flags) {
return s;
}

std::string cublasltfmt2str(const memory_desc_t *md) {
if (md->format_desc.cublaslt_blocked_desc.cublaslt_format
== cublaslt_memory_format_t::col32_2r_4r4) {
return ":col32_2r_4r4";
}
return "";
}

std::ostream &operator<<(std::ostream &ss, const memory_extra_desc_t &extra) {
using namespace memory_extra_flags;

Expand Down Expand Up @@ -514,6 +522,7 @@ std::string md2fmt_str(
case format_kind::blocked:
ss << ":" << md2fmt_tag_str(md) << ":" << md2fmt_strides_str(md);
break;
case format_kind::cublaslt_blocked: ss << cublasltfmt2str(md); break;
case format_kind::wino:
case format_kind::rnn_packed:
case format_kind::opaque: ss << "::"; break;
Expand Down
18 changes: 14 additions & 4 deletions src/gpu/generic/sycl/binary_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ struct binary_kernel_vec_t {
== src1_mem.md().strides()[i]);
}
}
if (!any_broadcast && conf_.post_ops.get_post_op() == 0

const bool is_blocked_fmt = conf_.src0_md.inner_nblks() > 0
|| conf_.src1_md.inner_nblks() > 0
|| conf_.dst_md.inner_nblks() > 0;

if (!any_broadcast && !is_blocked_fmt
&& conf_.post_ops.get_post_op() == 0
&& sg_base_idx + (sg.get_local_range()[0] * conf_.block_size)
< conf_.wk_size
&& is_same_tag) {
Expand All @@ -114,8 +120,12 @@ struct binary_kernel_vec_t {
for (int i = 0; i < conf_.block_size; i++) {
int idx = base_idx + i;
if (idx < conf_.wk_size) {
for (int i = 0; i < max_supported_ndims; i++) {
off_dst[i] = idx / strides[i] % dims[i];
auto l_offset = idx;
for (int i = 0; i < conf_.ndims; i++) {
const int d = conf_.ndims - 1 - i;
const dim_t cur_dim = conf_.dst_md.dims()[d];
off_dst[d] = l_offset % cur_dim;
l_offset = l_offset / cur_dim;
}

for (int i = 0; i < max_supported_ndims; i++) {
Expand All @@ -133,7 +143,7 @@ struct binary_kernel_vec_t {

acc = conf_.post_ops.apply(
acc, dst_, idx, po_args_, off_dst);
dst_mem.store(acc, idx);
dst_mem.store_md(acc, off_dst);
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/gpu/generic/sycl/ref_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ status_t ref_binary_t::init(impl::engine_t *engine) {

status_t ref_binary_t::execute(const exec_ctx_t &ctx) const {

ctx.zero_pad_output(DNNL_ARG_TO);

parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
binary_kernel_vec_t binary_kernel(pd()->conf_, cgh, ctx);

Expand Down
5 changes: 4 additions & 1 deletion src/gpu/generic/sycl/ref_binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,11 @@ struct ref_binary_t : public gpu::generic::sycl::primitive_t {
using namespace format_tag;

for (const auto &mdw : {src0, src1, dst}) {
if (!mdw.is_plain()) { return false; }
if (!(mdw.is_plain() || mdw.matches_tag(format_tag::Ab32a)
|| mdw.matches_tag(format_tag::aBc32b)))
return false;
}

return true;
}
};
Expand Down
Loading