Skip to content

Commit

Permalink
Fixed OptiX kernel hashing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
DoeringChristian committed Jan 30, 2025
1 parent 0025e68 commit 208dcc9
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
7 changes: 3 additions & 4 deletions src/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,9 @@ Task *jitc_run(ThreadState *ts, ScheduledGroup group) {
}
#endif

KernelKey kernel_key((char *) buffer.get(), ts->device, flags);
auto it = state.kernel_cache.find(
kernel_key,
KernelHash::compute_hash(kernel_hash.high64, ts->device, flags));
KernelKey kernel_key((char *) buffer.get(), kernel_hash.high64, ts->device,
flags);
auto it = state.kernel_cache.find(kernel_key);
Kernel kernel;
memset(&kernel, 0, sizeof(Kernel)); // quench uninitialized variable warning on MSVC

Expand Down
8 changes: 6 additions & 2 deletions src/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,11 @@ struct KernelKey {
char *str = nullptr;
int device = 0;
uint64_t flags = 0;
// The precomputed hash, used to lookup the kernel in the kernel_cache.
uint64_t high64 = 0;

KernelKey(char *str, int device, uint64_t flags) : str(str), device(device), flags(flags) { }
KernelKey(char *str, uint64_t high64, int device, uint64_t flags)
: str(str), device(device), flags(flags), high64(high64) {}

bool operator==(const KernelKey &k) const {
return strcmp(k.str, str) == 0 && device == k.device && flags == k.flags;
Expand All @@ -757,9 +760,10 @@ struct KernelKey {
/// Helper class to hash KernelKey instances
struct KernelHash {
size_t operator()(const KernelKey &k) const {
return compute_hash(hash_kernel(k.str).high64, k.device, k.flags);
return compute_hash(k.high64, k.device, k.flags);
}

private:
static size_t compute_hash(size_t kernel_hash, int device, uint64_t flags) {
size_t hash = kernel_hash;
hash_combine(hash, (size_t) flags + size_t(device + 1));
Expand Down
6 changes: 2 additions & 4 deletions src/record_ts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,14 +709,13 @@ void RecordThreadState::record_launch(
op.dependency_range = std::pair(start, end);

op.kernel.kernel = kernel;
op.kernel.precomputed_hash =
KernelHash::compute_hash(hash.high64, key->device, key->flags);
op.kernel.key = (KernelKey *) std::malloc(sizeof(KernelKey));
size_t str_size = buffer.size() + 1;
op.kernel.key->str = (char *) malloc_check(str_size);
std::memcpy(op.kernel.key->str, key->str, str_size);
op.kernel.key->device = key->device;
op.kernel.key->flags = key->flags;
op.kernel.key->high64 = key->high64;

op.size = size;

Expand Down Expand Up @@ -1856,8 +1855,7 @@ bool Recording::check_kernel_cache() {
Operation &op = operations[i];
if (op.type == OpType::KernelLaunch) {
// Test if this kernel is still in the cache
auto it = state.kernel_cache.find(*op.kernel.key,
op.kernel.precomputed_hash);
auto it = state.kernel_cache.find(*op.kernel.key);
if (it == state.kernel_cache.end())
return false;
}
Expand Down
1 change: 0 additions & 1 deletion src/record_ts.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ struct Operation {
/// Additional information of a kernel launch
struct {
KernelKey *key;
size_t precomputed_hash;
Kernel kernel;
XXH128_hash_t hash;
} kernel;
Expand Down

0 comments on commit 208dcc9

Please sign in to comment.