Skip to content

Commit 9af8394

Browse files
committed
cont : handle API incompatibilities
1 parent e6aa68a commit 9af8394

File tree

1 file changed

+39
-6
lines changed

1 file changed

+39
-6
lines changed

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,44 @@ ggml_metal_device_t ggml_metal_device_init(void) {
549549
dev->props.has_tensor = false;
550550
}
551551

552-
// try to compile a dummy tensor kernel to determine if the tensor API is supported for bfloat
552+
// double-check that the tensor API compiles
553+
if (dev->props.has_tensor) {
554+
const char * src_tensor_f16 = "\n"
555+
"#include <metal_stdlib> \n"
556+
"#include <metal_tensor> \n"
557+
"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
558+
" \n"
559+
"using namespace metal; \n"
560+
"using namespace mpp::tensor_ops; \n"
561+
" \n"
562+
"kernel void dummy_kernel( \n"
563+
" tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n"
564+
" tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n"
565+
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
566+
"{ \n"
567+
" auto tA = A.slice(0, (int)tgid.y); \n"
568+
" auto tB = B.slice((int)tgid.x, 0); \n"
569+
" \n"
570+
" matmul2d< \n"
571+
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
572+
" execution_thread> mm; \n"
573+
" \n"
574+
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), half>(); \n"
575+
" \n"
576+
" (void)cT; \n"
577+
"}";
578+
579+
GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__);
580+
ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_f16, false);
581+
if (lib == NULL) {
582+
GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
583+
dev->props.has_tensor = false;
584+
} else {
585+
ggml_metal_library_free(lib);
586+
}
587+
}
588+
589+
// try to compile a dummy kernel to determine if the tensor API is supported for bfloat
553590
if (dev->props.has_tensor && dev->props.has_bfloat) {
554591
const char * src_tensor_bf16 = "\n"
555592
"#include <metal_stdlib> \n"
@@ -559,24 +596,20 @@ ggml_metal_device_t ggml_metal_device_init(void) {
559596
"using namespace metal; \n"
560597
"using namespace mpp::tensor_ops; \n"
561598
" \n"
562-
"kernel void bfloat_dummy_kernel( \n"
599+
"kernel void dummy_kernel( \n"
563600
" tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
564601
" tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
565602
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
566603
"{ \n"
567-
" // Create slices for this threadgroup (no real computation performed). \n"
568604
" auto tA = A.slice(0, (int)tgid.y); \n"
569605
" auto tB = B.slice((int)tgid.x, 0); \n"
570606
" \n"
571-
" // Minimal matmul descriptor: 8×8 tile with dynamic K dimension. \n"
572607
" matmul2d< \n"
573608
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
574609
" execution_thread> mm; \n"
575610
" \n"
576-
" // Obtain a cooperative destination tensor of bfloat type. \n"
577611
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), bfloat>(); \n"
578612
" \n"
579-
" // Silence “unused variable” warnings. \n"
580613
" (void)cT; \n"
581614
"}";
582615

0 commit comments

Comments
 (0)