Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused ROPE and reshape cache kernel #229

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
union with float4
Aleksandr Malyshev committed Dec 4, 2024
commit d14c054ec4006eded4bdc5241d92cce127c009ba
36 changes: 23 additions & 13 deletions csrc/rocm/fused_rope_and_reshape_cache.cu
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_vector_types.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#include "quantization/fp8/amd/hip_float8.h"
@@ -27,6 +28,12 @@ using __nv_bfloat162 = __hip_bfloat162;
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif

#if __cplusplus
#if defined(_MSC_VER)
static_assert(false);
#endif
#endif

namespace {

template <typename scalar_t, int width>
@@ -43,64 +50,67 @@ __device__ void apply_rope(scalar_t* __restrict__ arr_ptr,

template <typename scalar_t, int width>
struct __align__(16) vec_t {
scalar_t data[width];
union {
float4 fdata[2];
scalar_t data[width];
} uvec;

__device__ vec_t() = default;
__device__ vec_t(const scalar_t (& _data)[width]){
#pragma unroll
for (int i = 0; i < width; ++i) data[i] = _data[i];
uvec.fdata[0] = *reinterpret_cast<float4 *>(&_data);
uvec.fdata[1] = *reinterpret_cast<float4 *>(&_data + width / 2);
}
__device__ vec_t(const vec_t<scalar_t, width>& other) {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] = other.data[i];
uvec.fdata[0] = other.uvec.fdata[0];
uvec.fdata[1] = other.uvec.fdata[2];
}

__device__ vec_t operator*(const vec_t& other) const {
vec_t<scalar_t, width> tmp{*this};
#pragma unroll
for (int i = 0; i < width; ++i) tmp.data[i] *= other.data[i];
for (int i = 0; i < width; ++i) tmp.uvec.data[i] *= other.uvec.data[i];
return tmp;
}

__device__ vec_t operator*(const float& scale) const {
vec_t<scalar_t, width> tmp{*this};
#pragma unroll
for (int i = 0; i < width; ++i) tmp.data[i] *= scale;
for (int i = 0; i < width; ++i) tmp.uvec.data[i] *= scale;
return tmp;
}

__device__ vec_t operator+(const vec_t& other) const {
vec_t<scalar_t, width> tmp{*this};
#pragma unroll
for (int i = 0; i < width; ++i) tmp.data[i] += other.data[i];
for (int i = 0; i < width; ++i) tmp.uvec.data[i] += other.uvec.data[i];
return tmp;
}

__device__ vec_t operator-(const vec_t& other) const {
vec_t<scalar_t, width> tmp{*this};
#pragma unroll
for (int i = 0; i < width; ++i) tmp.data[i] -= other.data[i];
for (int i = 0; i < width; ++i) tmp.uvec.data[i] -= other.uvec.data[i];
return tmp;
}

__device__ vec_t<scalar_t, width>& operator=(const vec_t& other) {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] = other.data[i];
for (int i = 0; i < width; ++i) uvec.data[i] = other.uvec.data[i];
return *this;
}

__device__ vec_t<scalar_t, width>& operator+=(const vec_t& other) {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] += other.data[i];
for (int i = 0; i < width; ++i) uvec.data[i] += other.uvec.data[i];
return *this;
}

__device__ scalar_t& operator [](const size_t& idx) {
return data[idx];
return uvec.data[idx];
}

__device__ scalar_t operator [](const size_t& idx) const {
return data[idx];
return uvec.data[idx];
}

friend
Loading