Skip to content

Commit

Permalink
std.rand: Refactor Random interface
Browse files Browse the repository at this point in the history
These changes have been made to resolve issue ziglang#10037. The `Random`
interface was implemented in such a way that causes significant slowdown
when calling the `fill` function of the rng used.

The `Random` interface is no longer stored in a field of the rng, and is
instead returned by the child function `random()` of the rng. This
avoids the performance issues caused by the interface.
  • Loading branch information
ominitay committed Oct 27, 2021
1 parent 3af9731 commit 1a54a15
Show file tree
Hide file tree
Showing 18 changed files with 291 additions and 244 deletions.
5 changes: 3 additions & 2 deletions lib/std/atomic/queue.zig
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,11 @@ test "std.atomic.Queue" {

fn startPuts(ctx: *Context) u8 {
var put_count: usize = puts_per_thread;
var r = std.rand.DefaultPrng.init(0xdeadbeef);
var prng = std.rand.DefaultPrng.init(0xdeadbeef);
const random = prng.random();
while (put_count != 0) : (put_count -= 1) {
std.time.sleep(1); // let the os scheduler be our fuzz
const x = @bitCast(i32, r.random.int(u32));
const x = @bitCast(i32, random.int(u32));
const node = ctx.allocator.create(Queue(i32).Node) catch unreachable;
node.* = .{
.prev = undefined,
Expand Down
5 changes: 3 additions & 2 deletions lib/std/atomic/stack.zig
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ test "std.atomic.stack" {

fn startPuts(ctx: *Context) u8 {
var put_count: usize = puts_per_thread;
var r = std.rand.DefaultPrng.init(0xdeadbeef);
var prng = std.rand.DefaultPrng.init(0xdeadbeef);
const random = prng.random();
while (put_count != 0) : (put_count -= 1) {
std.time.sleep(1); // let the os scheduler be our fuzz
const x = @bitCast(i32, r.random.int(u32));
const x = @bitCast(i32, random.int(u32));
const node = ctx.allocator.create(Stack(i32).Node) catch unreachable;
node.* = Stack(i32).Node{
.next = undefined,
Expand Down
21 changes: 11 additions & 10 deletions lib/std/crypto/benchmark.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const KiB = 1024;
const MiB = 1024 * KiB;

var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();

const Crypto = struct {
ty: type,
Expand All @@ -34,7 +35,7 @@ pub fn benchmarkHash(comptime Hash: anytype, comptime bytes: comptime_int) !u64
var h = Hash.init(.{});

var block: [Hash.digest_length]u8 = undefined;
prng.random.bytes(block[0..]);
random.bytes(block[0..]);

var offset: usize = 0;
var timer = try Timer.start();
Expand Down Expand Up @@ -66,11 +67,11 @@ const macs = [_]Crypto{

pub fn benchmarkMac(comptime Mac: anytype, comptime bytes: comptime_int) !u64 {
var in: [512 * KiB]u8 = undefined;
prng.random.bytes(in[0..]);
random.bytes(in[0..]);

const key_length = if (Mac.key_length == 0) 32 else Mac.key_length;
var key: [key_length]u8 = undefined;
prng.random.bytes(key[0..]);
random.bytes(key[0..]);

var mac: [Mac.mac_length]u8 = undefined;
var offset: usize = 0;
Expand All @@ -94,10 +95,10 @@ pub fn benchmarkKeyExchange(comptime DhKeyExchange: anytype, comptime exchange_c
std.debug.assert(DhKeyExchange.shared_length >= DhKeyExchange.secret_length);

var secret: [DhKeyExchange.shared_length]u8 = undefined;
prng.random.bytes(secret[0..]);
random.bytes(secret[0..]);

var public: [DhKeyExchange.shared_length]u8 = undefined;
prng.random.bytes(public[0..]);
random.bytes(public[0..]);

var timer = try Timer.start();
const start = timer.lap();
Expand Down Expand Up @@ -211,15 +212,15 @@ const aeads = [_]Crypto{

pub fn benchmarkAead(comptime Aead: anytype, comptime bytes: comptime_int) !u64 {
var in: [512 * KiB]u8 = undefined;
prng.random.bytes(in[0..]);
random.bytes(in[0..]);

var tag: [Aead.tag_length]u8 = undefined;

var key: [Aead.key_length]u8 = undefined;
prng.random.bytes(key[0..]);
random.bytes(key[0..]);

var nonce: [Aead.nonce_length]u8 = undefined;
prng.random.bytes(nonce[0..]);
random.bytes(nonce[0..]);

var offset: usize = 0;
var timer = try Timer.start();
Expand All @@ -244,7 +245,7 @@ const aes = [_]Crypto{

pub fn benchmarkAes(comptime Aes: anytype, comptime count: comptime_int) !u64 {
var key: [Aes.key_bits / 8]u8 = undefined;
prng.random.bytes(key[0..]);
random.bytes(key[0..]);
const ctx = Aes.initEnc(key);

var in = [_]u8{0} ** 16;
Expand Down Expand Up @@ -273,7 +274,7 @@ const aes8 = [_]Crypto{

pub fn benchmarkAes8(comptime Aes: anytype, comptime count: comptime_int) !u64 {
var key: [Aes.key_bits / 8]u8 = undefined;
prng.random.bytes(key[0..]);
random.bytes(key[0..]);
const ctx = Aes.initEnc(key);

var in = [_]u8{0} ** (8 * 16);
Expand Down
7 changes: 5 additions & 2 deletions lib/std/crypto/tlcsprng.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ const os = std.os;

/// We use this as a layer of indirection because global const pointers cannot
/// point to thread-local variables.
pub var interface = std.rand.Random{ .fillFn = tlsCsprngFill };
pub const interface = std.rand.Random{
.ptr = undefined,
.fillFn = tlsCsprngFill,
};

const os_has_fork = switch (builtin.os.tag) {
.dragonfly,
Expand Down Expand Up @@ -55,7 +58,7 @@ var install_atfork_handler = std.once(struct {

threadlocal var wipe_mem: []align(mem.page_size) u8 = &[_]u8{};

fn tlsCsprngFill(_: *const std.rand.Random, buffer: []u8) void {
fn tlsCsprngFill(_: *c_void, buffer: []u8) void {
if (builtin.link_libc and @hasDecl(std.c, "arc4random_buf")) {
// arc4random is already a thread-local CSPRNG.
return std.c.arc4random_buf(buffer.ptr, buffer.len);
Expand Down
5 changes: 3 additions & 2 deletions lib/std/hash/benchmark.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const MiB = 1024 * KiB;
const GiB = 1024 * MiB;

var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();

const Hash = struct {
ty: type,
Expand Down Expand Up @@ -88,7 +89,7 @@ pub fn benchmarkHash(comptime H: anytype, bytes: usize) !Result {
};

var block: [block_size]u8 = undefined;
prng.random.bytes(block[0..]);
random.bytes(block[0..]);

var offset: usize = 0;
var timer = try Timer.start();
Expand All @@ -110,7 +111,7 @@ pub fn benchmarkHash(comptime H: anytype, bytes: usize) !Result {
pub fn benchmarkHashSmallKeys(comptime H: anytype, key_size: usize, bytes: usize) !Result {
const key_count = bytes / key_size;
var block: [block_size]u8 = undefined;
prng.random.bytes(block[0..]);
random.bytes(block[0..]);

var i: usize = 0;
var timer = try Timer.start();
Expand Down
12 changes: 7 additions & 5 deletions lib/std/hash_map.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1795,10 +1795,11 @@ test "std.hash_map put and remove loop in random order" {
while (i < size) : (i += 1) {
try keys.append(i);
}
var rng = std.rand.DefaultPrng.init(0);
var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();

while (i < iterations) : (i += 1) {
std.rand.Random.shuffle(&rng.random, u32, keys.items);
random.shuffle(u32, keys.items);

for (keys.items) |key| {
try map.put(key, key);
Expand Down Expand Up @@ -1826,14 +1827,15 @@ test "std.hash_map remove one million elements in random order" {
keys.append(i) catch unreachable;
}

var rng = std.rand.DefaultPrng.init(0);
std.rand.Random.shuffle(&rng.random, u32, keys.items);
var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();
random.shuffle(u32, keys.items);

for (keys.items) |key| {
map.put(key, key) catch unreachable;
}

std.rand.Random.shuffle(&rng.random, u32, keys.items);
random.shuffle(u32, keys.items);
i = 0;
while (i < n) : (i += 1) {
const key = keys.items[i];
Expand Down
3 changes: 2 additions & 1 deletion lib/std/io/test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ test "write a file, read it, then delete it" {

var data: [1024]u8 = undefined;
var prng = DefaultPrng.init(1234);
prng.random.bytes(data[0..]);
const random = prng.random();
random.bytes(data[0..]);
const tmp_file_name = "temp_test_file.txt";
{
var file = try tmp.dir.createFile(tmp_file_name, .{});
Expand Down
3 changes: 2 additions & 1 deletion lib/std/math/big/rational.zig
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,10 @@ test "big.rational set/to Float round-trip" {
var a = try Rational.init(testing.allocator);
defer a.deinit();
var prng = std.rand.DefaultPrng.init(0x5EED);
const random = prng.random();
var i: usize = 0;
while (i < 512) : (i += 1) {
const r = prng.random.float(f64);
const r = random.float(f64);
try a.setFloat(f64, r);
try testing.expect((try a.toFloat(f64)) == r);
}
Expand Down
17 changes: 10 additions & 7 deletions lib/std/priority_dequeue.zig
Original file line number Diff line number Diff line change
Expand Up @@ -850,17 +850,18 @@ test "std.PriorityDequeue: shrinkAndFree" {

test "std.PriorityDequeue: fuzz testing min" {
var prng = std.rand.DefaultPrng.init(0x12345678);
const random = prng.random();

const test_case_count = 100;
const queue_size = 1_000;

var i: usize = 0;
while (i < test_case_count) : (i += 1) {
try fuzzTestMin(&prng.random, queue_size);
try fuzzTestMin(random, queue_size);
}
}

fn fuzzTestMin(rng: *std.rand.Random, comptime queue_size: usize) !void {
fn fuzzTestMin(rng: std.rand.Random, comptime queue_size: usize) !void {
const allocator = testing.allocator;
const items = try generateRandomSlice(allocator, rng, queue_size);

Expand All @@ -878,17 +879,18 @@ fn fuzzTestMin(rng: *std.rand.Random, comptime queue_size: usize) !void {

test "std.PriorityDequeue: fuzz testing max" {
var prng = std.rand.DefaultPrng.init(0x87654321);
const random = prng.random();

const test_case_count = 100;
const queue_size = 1_000;

var i: usize = 0;
while (i < test_case_count) : (i += 1) {
try fuzzTestMax(&prng.random, queue_size);
try fuzzTestMax(random, queue_size);
}
}

fn fuzzTestMax(rng: *std.rand.Random, queue_size: usize) !void {
fn fuzzTestMax(rng: std.rand.Random, queue_size: usize) !void {
const allocator = testing.allocator;
const items = try generateRandomSlice(allocator, rng, queue_size);

Expand All @@ -906,17 +908,18 @@ fn fuzzTestMax(rng: *std.rand.Random, queue_size: usize) !void {

test "std.PriorityDequeue: fuzz testing min and max" {
var prng = std.rand.DefaultPrng.init(0x87654321);
const random = prng.random();

const test_case_count = 100;
const queue_size = 1_000;

var i: usize = 0;
while (i < test_case_count) : (i += 1) {
try fuzzTestMinMax(&prng.random, queue_size);
try fuzzTestMinMax(random, queue_size);
}
}

fn fuzzTestMinMax(rng: *std.rand.Random, queue_size: usize) !void {
fn fuzzTestMinMax(rng: std.rand.Random, queue_size: usize) !void {
const allocator = testing.allocator;
const items = try generateRandomSlice(allocator, rng, queue_size);

Expand All @@ -943,7 +946,7 @@ fn fuzzTestMinMax(rng: *std.rand.Random, queue_size: usize) !void {
}
}

fn generateRandomSlice(allocator: *std.mem.Allocator, rng: *std.rand.Random, size: usize) ![]u32 {
fn generateRandomSlice(allocator: *std.mem.Allocator, rng: std.rand.Random, size: usize) ![]u32 {
var array = std.ArrayList(u32).init(allocator);
try array.ensureTotalCapacity(size);

Expand Down
Loading

0 comments on commit 1a54a15

Please sign in to comment.