Skip to content

Commit 733b3f5

Browse files
authored
Fix cong implementation to be properly random and not just cycling. (JuliaLang#55509)
This was found by @IanButterworth. It unfortunately has a small performance regression due to actually using all the rng bits
1 parent 6477530 commit 733b3f5

File tree

4 files changed

+28
-12
lines changed

4 files changed

+28
-12
lines changed

src/gc-stock.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ STATIC_INLINE int gc_is_concurrent_collector_thread(int tid) JL_NOTSAFEPOINT
446446
STATIC_INLINE int gc_random_parallel_collector_thread_id(jl_ptls_t ptls) JL_NOTSAFEPOINT
447447
{
448448
assert(jl_n_markthreads > 0);
449-
int v = gc_first_tid + (int)cong(jl_n_markthreads - 1, &ptls->rngseed);
449+
int v = gc_first_tid + (int)cong(jl_n_markthreads, &ptls->rngseed); // cong is [0, n)
450450
assert(v >= gc_first_tid && v <= gc_last_parallel_collector_thread_id());
451451
return v;
452452
}

src/julia_internal.h

+25-9
Original file line numberDiff line numberDiff line change
@@ -1307,20 +1307,36 @@ JL_DLLEXPORT size_t jl_maxrss(void);
13071307
// congruential random number generator
13081308
// for a small amount of thread-local randomness
13091309

1310-
STATIC_INLINE uint64_t cong(uint64_t max, uint64_t *seed) JL_NOTSAFEPOINT
1310+
//TODO: utilize https://github.com/openssl/openssl/blob/master/crypto/rand/rand_uniform.c#L13-L99
1311+
// for better performance, it does however require making users expect a 32bit random number.
1312+
1313+
STATIC_INLINE uint64_t cong(uint64_t max, uint64_t *seed) JL_NOTSAFEPOINT // Open interval [0, max)
13111314
{
1312-
if (max == 0)
1315+
if (max < 2)
13131316
return 0;
13141317
uint64_t mask = ~(uint64_t)0;
1315-
--max;
1316-
mask >>= __builtin_clzll(max|1);
1317-
uint64_t x;
1318+
int zeros = __builtin_clzll(max);
1319+
int bits = CHAR_BIT * sizeof(uint64_t) - zeros;
1320+
mask = mask >> zeros;
13181321
do {
1319-
*seed = 69069 * (*seed) + 362437;
1320-
x = *seed & mask;
1321-
} while (x > max);
1322-
return x;
1322+
uint64_t value = 69069 * (*seed) + 362437;
1323+
*seed = value;
1324+
uint64_t x = value & mask;
1325+
if (x < max) {
1326+
return x;
1327+
}
1328+
int bits_left = zeros;
1329+
while (bits_left >= bits) {
1330+
value >>= bits;
1331+
x = value & mask;
1332+
if (x < max) {
1333+
return x;
1334+
}
1335+
bits_left -= bits;
1336+
}
1337+
} while (1);
13231338
}
1339+
13241340
JL_DLLEXPORT uint64_t jl_rand(void) JL_NOTSAFEPOINT;
13251341
JL_DLLEXPORT void jl_srand(uint64_t) JL_NOTSAFEPOINT;
13261342
JL_DLLEXPORT void jl_init_rand(void);

src/scheduler.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ extern int jl_gc_mark_queue_obj_explicit(jl_gc_mark_cache_t *gc_cache,
8787
// parallel task runtime
8888
// ---
8989

90-
JL_DLLEXPORT uint32_t jl_rand_ptls(uint32_t max)
90+
JL_DLLEXPORT uint32_t jl_rand_ptls(uint32_t max) // [0, n)
9191
{
9292
jl_ptls_t ptls = jl_current_task->ptls;
9393
return cong(max, &ptls->rngseed);

src/signal-handling.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ static void jl_shuffle_int_array_inplace(int *carray, int size, uint64_t *seed)
155155
// The "modern Fisher–Yates shuffle" - O(n) algorithm
156156
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
157157
for (int i = size; i-- > 1; ) {
158-
size_t j = cong(i, seed);
158+
size_t j = cong(i + 1, seed); // cong is an open interval so we add 1
159159
uint64_t tmp = carray[j];
160160
carray[j] = carray[i];
161161
carray[i] = tmp;

0 commit comments

Comments
 (0)