Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion web/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export {
} from "./runtime";
export { Disposable, LibraryProvider } from "./types";
export { RPCServer } from "./rpc_server";
export { wasmPath } from "./support";
export { wasmPath, LinearCongruentialGenerator } from "./support";
export { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu";
export { assert } from "./support";
export { createPolyfillWASI } from "./compact";
17 changes: 13 additions & 4 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes";
import { Disposable } from "./types";
import { Memory, CachedCallStack } from "./memory";
import { assert, StringToUint8Array } from "./support";
import { assert, StringToUint8Array, LinearCongruentialGenerator } from "./support";
import { Environment } from "./environment";
import { AsyncifyHandler } from "./asyncify";
import { FunctionInfo, WebGPUContext } from "./webgpu";
Expand Down Expand Up @@ -1079,6 +1079,7 @@ export class Instance implements Disposable {
private ctx: RuntimeContext;
private asyncifyHandler: AsyncifyHandler;
private initProgressCallback: Array<InitProgressCallback> = [];
private rng: LinearCongruentialGenerator;

/**
* Internal function(registered by the runtime)
Expand Down Expand Up @@ -1131,6 +1132,7 @@ export class Instance implements Disposable {
);
this.registerEnvGlobalPackedFuncs();
this.registerObjectFactoryFuncs();
this.rng = new LinearCongruentialGenerator();
}

/**
Expand Down Expand Up @@ -1811,11 +1813,18 @@ export class Instance implements Disposable {
const scale = high - low;
const input = new Float32Array(size);
for (let i = 0; i < input.length; ++i) {
input[i] = low + Math.random() * scale;
input[i] = low + this.rng.randomFloat() * scale;
}
return ret.copyFrom(input);
}

/**
* Set the seed of the internal LinearCongruentialGenerator.
*/
setSeed(seed: number): void {
this.rng.setSeed(seed);
}

/**
* Sample index via top-p sampling.
*
Expand All @@ -1825,7 +1834,7 @@ export class Instance implements Disposable {
* @returns The sampled index.
*/
sampleTopPFromLogits(logits: NDArray, temperature: number, top_p: number): number {
return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random());
return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, this.rng.randomFloat());
}

/**
Expand All @@ -1836,7 +1845,7 @@ export class Instance implements Disposable {
* @returns The sampled index.
*/
sampleTopPFromProb(prob: NDArray, top_p: number): number {
return this.ctx.sampleTopPFromProb(prob, top_p, Math.random());
return this.ctx.sampleTopPFromProb(prob, top_p, this.rng.randomFloat());
}

/**
Expand Down
76 changes: 76 additions & 0 deletions web/src/support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,79 @@ export function assert(condition: boolean, msg?: string): asserts condition {
export function wasmPath(): string {
return __dirname + "/wasm";
}

/**
* Linear congruential generator for random number generating that can be seeded.
*
* Follows the implementation of `include/tvm/support/random_engine.h`, which follows the
* sepcification in https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine.
*
* Note `Number.MAX_SAFE_INTEGER = 2^53 - 1`, and our intermediates are strictly less than 2^48.
*/

export class LinearCongruentialGenerator {
readonly modulus: number;
readonly multiplier: number;
readonly increment: number;
// Always within the range (0, 2^32 - 1) non-inclusive; if 0, will forever generate 0.
private rand_state: number;

/**
* Set modulus, multiplier, and increment. Initialize `rand_state` according to `Date.now()`.
*/
constructor() {
this.modulus = 2147483647; // 2^32 - 1
this.multiplier = 48271; // between 2^15 and 2^16
this.increment = 0;
this.setSeed(Date.now());
}

/**
* Sets `rand_state` after normalized with `modulus` to ensure that it is within range.
* @param seed Any integer. Used to set `rand_state` after normalized with `modulus`.
*
* Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an integer.
*/
setSeed(seed: number) {
if (!Number.isInteger(seed)) {
throw new Error("Seed should be an integer.");
}
this.rand_state = seed % this.modulus;
if (this.rand_state == 0) {
this.rand_state = 1;
}
this.checkRandState();
}

/**
* Generate the next integer in the range (0, this.modulus) non-inclusive, updating `rand_state`.
*
* Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an integer.
*/
nextInt(): number {
// `intermediate` is always < 2^48, hence less than `Number.MAX_SAFE_INTEGER` due to the
// invariants as commented in the constructor.
const intermediate = this.multiplier * this.rand_state + this.increment;
this.rand_state = intermediate % this.modulus;
this.checkRandState();
return this.rand_state;
}

/**
* Generates random float between (0, 1) non-inclusive, updating `rand_state`.
*
* Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an integer.
*/
randomFloat(): number {
return this.nextInt() / this.modulus;
}

private checkRandState(): void {
if (this.rand_state <= 0) {
throw new Error("Random state is unexpectedly not strictly positive.");
}
if (!Number.isInteger(this.rand_state)) {
throw new Error("Random state is unexpectedly not an integer.");
}
}
}
71 changes: 71 additions & 0 deletions web/tests/node/test_random_generator.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/* eslint-disable no-undef */

const tvmjs = require("../../dist");

test("Test coverage of [0,100] inclusive", () => {
const covered = Array(100);
const rng = new tvmjs.LinearCongruentialGenerator();
for (let i = 0; i < 100000; i++) {
covered[rng.nextInt() % 100] = true;
}
const notCovered = [];
for (let i = 0; i < 100; i++) {
if (!covered[i]) {
notCovered.push(i);
}
}
expect(notCovered).toEqual([]);
});

test("Test whether the same seed make two RNGs generate same results", () => {
const rng1 = new tvmjs.LinearCongruentialGenerator();
const rng2 = new tvmjs.LinearCongruentialGenerator();
rng1.setSeed(42);
rng2.setSeed(42);

for (let i = 0; i < 100; i++) {
expect(rng1.randomFloat()).toBeCloseTo(rng2.randomFloat());
}
});

test("Test two RNGs with different seeds generate different results", () => {
const rng1 = new tvmjs.LinearCongruentialGenerator();
const rng2 = new tvmjs.LinearCongruentialGenerator();
rng1.setSeed(41);
rng2.setSeed(42);
let numSame = 0;
const numTest = 100;

// Generate `numTest` random numbers, make sure not all are the same.
for (let i = 0; i < numTest; i++) {
if (rng1.nextInt() === rng2.nextInt()) {
numSame += 1;
}
}
expect(numSame < numTest).toBe(true);
});

test('Illegal argument to `setSeed()`', () => {
expect(() => {
const rng1 = new tvmjs.LinearCongruentialGenerator();
rng1.setSeed(42.5);
}).toThrow("Seed should be an integer.");
});