Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/cpu/cpu_attn_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,8 @@ class AttentionMainLoop {
if (sliding_window_left != -1) {
pos = std::max(pos, curr_token_pos - sliding_window_left);
}
return pos;
// Clamp to tile end to avoid OOB when window starts past the tile
return std::min(pos, kv_tile_end_pos);
}();

int32_t right_kv_pos = [&]() {
Expand Down
19 changes: 17 additions & 2 deletions csrc/cpu/cpu_attn_neon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include "cpu_attn_impl.hpp"
#include <arm_neon.h>
#include <type_traits>
#ifdef ARM_BF16_SUPPORT
#include "cpu_attn_neon_bfmmla.hpp"
#endif
namespace cpu_attention {

namespace {
Expand Down Expand Up @@ -57,7 +60,7 @@ FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
#endif
}

// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with NEON FMLAs
// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with ASIMD FMLAs
// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2)
// #FMLAs = (K // 4) * (4 * 2 * M)
// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads
Expand Down Expand Up @@ -381,6 +384,18 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
}
}
};

#ifdef ARM_BF16_SUPPORT
// For BF16 on Arm, reuse the BFMMLA kernels with 32-token alignment.
template <int64_t head_dim>
class AttentionImpl<ISA::NEON, c10::BFloat16, head_dim>
: public AttentionImplNEONBFMMLA<BLOCK_SIZE_ALIGNMENT, ISA::NEON,
head_dim> {};
#endif
} // namespace cpu_attention

#endif // #ifndef CPU_ATTN_NEON_HPP
#undef BLOCK_SIZE_ALIGNMENT
#undef HEAD_SIZE_ALIGNMENT
#undef MAX_Q_HEAD_NUM_PER_ITER

#endif // #ifndef CPU_ATTN_ASIMD_HPP
Loading