Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4197,9 +4197,53 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
}

static __device__ __forceinline__ dfloat2 dfmul2(dfloat2 a, dfloat2 b) {
#ifdef GGML_CUDA_F16
return __hmul2(a, b);
#else
return make_float2(a.x * b.x, a.y * b.y);
#endif
}

static __device__ __forceinline__ float2 dfloat22float2(dfloat2 a) {
#ifdef GGML_CUDA_F16
return __half22float2(a);
#else
return a;
#endif
}

static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i*4 >= k) {
return;
}

const int ib = i/(QK4_0/4);
const int iqs = i%(QK4_0/4);

const block_q4_0 * x = (const block_q4_0 *) vx;
const uchar2 qs = *(const uchar2 *)(x[ib].qs + iqs*2);
const dfloat d = x[ib].d;

dfloat2 dv0;
dv0.x = (int)(qs.x & 0xf) - 8;
dv0.y = (int)(qs.y & 0xf) - 8;
Copy link
Contributor

@Engininja2 Engininja2 Aug 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HIP/ROCm treats the x and y variables of a half2 as shorts, so I think this would work better, and then the same change for dv1 just below this.

    #ifdef GGML_CUDA_F16
    dv0 = __halves2half2((int)(qs.x & 0xf) - 8, (int)(qs.y & 0xf) - 8);
    #else
    dv0.x = (int)(qs.x & 0xf) - 8;
    dv0.y = (int)(qs.y & 0xf) - 8;
    #endif

edit: replaced make_half2 with __halves2half2 which has been part of the CUDA API for longer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Introduced a make_dfloat2 macro to create the proper dfloat2 (half2 or float2)

const float2 v0 = dfloat22float2(dfmul2(dv0, {d, d}));
*(float2 *)(y + ib*QK4_0 + iqs*2) = v0;

dfloat2 dv1;
dv1.x = (int)(qs.x >> 4) - 8;
dv1.y = (int)(qs.y >> 4) - 8;
const float2 v1 = dfloat22float2(dfmul2(dv1, {d, d}));
*(float2 *)(y + ib*QK4_0 + QK4_0/2 + iqs*2) = v1;
}

static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
GGML_ASSERT(k % 4 == 0);
const int num_blocks = (k/4 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block_q4_0<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}

static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
Expand Down