From 208dcc94f4433d6195d3c95d7eec81c6f0e30c6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20D=C3=B6ring?= Date: Thu, 30 Jan 2025 14:14:19 +0100 Subject: [PATCH] Fixed OptiX kernel hashing bug --- src/eval.cpp | 7 +++---- src/internal.h | 8 ++++++-- src/record_ts.cpp | 6 ++---- src/record_ts.h | 1 - 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/eval.cpp b/src/eval.cpp index 408873518..65414226f 100644 --- a/src/eval.cpp +++ b/src/eval.cpp @@ -500,10 +500,9 @@ Task *jitc_run(ThreadState *ts, ScheduledGroup group) { } #endif - KernelKey kernel_key((char *) buffer.get(), ts->device, flags); - auto it = state.kernel_cache.find( - kernel_key, - KernelHash::compute_hash(kernel_hash.high64, ts->device, flags)); + KernelKey kernel_key((char *) buffer.get(), kernel_hash.high64, ts->device, + flags); + auto it = state.kernel_cache.find(kernel_key); Kernel kernel; memset(&kernel, 0, sizeof(Kernel)); // quench uninitialized variable warning on MSVC diff --git a/src/internal.h b/src/internal.h index dc8188b08..1afdcc67d 100644 --- a/src/internal.h +++ b/src/internal.h @@ -746,8 +746,11 @@ struct KernelKey { char *str = nullptr; int device = 0; uint64_t flags = 0; + // The precomputed hash, used to lookup the kernel in the kernel_cache. + uint64_t high64 = 0; - KernelKey(char *str, int device, uint64_t flags) : str(str), device(device), flags(flags) { } + KernelKey(char *str, uint64_t high64, int device, uint64_t flags) + : str(str), device(device), flags(flags), high64(high64) {} bool operator==(const KernelKey &k) const { return strcmp(k.str, str) == 0 && device == k.device && flags == k.flags; @@ -757,9 +760,10 @@ struct KernelKey { /// Helper class to hash KernelKey instances struct KernelHash { size_t operator()(const KernelKey &k) const { - return compute_hash(hash_kernel(k.str).high64, k.device, k.flags); + return compute_hash(k.high64, k.device, k.flags); } +private: static size_t compute_hash(size_t kernel_hash, int device, uint64_t flags) { size_t hash = kernel_hash; hash_combine(hash, (size_t) flags + size_t(device + 1)); diff --git a/src/record_ts.cpp b/src/record_ts.cpp index a40979749..1eb7b9df6 100644 --- a/src/record_ts.cpp +++ b/src/record_ts.cpp @@ -709,14 +709,13 @@ void RecordThreadState::record_launch( op.dependency_range = std::pair(start, end); op.kernel.kernel = kernel; - op.kernel.precomputed_hash = - KernelHash::compute_hash(hash.high64, key->device, key->flags); op.kernel.key = (KernelKey *) std::malloc(sizeof(KernelKey)); size_t str_size = buffer.size() + 1; op.kernel.key->str = (char *) malloc_check(str_size); std::memcpy(op.kernel.key->str, key->str, str_size); op.kernel.key->device = key->device; op.kernel.key->flags = key->flags; + op.kernel.key->high64 = key->high64; op.size = size; @@ -1856,8 +1855,7 @@ bool Recording::check_kernel_cache() { Operation &op = operations[i]; if (op.type == OpType::KernelLaunch) { // Test if this kernel is still in the cache - auto it = state.kernel_cache.find(*op.kernel.key, - op.kernel.precomputed_hash); + auto it = state.kernel_cache.find(*op.kernel.key); if (it == state.kernel_cache.end()) return false; } diff --git a/src/record_ts.h b/src/record_ts.h index b5fd4b967..a73ac9b9d 100644 --- a/src/record_ts.h +++ b/src/record_ts.h @@ -59,7 +59,6 @@ struct Operation { /// Additional information of a kernel launch struct { KernelKey *key; - size_t precomputed_hash; Kernel kernel; XXH128_hash_t hash; } kernel;