Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions ggml/src/ggml-cuda/fwht.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows,
float reg[el_w];
const int lane = threadIdx.x;

ggml_cuda_pdl_sync();
#pragma unroll
for (int i = 0; i < el_w; ++i) {
reg[i] = src[i * warp_size + lane] * scale;
Expand Down Expand Up @@ -57,18 +58,18 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows,
}
}

void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src, dst));
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_is_contiguous(dst));
if (!ggml_is_contiguous(src) || !ggml_is_contiguous(dst)) {
return false;
}
const int n = src->ne[0];
const int64_t rows = ggml_nrows(src);

const float * src_d = (const float *) src->data;
float * dst_d = (float *) dst->data;

const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
GGML_ASSERT(n % warp_size == 0);
const int rows_per_block = 4;

const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block;
Expand All @@ -83,26 +84,18 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src,

switch (n) {
case 64:
{
ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale);
return true;
case 128:
{
ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale);
return true;
case 256:
{
ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale);
return true;
case 512:
{
ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale);
break;
}
ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale);
return true;
default:
GGML_ABORT("fatal error");
return false;
}
}
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/fwht.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "common.cuh"

void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst);
// Returns whether the Fast Walsh-Hadamard transform could be used.
bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst);
4 changes: 1 addition & 3 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2596,9 +2596,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;

const int32_t hint = ggml_get_op_params_i32(dst, 1);
if (hint == GGML_HINT_SRC0_IS_HADAMARD) {
GGML_ASSERT(!split);
ggml_cuda_op_fwht(ctx, src1, dst);
if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) {
return;
}

Expand Down
Loading