Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 2 additions & 77 deletions cmake/patches/cutlass/cutlass_3.5.0.patch
Original file line number Diff line number Diff line change
@@ -1,80 +1,5 @@
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
index 4c80f549..5ad610c8 100644
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
@@ -189,6 +189,7 @@ struct AttentionKernel {

// Scale
accum_t scale = 0.0;
+ accum_t softcap = 0.0;

// Dimensions/strides
int32_t head_dim = 0;
@@ -221,6 +222,8 @@ struct AttentionKernel {
int32_t num_batches = 0;
int32_t num_heads = 0;

+ bool use_smooth_softmax = false;
+
// dropout
bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset = 0;
@@ -818,6 +821,15 @@ struct AttentionKernel {
accum =
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
}
+
+ // apply softcap if applicable
+ if (p.softcap > 0.0) {
+ accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(1.0 / p.softcap, accum);
+ for (int i = 0; i < accum.size(); ++i) {
+ accum[i] = cutlass::fast_tanh(accum[i]);
+ }
+ accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.softcap, accum);
+ }

// apply attention bias if applicable
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
@@ -897,7 +909,8 @@ struct AttentionKernel {
p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
- kSupportsBias ? 1.0f : p.scale);
+ kSupportsBias ? 1.0f : p.scale,
+ p.use_smooth_softmax);

// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
@@ -1166,7 +1179,8 @@ struct AttentionKernel {
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
- float scaling) {
+ float scaling,
+ bool use_smooth_softmax) {
/* Iterates on the accumulator and corresponding position on result matrix

(1) Update `mi[r]` to the max value of the row `r`
@@ -1257,7 +1271,7 @@ struct AttentionKernel {
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
- [&](int accum_m) { mi_row = mi[accum_m]; },
+ [&](int accum_m) { mi_row = mi[accum_m];},
[&](int accum_m, int accum_n, int idx) {
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
@@ -1294,7 +1308,7 @@ struct AttentionKernel {
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
- s_prime[id] = total_row;
+ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
}
}

diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
index 964d2ff3..676ba768 100644
index 964d2ff3..b366bc14 100644
--- a/include/cutlass/functional.h
+++ b/include/cutlass/functional.h
@@ -39,6 +39,7 @@
Expand All @@ -97,4 +22,4 @@ index 964d2ff3..676ba768 100644
+#endif
#else
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
#endif
#endif
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ Status EfficientAttention(
p.seqstart_q_ptr = nullptr;
p.seqstart_k_ptr = nullptr;
} else {
p.seqlen_k_ptr = const_cast<int32_t*>(reinterpret_cast<const int32_t*>(data.mask_index));
p.seqlen_k_ptr = reinterpret_cast<const int32_t*>(data.mask_index);
p.seqstart_q_ptr = p.seqlen_k_ptr + parameters.batch_size;
p.seqstart_k_ptr = p.seqlen_k_ptr + 2 * parameters.batch_size + 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#endif

#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "41_fused_multi_head_attention/kernel_forward.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h"

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -223,6 +223,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
}

p.use_smooth_softmax = params.use_smooth_softmax;
p.window_size = params.local_window_size;
}

auto kernel_fn = attention_kernel_batched_impl<Attention>;
Expand Down
Loading