Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
*as_const_int(B->shape[dim_B - 1]),
false, trans_B ? 2 : 1));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
TargetIsSM120(T.target)) {
TargetIsSM100(T.target) || TargetIsSM120(T.target)) {
auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
Expand Down
7 changes: 7 additions & 0 deletions src/target/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ bool TargetIsHopper(Target target) {
return arch >= 90 && arch < 100;
}

bool TargetIsSM100(Target target) {
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 100 && arch < 120;
}

bool TargetIsSM120(Target target) {
if (!TargetIsCuda(target))
return false;
Expand Down
1 change: 1 addition & 0 deletions src/target/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ bool TargetIsVolta(Target target);
bool TargetIsTuring(Target target);
bool TargetIsAmpere(Target target);
bool TargetIsHopper(Target target);
bool TargetIsSM100(Target target);
bool TargetIsSM120(Target target);
bool TargetIsCDNA(Target target);

Expand Down
2 changes: 2 additions & 0 deletions src/tl_templates/cuda/gemm.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200))
#include "gemm_sm120.h"
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1000))
#include "gemm_sm100.h"
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
#include "gemm_sm90.h"
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890))
Expand Down
38 changes: 2 additions & 36 deletions src/tl_templates/cuda/gemm_mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "common.h"
#include "cuda_fp8.h"

namespace cute {
namespace cute::tl_mma {

template <typename A_type, typename B_type, typename C_type, int num_warp_m,
int num_warp_n, int N>
Expand Down Expand Up @@ -421,38 +421,4 @@ class GemmTensorOp {
}
};

} // namespace cute

namespace tl {

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA = cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}

} // namespace tl
} // namespace cute::tl_mma
40 changes: 40 additions & 0 deletions src/tl_templates/cuda/gemm_sm100.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#pragma once

#include "gemm_mma.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Confirm CUTLASS_DEVICE availability via gemm_mma.h

If gemm_mma.h doesn’t include cutlass/cutlass.h (or otherwise define CUTLASS_DEVICE), these wrappers may fail to compile on some TU orders. Please verify.

Run:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Locate gemm_mma.h and verify it (or its includes) define CUTLASS_DEVICE
fd -a gemm_mma.h
rg -n -C2 'CUTLASS_DEVICE' $(fd -a gemm_mma.h || true)

# As a fallback, ensure cutlass/cutlass.h is pulled somewhere above these wrappers
rg -n -C2 'include\s*<cutlass/cutlass\.h>' src || true

Length of output: 189


Let's drill into gemm_mma.h’s include chain and locate the CUTLASS_DEVICE macro definition:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Locate gemm_mma.h
file=$(fd -a gemm_mma.h)

echo "Inspecting includes in $file:"
rg -n '^#\s*include' "$file"

echo -e "\nSearching for CUTLASS_DEVICE definition in the repo:"
rg -n '#\s*define\s*CUTLASS_DEVICE' -C2 .

echo -e "\nSearching for any inclusion of cutlass/cutlass.h (quotes or brackets):"
rg -n 'include\s*[<"]cutlass/cutlass.h[>"]' -C2 .

Length of output: 763


Ensure CUTLASS_DEVICE is defined in gemm_mma.h
The gemm_sm100.h wrappers pull in gemm_mma.h (via line 3) but we didn’t find any CUTLASS_DEVICE definition or an #include <cutlass/cutlass.h> in its include chain. Without this, builds can break depending on translation-unit include order.

Please update src/tl_templates/cuda/gemm_mma.h to guarantee the macro is available, for example by adding at the top:

#include <cutlass/cutlass.h>   // for CUTLASS_DEVICE

or by otherwise defining CUTLASS_DEVICE before it’s used in the wrappers.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_mma.h (top of file), ensure CUTLASS_DEVICE is
always defined by adding an include of Cutlass's main header before any use of
CUTLASS_DEVICE (e.g., add #include <cutlass/cutlass.h> as the first include) or
alternatively add a guarded fallback definition (wrap a #ifndef CUTLASS_DEVICE /
#define CUTLASS_DEVICE __host__ __device__ / #endif) so the macro is available
regardless of translation-unit include order.


namespace tl {

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}

} // namespace tl
37 changes: 37 additions & 0 deletions src/tl_templates/cuda/gemm_sm120.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,40 @@
#pragma once

#include "gemm_mma.h"

namespace tl {

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}

} // namespace tl
37 changes: 37 additions & 0 deletions src/tl_templates/cuda/gemm_sm80.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,40 @@
#pragma once

#include "gemm_mma.h"

namespace tl {

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}

} // namespace tl
37 changes: 37 additions & 0 deletions src/tl_templates/cuda/gemm_sm89.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,40 @@
#include "cuda_fp8.h"

#include "gemm_mma.h"

namespace tl {

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
int offset_b, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}

} // namespace tl
Loading
Loading