4242#include < cstring>
4343#include < iostream>
4444#include < memory>
45+ #include < mutex>
4546#include < stdexcept>
4647#include < string>
4748#include < unordered_map>
@@ -1323,17 +1324,7 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
13231324 ggml_vk_cpy (spirv, 2 , 4 , std::forward<Args>(args)...);
13241325}
13251326
1326- static bool ggml_vk_supports_op (const struct ggml_tensor * op) {
1327- switch (op->type ) {
1328- case GGML_TYPE_F16:
1329- case GGML_TYPE_F32:
1330- case GGML_TYPE_Q4_0:
1331- case GGML_TYPE_Q4_1:
1332- break ;
1333- default :
1334- return false ;
1335- }
1336-
1327+ static bool ggml_backend_kompute_device_supports_op (ggml_backend_dev_t dev, const struct ggml_tensor * op) {
13371328 switch (op->op ) {
13381329 case GGML_OP_UNARY:
13391330 switch (ggml_get_unary_op (op)) {
@@ -1410,6 +1401,8 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
14101401 ;
14111402 }
14121403 return false ;
1404+
1405+ GGML_UNUSED (dev);
14131406}
14141407
14151408static void ggml_vk_graph_compute (struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
@@ -1458,10 +1451,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
14581451
14591452 any_commands_recorded = true ;
14601453
1454+ /* Do we still need this?
14611455 if (!ggml_vk_supports_op(dst)) {
14621456 fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
14631457 GGML_ABORT("unsupported op");
14641458 }
1459+ */
14651460
14661461 const int32_t ne00 = src0 ? src0->ne [0 ] : 0 ;
14671462 const int32_t ne01 = src0 ? src0->ne [1 ] : 0 ;
@@ -1921,7 +1916,7 @@ ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
19211916 for (const auto & dev : devices) {
19221917 vec.push_back ({
19231918 /* .iface = */ ggml_backend_kompute_buffer_type_interface,
1924- /* .device = */ nullptr ,
1919+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_kompute_reg (), 0 ) ,
19251920 /* .context = */ new ggml_backend_kompute_buffer_type_context (dev.index , dev.bufferAlignment , dev.maxAlloc )
19261921 });
19271922 }
@@ -1964,16 +1959,6 @@ static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, st
19641959 return GGML_STATUS_SUCCESS;
19651960}
19661961
1967- static bool ggml_backend_kompute_supports_op (ggml_backend_t backend, const struct ggml_tensor * op) {
1968- GGML_UNUSED (backend);
1969- return ggml_vk_supports_op (op);
1970- }
1971-
1972- static bool ggml_backend_kompute_supports_buft (ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1973- GGML_UNUSED (backend);
1974- return buft->iface .get_name == ggml_backend_kompute_buffer_type_get_name;
1975- }
1976-
19771962static struct ggml_backend_i kompute_backend_i = {
19781963 /* .get_name = */ ggml_backend_kompute_name,
19791964 /* .free = */ ggml_backend_kompute_free,
@@ -1987,8 +1972,8 @@ static struct ggml_backend_i kompute_backend_i = {
19871972 /* .graph_plan_update = */ NULL ,
19881973 /* .graph_plan_compute = */ NULL ,
19891974 /* .graph_compute = */ ggml_backend_kompute_graph_compute,
1990- /* .supports_op = */ ggml_backend_kompute_supports_op ,
1991- /* .supports_buft = */ ggml_backend_kompute_supports_buft ,
1975+ /* .supports_op = */ NULL ,
1976+ /* .supports_buft = */ NULL ,
19921977 /* .offload_op = */ NULL ,
19931978 /* .event_record = */ NULL ,
19941979 /* .event_wait = */ NULL ,
@@ -2006,7 +1991,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
20061991 ggml_backend_t kompute_backend = new ggml_backend {
20071992 /* .guid = */ ggml_backend_kompute_guid (),
20081993 /* .interface = */ kompute_backend_i,
2009- /* .device = */ nullptr ,
1994+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_kompute_reg (), 0 ) ,
20101995 /* .context = */ s_kompute_context,
20111996 };
20121997
@@ -2016,3 +2001,203 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
20162001bool ggml_backend_is_kompute (ggml_backend_t backend) {
20172002 return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_kompute_guid ());
20182003}
2004+
2005+ int ggml_backend_kompute_get_device_count () {
2006+ auto devices = ggml_vk_available_devices_internal (0 );
2007+ return devices.size ();
2008+ }
2009+
2010+ void ggml_backend_kompute_get_device_description (int device, char * description, size_t description_size) {
2011+ std::vector<vk::PhysicalDevice> physical_devices;
2012+ try {
2013+ physical_devices = komputeManager ()->listDevices ();
2014+ } catch (vk::SystemError & err) {
2015+ std::cerr << __func__ << " : Vulkan exception: " << err.what () << " \n " ;
2016+ GGML_ABORT (" " );
2017+ }
2018+
2019+ GGML_ASSERT (device < physical_devices.size ());
2020+
2021+ const auto & physical_device = physical_devices[device];
2022+ VkPhysicalDeviceProperties dev_props = physical_device.getProperties ();
2023+
2024+ auto devices = ggml_vk_available_devices_internal (0 );
2025+ snprintf (description, description_size, " %s" , dev_props.deviceName );
2026+ }
2027+
2028+ void ggml_backend_kompute_get_device_memory (int device, size_t * free, size_t * total) {
2029+ std::vector<vk::PhysicalDevice> physical_devices;
2030+ try {
2031+ physical_devices = komputeManager ()->listDevices ();
2032+ } catch (vk::SystemError & err) {
2033+ std::cerr << __func__ << " : Vulkan exception: " << err.what () << " \n " ;
2034+ GGML_ABORT (" " );
2035+ }
2036+
2037+ GGML_ASSERT (device < physical_devices.size ());
2038+
2039+ const auto & physical_device = physical_devices[device];
2040+
2041+ vk::PhysicalDeviceMemoryProperties memprops = physical_device.getMemoryProperties ();
2042+
2043+ for (const vk::MemoryHeap& heap : memprops.memoryHeaps ) {
2044+ if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
2045+ *total = heap.size ;
2046+ *free = heap.size ;
2047+ break ;
2048+ }
2049+ }
2050+ }
2051+
2052+ // ////////////////////////
2053+
2054+ struct ggml_backend_kompute_device_context {
2055+ int device;
2056+ std::string name;
2057+ std::string description;
2058+ };
2059+
2060+ static const char * ggml_backend_kompute_device_get_name (ggml_backend_dev_t dev) {
2061+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2062+ return ctx->name .c_str ();
2063+ }
2064+
2065+ static const char * ggml_backend_kompute_device_get_description (ggml_backend_dev_t dev) {
2066+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2067+ return ctx->description .c_str ();
2068+ }
2069+
2070+ static void ggml_backend_kompute_device_get_memory (ggml_backend_dev_t dev, size_t * free, size_t * total) {
2071+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2072+ ggml_backend_kompute_get_device_memory (ctx->device , free, total);
2073+ }
2074+
2075+ static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type (ggml_backend_dev_t dev) {
2076+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2077+ return ggml_backend_kompute_buffer_type (ctx->device );
2078+ }
2079+
2080+ static bool ggml_backend_kompute_device_supports_buft (ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2081+ if (buft->iface .get_name != ggml_backend_kompute_buffer_type_get_name) {
2082+ return false ;
2083+ }
2084+
2085+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2086+ ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context ;
2087+
2088+ return buft_ctx->device == ctx->device ;
2089+ }
2090+
2091+ // TODO
2092+ /* *
2093+ static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_host_buffer_type(ggml_backend_dev_t dev) {
2094+ GGML_ABORT("Unimplemented");
2095+ return ggml_backend_kompute_host_buffer_type();
2096+ }
2097+ */
2098+
2099+ static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type (ggml_backend_dev_t dev) {
2100+ GGML_UNUSED (dev);
2101+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
2102+ }
2103+
2104+ static void ggml_backend_kompute_device_get_props (ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2105+ props->name = ggml_backend_kompute_device_get_name (dev);
2106+ props->description = ggml_backend_kompute_device_get_description (dev);
2107+ props->type = ggml_backend_kompute_device_get_type (dev);
2108+ ggml_backend_kompute_device_get_memory (dev, &props->memory_free , &props->memory_total );
2109+ props->caps = {
2110+ /* async */ false ,
2111+ /* host_buffer */ false ,
2112+ /* events */ false ,
2113+ };
2114+ }
2115+
2116+ static ggml_backend_t ggml_backend_kompute_device_init (ggml_backend_dev_t dev, const char * params) {
2117+ GGML_UNUSED (params);
2118+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2119+ return ggml_backend_kompute_init (ctx->device );
2120+ }
2121+
2122+ static bool ggml_backend_kompute_device_offload_op (ggml_backend_dev_t dev, const ggml_tensor * op) {
2123+ const int min_batch_size = 32 ;
2124+
2125+ return (op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2126+ (op->ne [2 ] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
2127+
2128+ GGML_UNUSED (dev);
2129+ }
2130+
2131+ static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
2132+ /* .get_name = */ ggml_backend_kompute_device_get_name,
2133+ /* .get_description = */ ggml_backend_kompute_device_get_description,
2134+ /* .get_memory = */ ggml_backend_kompute_device_get_memory,
2135+ /* .get_type = */ ggml_backend_kompute_device_get_type,
2136+ /* .get_props = */ ggml_backend_kompute_device_get_props,
2137+ /* .init_backend = */ ggml_backend_kompute_device_init,
2138+ /* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
2139+ /* .get_host_buffer_type = */ NULL ,
2140+ /* .buffer_from_host_ptr = */ NULL ,
2141+ /* .supports_op = */ ggml_backend_kompute_device_supports_op,
2142+ /* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
2143+ /* .offload_op = */ ggml_backend_kompute_device_offload_op,
2144+ /* .event_new = */ NULL ,
2145+ /* .event_free = */ NULL ,
2146+ /* .event_synchronize = */ NULL ,
2147+ };
2148+
2149+ static const char * ggml_backend_kompute_reg_get_name (ggml_backend_reg_t reg) {
2150+ GGML_UNUSED (reg);
2151+ return " Kompute" ;
2152+ }
2153+
2154+ static size_t ggml_backend_kompute_reg_get_device_count (ggml_backend_reg_t reg) {
2155+ GGML_UNUSED (reg);
2156+ return ggml_backend_kompute_get_device_count ();
2157+ }
2158+
2159+ static ggml_backend_dev_t ggml_backend_kompute_reg_get_device (ggml_backend_reg_t reg, size_t device) {
2160+ static std::vector<ggml_backend_dev_t > devices;
2161+
2162+ static bool initialized = false ;
2163+
2164+ {
2165+ static std::mutex mutex;
2166+ std::lock_guard<std::mutex> lock (mutex);
2167+ if (!initialized) {
2168+ for (size_t i = 0 ; i < ggml_backend_kompute_get_device_count (); i++) {
2169+ ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
2170+ char desc[256 ];
2171+ ggml_backend_kompute_get_device_description (i, desc, sizeof (desc));
2172+ ctx->device = i;
2173+ ctx->name = " Kompute" + std::to_string (i);
2174+ ctx->description = desc;
2175+ devices.push_back (new ggml_backend_device {
2176+ /* .iface = */ ggml_backend_kompute_device_i,
2177+ /* .reg = */ reg,
2178+ /* .context = */ ctx,
2179+ });
2180+ }
2181+ initialized = true ;
2182+ }
2183+ }
2184+
2185+ GGML_ASSERT (device < devices.size ());
2186+ return devices[device];
2187+ }
2188+
2189+ static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
2190+ /* .get_name = */ ggml_backend_kompute_reg_get_name,
2191+ /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
2192+ /* .get_device = */ ggml_backend_kompute_reg_get_device,
2193+ /* .get_proc_address = */ NULL ,
2194+ };
2195+
2196+ ggml_backend_reg_t ggml_backend_kompute_reg () {
2197+ static ggml_backend_reg reg = {
2198+ /* .iface = */ ggml_backend_kompute_reg_i,
2199+ /* .context = */ nullptr ,
2200+ };
2201+
2202+ return ®
2203+ }
0 commit comments