Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ typedef struct ggml_metal * ggml_metal_t;
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev);
void ggml_metal_free(ggml_metal_t ctx);

const char * ggml_metal_get_name(ggml_metal_t ctx);

void ggml_metal_synchronize(ggml_metal_t ctx);

void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
Expand Down
10 changes: 9 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-context.m
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
};

struct ggml_metal {
char name[128];

ggml_metal_device_t dev;
ggml_metal_library_t lib;

Expand Down Expand Up @@ -117,7 +119,9 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
}
}

//const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);

snprintf(res->name, sizeof(res->name), "%s", props_dev->name);

res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);

Expand Down Expand Up @@ -209,6 +213,10 @@ void ggml_metal_free(ggml_metal_t ctx) {
free(ctx);
}

const char * ggml_metal_get_name(ggml_metal_t ctx) {
return ctx->name;
}

void ggml_metal_synchronize(ggml_metal_t ctx) {
// wait for any backend operations to finish
if (ctx->cmd_buf_last) {
Expand Down
8 changes: 5 additions & 3 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ struct ggml_metal_device_deleter {

typedef std::unique_ptr<ggml_metal_device, ggml_metal_device_deleter> ggml_metal_device_ptr;

ggml_metal_device_t ggml_metal_device_get(void) {
static ggml_metal_device_ptr ctx { ggml_metal_device_init() };
ggml_metal_device_t ggml_metal_device_get(int device) {
static std::vector<ggml_metal_device_ptr> devs;

return ctx.get();
devs.emplace_back(ggml_metal_device_init(device));

return devs.back().get();
}

struct ggml_metal_pipelines {
Expand Down
7 changes: 4 additions & 3 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets);
//

struct ggml_metal_device_props {
int device;
char name[128];
char desc[128];

size_t max_buffer_size;
size_t max_working_set_size;
Expand All @@ -224,11 +226,10 @@ struct ggml_metal_device_props {
int op_offload_min_batch_size;
};

ggml_metal_device_t ggml_metal_device_init(void);
ggml_metal_device_t ggml_metal_device_init(int device);
void ggml_metal_device_free(ggml_metal_device_t dev);

// return a singleton that is automatically destroyed when the program exits
ggml_metal_device_t ggml_metal_device_get(void);
ggml_metal_device_t ggml_metal_device_get(int device);

void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id<MTLDevice>
void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id<MTLCommandQueue>
Expand Down
18 changes: 11 additions & 7 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;

// virtual address for GPU memory allocations
static atomic_uintptr_t g_addr_device = 0x000000400ULL;

#if !GGML_METAL_EMBED_LIBRARY
// Here to assist with NSBundle Path Hack
@interface GGMLMetalClass : NSObject
Expand Down Expand Up @@ -523,6 +520,9 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) {
ggml_metal_library_t library;

struct ggml_metal_device_props props;

// virtual address for GPU memory allocations
atomic_uintptr_t addr_virt;
};

//
Expand Down Expand Up @@ -618,7 +618,7 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) {
free(rsets);
}

ggml_metal_device_t ggml_metal_device_init(void) {
ggml_metal_device_t ggml_metal_device_init(int device) {
ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device));

assert(dev != NULL);
Expand All @@ -632,6 +632,9 @@ ggml_metal_device_t ggml_metal_device_init(void) {
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
}

dev->addr_virt = 0x000000400ULL;

dev->props.device = device;
dev->props.has_simdgroup_reduction = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];

Expand Down Expand Up @@ -788,7 +791,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize;
dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength;

strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1);
snprintf(dev->props.name, sizeof(dev->props.name), "%s%d", "MTL", device);
snprintf(dev->props.desc, sizeof(dev->props.desc), "%s", [[dev->mtl_device name] UTF8String]);

dev->library = ggml_metal_library_init(dev);
if (!dev->library) {
Expand Down Expand Up @@ -1340,8 +1344,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
res->all_data = ggml_metal_host_malloc(size_aligned);
res->is_shared = true;
} else {
// use virtual address from g_addr_device counter
res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
// use virtual address
res->all_data = (void *) atomic_fetch_add_explicit(&dev->addr_virt, size_aligned, memory_order_relaxed);
res->is_shared = false;
}
res->all_size = size_aligned;
Expand Down
Loading
Loading