@@ -543,6 +543,8 @@ struct vk_device_struct {
543543    vk_pipeline pipeline_relu[2];
544544    vk_pipeline pipeline_tanh[2];
545545    vk_pipeline pipeline_sigmoid[2];
546+     vk_pipeline pipeline_hardsigmoid[2];
547+     vk_pipeline pipeline_hardswish[2];
546548
547549    vk_pipeline pipeline_geglu[2];
548550    vk_pipeline pipeline_reglu[2];
@@ -3324,6 +3326,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
33243326    CREATE_UNARY(relu)
33253327    CREATE_UNARY(tanh)
33263328    CREATE_UNARY(sigmoid)
3329+     CREATE_UNARY(hardsigmoid)
3330+     CREATE_UNARY(hardswish)
33273331#undef CREATE_UNARY
33283332
33293333#define CREATE_GLU(name)  \
@@ -7656,6 +7660,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
76567660                return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
76577661            case GGML_UNARY_OP_SIGMOID:
76587662                return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
7663+             case GGML_UNARY_OP_HARDSIGMOID:
7664+                 return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
7665+             case GGML_UNARY_OP_HARDSWISH:
7666+                 return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
76597667            default:
76607668                break;
76617669        }
@@ -10330,6 +10338,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1033010338        case GGML_UNARY_OP_RELU:
1033110339        case GGML_UNARY_OP_TANH:
1033210340        case GGML_UNARY_OP_SIGMOID:
10341+         case GGML_UNARY_OP_HARDSIGMOID:
10342+         case GGML_UNARY_OP_HARDSWISH:
1033310343            break;
1033410344        default:
1033510345            return false;
@@ -10711,6 +10721,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1071110721        case GGML_UNARY_OP_RELU:
1071210722        case GGML_UNARY_OP_TANH:
1071310723        case GGML_UNARY_OP_SIGMOID:
10724+         case GGML_UNARY_OP_HARDSIGMOID:
10725+         case GGML_UNARY_OP_HARDSWISH:
1071410726            ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
1071510727            break;
1071610728        default:
@@ -10955,6 +10967,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1095510967        case GGML_UNARY_OP_RELU:
1095610968        case GGML_UNARY_OP_TANH:
1095710969        case GGML_UNARY_OP_SIGMOID:
10970+         case GGML_UNARY_OP_HARDSIGMOID:
10971+         case GGML_UNARY_OP_HARDSWISH:
1095810972            buf = tensor->buffer;
1095910973            break;
1096010974        default:
@@ -12105,6 +12119,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1210512119                case GGML_UNARY_OP_RELU:
1210612120                case GGML_UNARY_OP_TANH:
1210712121                case GGML_UNARY_OP_SIGMOID:
12122+                 case GGML_UNARY_OP_HARDSIGMOID:
12123+                 case GGML_UNARY_OP_HARDSWISH:
1210812124                    return ggml_is_contiguous(op->src[0]) &&
1210912125                           (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1211012126                           (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -12921,6 +12937,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1292112937        case GGML_UNARY_OP_SIGMOID:
1292212938            tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
1292312939            break;
12940+         case GGML_UNARY_OP_HARDSIGMOID:
12941+             tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);
12942+             break;
12943+         case GGML_UNARY_OP_HARDSWISH:
12944+             tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
12945+             break;
1292412946        default:
1292512947            std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1292612948            GGML_ABORT("fatal error");
0 commit comments