@@ -549,7 +549,44 @@ ggml_metal_device_t ggml_metal_device_init(void) {
549549 dev->props .has_tensor = false ;
550550 }
551551
552- // try to compile a dummy tensor kernel to determine if the tensor API is supported for bfloat
552+ // double-check that the tensor API compiles
553+ if (dev->props .has_tensor ) {
554+ const char * src_tensor_f16 = " \n "
555+ " #include <metal_stdlib> \n "
556+ " #include <metal_tensor> \n "
557+ " #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n "
558+ " \n "
559+ " using namespace metal; \n "
560+ " using namespace mpp::tensor_ops; \n "
561+ " \n "
562+ " kernel void dummy_kernel( \n "
563+ " tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n "
564+ " tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n "
565+ " uint2 tgid [[threadgroup_position_in_grid]]) \n "
566+ " { \n "
567+ " auto tA = A.slice(0, (int)tgid.y); \n "
568+ " auto tB = B.slice((int)tgid.x, 0); \n "
569+ " \n "
570+ " matmul2d< \n "
571+ " matmul2d_descriptor(8, 8, dynamic_extent), \n "
572+ " execution_thread> mm; \n "
573+ " \n "
574+ " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), half>(); \n "
575+ " \n "
576+ " (void)cT; \n "
577+ " }" ;
578+
579+ GGML_LOG_INFO (" %s : testing tensor API for f16 support\n " , __func__);
580+ ggml_metal_library_t lib = ggml_metal_library_init_from_source (dev, src_tensor_f16, false );
581+ if (lib == NULL ) {
582+ GGML_LOG_WARN (" %s : - the tensor API is not supported in this environment - disabling\n " , __func__);
583+ dev->props .has_tensor = false ;
584+ } else {
585+ ggml_metal_library_free (lib);
586+ }
587+ }
588+
589+ // try to compile a dummy kernel to determine if the tensor API is supported for bfloat
553590 if (dev->props .has_tensor && dev->props .has_bfloat ) {
554591 const char * src_tensor_bf16 = " \n "
555592 " #include <metal_stdlib> \n "
@@ -559,24 +596,20 @@ ggml_metal_device_t ggml_metal_device_init(void) {
559596 " using namespace metal; \n "
560597 " using namespace mpp::tensor_ops; \n "
561598 " \n "
562- " kernel void bfloat_dummy_kernel ( \n "
599+ " kernel void dummy_kernel ( \n "
563600 " tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n "
564601 " tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n "
565602 " uint2 tgid [[threadgroup_position_in_grid]]) \n "
566603 " { \n "
567- " // Create slices for this threadgroup (no real computation performed). \n "
568604 " auto tA = A.slice(0, (int)tgid.y); \n "
569605 " auto tB = B.slice((int)tgid.x, 0); \n "
570606 " \n "
571- " // Minimal matmul descriptor: 8×8 tile with dynamic K dimension. \n "
572607 " matmul2d< \n "
573608 " matmul2d_descriptor(8, 8, dynamic_extent), \n "
574609 " execution_thread> mm; \n "
575610 " \n "
576- " // Obtain a cooperative destination tensor of bfloat type. \n "
577611 " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), bfloat>(); \n "
578612 " \n "
579- " // Silence “unused variable” warnings. \n "
580613 " (void)cT; \n "
581614 " }" ;
582615
0 commit comments