Skip to content
119 changes: 108 additions & 11 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ enum vk_conv_shapes {
CONV_SHAPE_128x128,
CONV_SHAPE_64x32,
CONV_SHAPE_32x256,
CONV_SHAPE_64x128,
CONV_SHAPE_COUNT,
};

Expand All @@ -415,6 +416,7 @@ vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = {
{ 128, 128, 16 }, // CONV_SHAPE_128x128
{ 64, 32, 32 }, // CONV_SHAPE_64x32
{ 32, 256, 16 }, // CONV_SHAPE_32x256
{ 64, 128, 16 }, // CONV_SHAPE_64x128
};

enum dmmv_wg_sizes {
Expand Down Expand Up @@ -450,14 +452,16 @@ struct vk_fa_pipeline_state {
};

struct vk_conv2d_pipeline_state {
vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)
: s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}
vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH, uint32_t aligned)
: s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH), aligned(aligned) {}

uint32_t s0, s1, p0, p1, d0, d1, KW, KH;
// when set, shader can skip K/CRS/NPQ bounds checks and address clamps
uint32_t aligned;

bool operator<(const vk_conv2d_pipeline_state &b) const {
return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <
std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);
return std::tie(s0, s1, p0, p1, d0, d1, KW, KH, aligned) <
std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH, b.aligned);
}
};

Expand Down Expand Up @@ -4824,7 +4828,8 @@ static void ggml_vk_load_shaders(vk_device& device) {

// conv2d, conv_transpose_2d
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
uint32_t conv2d_WG_SIZE = 256;
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
uint32_t conv2d_TS_K = (s == CONV_SHAPE_64x32) ? 4 : 8;
uint32_t conv2d_SHMEM_PAD = 4;
Expand Down Expand Up @@ -4863,18 +4868,77 @@ static void ggml_vk_load_shaders(vk_device& device) {
conv2d_BS.CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
}

uint32_t conv2d_shmem_req =
(conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
// cm1 is used only when cm2 is unavailable; capped at 64x128 (due to shared memory size).
// Requires 16x16x16 f16-acc since that's the fragment shape hard-coded in the shader.
// Subgroup size must be 32 or 64 (to keep WG_SIZE sane) and we need
// subgroup_size_control to force the driver to actually use it.
bool conv2d_use_cm1 = false;
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
conv2d_use_cm1 = !device->coopmat2 &&
device->coopmat_support && device->coopmat_support_16x16x16_f16acc &&
device->subgroup_size_control &&
(device->subgroup_size == 32 || device->subgroup_size == 64) &&
s != CONV_SHAPE_128x128;
#endif

const uint32_t conv2d_cm1_shmem_pad = 8;

auto shmem_req = [&](uint32_t pad, bool csh_store, bool fp16_shmem) {
const uint32_t elem_size = fp16_shmem ? (uint32_t)sizeof(uint16_t) : (uint32_t)sizeof(float);
const uint32_t csh_elems = csh_store ? conv2d_BS.K * conv2d_BS.NPQ : 0u;
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
};

// coopmat1 needs to store the output through shared memory, so check up front
// whether it'll fit and disable it before applying coopmat1 parameters.
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
conv2d_use_cm1 = false;
}

uint32_t conv2d_WM = 16, conv2d_WN = 16; // cm1 subgroup tile, ignored otherwise
if (conv2d_use_cm1) {
conv2d_SHMEM_PAD = conv2d_cm1_shmem_pad;
// 16x16x16 fragments; pick WM/WN to keep WG_SIZE at 256
// (i.e. 8 subgroups for sg=32, 4 subgroups for sg=64).
const bool sg64 = (device->subgroup_size == 64);
switch (s) {
case CONV_SHAPE_64x32: conv2d_WM = sg64 ? 32 : 16; conv2d_WN = 16; break;
case CONV_SHAPE_64x128: conv2d_WM = 32; conv2d_WN = sg64 ? 64 : 32; break;
case CONV_SHAPE_32x256: conv2d_WM = sg64 ? 16 : 32; conv2d_WN = sg64 ? 128 : 32; break;
default: break;
}
const uint32_t warps_M = conv2d_BS.K / conv2d_WM;
const uint32_t warps_N = conv2d_BS.NPQ / conv2d_WN;
conv2d_WG_SIZE = warps_M * warps_N * device->subgroup_size;
}

// stage cm2 accumulator through shmem for coalesced global stores;
// skipped on 128x128 where the extra Csh footprint hurts occupancy.
// cm1 always uses the staged path.
uint32_t conv2d_csh_store = (device->coopmat2 && s != CONV_SHAPE_128x128) ? 1u : 0u;
if (conv2d_use_cm1) {
conv2d_csh_store = 1;
}

// shmem is fp16 on cm2/cm1 (matches Csh), fp32 on scalar
const bool conv2d_use_fp16_shmem = device->coopmat2 || conv2d_use_cm1;

// shrink CRS if the non-cm1 config still doesn't fit
if (device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_SHMEM_PAD, conv2d_csh_store, conv2d_use_fp16_shmem)) {
GGML_ASSERT(!conv2d_use_cm1);
conv2d_BS.CRS = 8;
if (use_collectives) {
conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS);
}
conv2d_csh_store = 0;
}

std::array<uint32_t, 3> wg_denoms = { conv2d_BS.K, 1, 1 };
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };

// cm1 needs a fixed subgroup width to match the WG_SIZE we computed
const uint32_t conv2d_required_subgroup_size = conv2d_use_cm1 ? device->subgroup_size : 0;

#define CREATE_CONV(name, type_suffix, spv_suffix) \
for (auto &c : device->pipeline_##name##type_suffix[s]) { \
const vk_conv2d_pipeline_state &state = c.first; \
Expand All @@ -4887,10 +4951,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
spec_constants_cpy.push_back(state.d1); \
spec_constants_cpy.push_back(state.KW); \
spec_constants_cpy.push_back(state.KH); \
spec_constants_cpy.push_back(state.aligned); \
spec_constants_cpy.push_back(conv2d_csh_store); \
spec_constants_cpy.push_back(conv2d_WM); \
spec_constants_cpy.push_back(conv2d_WN); \
ggml_vk_create_pipeline( \
device, c.second, #name #type_suffix, \
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives || conv2d_required_subgroup_size, conv2d_required_subgroup_size); \
}
#define CREATE_CONVS(spv_suffix) \
CREATE_CONV(conv2d, _f32, spv_suffix) \
Expand All @@ -4901,6 +4969,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->coopmat2) {
CREATE_CONVS(_cm2)
} else
#endif
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (conv2d_use_cm1) {
CREATE_CONVS(_cm1)
} else
#endif
if (conv2d_UNROLL) {
CREATE_CONVS(_unroll)
Expand Down Expand Up @@ -9346,10 +9419,23 @@ static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, u
// so small convolutions will still choose a smaller tile.
const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;

if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {
// 128x128 isn't used with cm1 due to shared memory size; fall through to a smaller tile.
bool allow_128x128 = true;
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (!ctx->device->coopmat2 && ctx->device->coopmat_support && ctx->device->coopmat_support_16x16x16_f16acc) {
allow_128x128 = false;
}
#endif

if (allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {
return CONV_SHAPE_128x128;
} else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) {
return CONV_SHAPE_32x256;
} else if (K <= 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) {
return CONV_SHAPE_64x128;
} else if (!allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) {
// cm1 fallback for large K when 128x128 isn't available
return CONV_SHAPE_64x128;
} else {
return CONV_SHAPE_64x32;
}
Expand Down Expand Up @@ -9876,7 +9962,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0;
uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1;
uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1;
vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);

// tile-aligned shapes let the shader skip bounds checks
const uint32_t Cin = (uint32_t)src1->ne[2];
const uint32_t CRS = Cin * KW * KH;
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
const uint32_t aligned = ((K % BS_K == 0) &&
(CRS % BS_CRS == 0) &&
(NPQ % BS_NPQ == 0)) ? 1u : 0u;

vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH, aligned);

std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
if (op == GGML_OP_CONV_2D) {
Expand Down
Loading
Loading