Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <mutex>

#if defined(GGML_USE_HIP)
#define GGML_COMMON_DECL_HIP
Expand Down Expand Up @@ -1549,6 +1550,77 @@ struct ggml_cuda_pdl_config {
ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete;

};

static bool ggml_cuda_kernel_can_use_pdl(const void * kernel) {
const int device = ggml_cuda_get_device();

struct cache_key {
int device;
const void * kernel;

bool operator==(const cache_key & other) const { return device == other.device && kernel == other.kernel; }
};

struct cache_key_hash {
// MurmurHash3 mixing function for better hash distribution
Comment thread
ORippler marked this conversation as resolved.
Outdated
static std::size_t hash_mix(std::size_t x) {
// 64-bit path
if constexpr (sizeof(std::size_t) >= 8) {
std::uint64_t y = x;
const std::uint64_t m = 0xe9846af9b1a615d;

y ^= y >> 32;
y *= m;
y ^= y >> 32;
y *= m;
y ^= y >> 28;

return static_cast<std::size_t>(y);
} else {
// 32-bit path
std::uint32_t y = x;
const std::uint32_t m1 = 0x21f0aaad;
const std::uint32_t m2 = 0x735a2d97;

y ^= y >> 16;
y *= m1;
y ^= y >> 15;
y *= m2;
y ^= y >> 15;

return static_cast<std::size_t>(y);
}
Comment thread
ORippler marked this conversation as resolved.
Outdated
}

std::size_t operator()(const cache_key & key) const {
std::size_t h = 0;
h = hash_mix(h + std::hash<int>{}(key.device));
h = hash_mix(h + std::hash<const void *>{}(key.kernel));
Comment thread
ORippler marked this conversation as resolved.
Outdated
return h;
}
};

static std::mutex cache_mutex;
static std::unordered_map<cache_key, bool, cache_key_hash> cache;

const cache_key key = { device, kernel };
std::lock_guard<std::mutex> lock(cache_mutex);
const auto it = cache.find(key);
if (it != cache.end()) {
return it->second;
}

cudaFuncAttributes attr = {};
CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel));

// PDL device-side primitives are emitted only for PTX versions >= 90.
// We have to guard on a loaded kernel's PTX version so a kernel forward-JIT'ed
// from pre-Hopper PTX to a Hopper-or-newer GPU does not opt into PDL.
const bool can_use_pdl = attr.ptxVersion >= 90;
cache.emplace(key, can_use_pdl);
return can_use_pdl;
}

#endif //defined(GGML_CUDA_USE_PDL)


Expand All @@ -1561,8 +1633,7 @@ static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_ke
return env == nullptr || std::atoi(env) != 0;
}();

const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (env_pdl_enabled && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_HOPPER) {
if (env_pdl_enabled && ggml_cuda_kernel_can_use_pdl(reinterpret_cast<const void *>(kernel))) {
auto pdl_cfg = ggml_cuda_pdl_config(launch_params);

CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... ));
Expand Down
Loading