Skip to content

Commit

Permalink
Fix tensor inheritance (#451)
Browse files Browse the repository at this point in the history
* Do not extend from ONNX tensor (fix #437)

* Fix typing issues

* Typing improvements

* Apply suggestions

* Update tensor import type
  • Loading branch information
xenova authored Dec 12, 2023
1 parent 2cd2997 commit 8c465a9
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 52 deletions.
7 changes: 5 additions & 2 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ function validateInputs(session, inputs) {
async function sessionRun(session, inputs) {
const checkedInputs = validateInputs(session, inputs);
try {
// @ts-ignore
let output = await session.run(checkedInputs);
output = replaceTensors(output);
return output;
Expand Down Expand Up @@ -292,6 +293,7 @@ function prepareAttentionMask(self, tokens) {
if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) {
let data = BigInt64Array.from(
// Note: != so that int matches bigint
// @ts-ignore
tokens.data.map(x => x != pad_token_id)
)
return new Tensor('int64', data, tokens.dims)
Expand Down Expand Up @@ -704,9 +706,10 @@ export class PreTrainedModel extends Callable {
* @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry
*/
async dispose() {
let promises = [];
const promises = [];
for (let key of Object.keys(this)) {
let item = this[key];
const item = this[key];
// @ts-ignore
if (item instanceof InferenceSession) {
promises.push(item.handler.dispose())
}
Expand Down
16 changes: 9 additions & 7 deletions src/utils/generation.js
Original file line number Diff line number Diff line change
Expand Up @@ -261,32 +261,34 @@ export class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
return logits;
}

const logitsData = /** @type {Float32Array} */(logits.data);

// timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
const seq = input_ids.slice(this.begin_index);
const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;

if (last_was_timestamp) {
if (penultimate_was_timestamp) { // has to be non-timestamp
logits.data.subarray(this.timestamp_begin).fill(-Infinity);
logitsData.subarray(this.timestamp_begin).fill(-Infinity);
} else { // cannot be normal text tokens
logits.data.subarray(0, this.eos_token_id).fill(-Infinity);
logitsData.subarray(0, this.eos_token_id).fill(-Infinity);
}
}

// apply the `max_initial_timestamp` option
if (input_ids.length === this.begin_index && this.max_initial_timestamp_index !== null) {
const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
logits.data.subarray(last_allowed + 1).fill(-Infinity);
logitsData.subarray(last_allowed + 1).fill(-Infinity);
}

// if sum of probability over timestamps is above any other token, sample timestamp
const logprobs = log_softmax(logits.data);
const logprobs = log_softmax(logitsData);
const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
const max_text_token_logprob = max(logprobs.subarray(0, this.timestamp_begin))[0];

if (timestamp_logprob > max_text_token_logprob) {
logits.data.subarray(0, this.timestamp_begin).fill(-Infinity);
logitsData.subarray(0, this.timestamp_begin).fill(-Infinity);
}

return logits;
Expand Down Expand Up @@ -697,12 +699,12 @@ export class Sampler extends Callable {
* Returns the specified logits as an array, with temperature applied.
* @param {Tensor} logits
* @param {number} index
* @returns {Array}
* @returns {Float32Array}
*/
getLogits(logits, index) {
let vocabSize = logits.dims.at(-1);

let logs = logits.data;
let logs = /** @type {Float32Array} */(logits.data);

if (index === -1) {
logs = logs.slice(-vocabSize);
Expand Down
15 changes: 13 additions & 2 deletions src/utils/image.js
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export class RawImage {

/**
* Create a new `RawImage` object.
* @param {Uint8ClampedArray} data The pixel data.
* @param {Uint8ClampedArray|Uint8Array} data The pixel data.
* @param {number} width The width of the image.
* @param {number} height The height of the image.
* @param {1|2|3|4} channels The number of channels.
Expand Down Expand Up @@ -173,7 +173,18 @@ export class RawImage {
} else {
throw new Error(`Unsupported channel format: ${channel_format}`);
}
return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]);
if (!(tensor.data instanceof Uint8ClampedArray || tensor.data instanceof Uint8Array)) {
throw new Error(`Unsupported tensor type: ${tensor.type}`);
}
switch (tensor.dims[2]) {
case 1:
case 2:
case 3:
case 4:
return new RawImage(tensor.data, tensor.dims[1], tensor.dims[0], tensor.dims[2]);
default:
throw new Error(`Unsupported number of channels: ${tensor.dims[2]}`);
}
}

/**
Expand Down
25 changes: 13 additions & 12 deletions src/utils/maths.js
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ export function transpose_data(array, dims, axes) {

/**
* Compute the softmax of an array of numbers.
*
* @param {number[]} arr The array of numbers to compute the softmax of.
* @returns {number[]} The softmax array.
* @template {TypedArray|number[]} T
* @param {T} arr The array of numbers to compute the softmax of.
* @returns {T} The softmax array.
*/
export function softmax(arr) {
// Compute the maximum value in the array
Expand All @@ -142,18 +142,20 @@ export function softmax(arr) {
const exps = arr.map(x => Math.exp(x - maxVal));

// Compute the sum of the exponentials
// @ts-ignore
const sumExps = exps.reduce((acc, val) => acc + val, 0);

// Compute the softmax values
const softmaxArr = exps.map(x => x / sumExps);

return softmaxArr;
return /** @type {T} */(softmaxArr);
}

/**
* Calculates the logarithm of the softmax function for the input array.
* @param {number[]} arr The input array to calculate the log_softmax function for.
* @returns {any} The resulting log_softmax array.
* @template {TypedArray|number[]} T
* @param {T} arr The input array to calculate the log_softmax function for.
* @returns {T} The resulting log_softmax array.
*/
export function log_softmax(arr) {
// Compute the softmax values
Expand All @@ -162,7 +164,7 @@ export function log_softmax(arr) {
// Apply log formula to each element
const logSoftmaxArr = softmaxArr.map(x => Math.log(x));

return logSoftmaxArr;
return /** @type {T} */(logSoftmaxArr);
}

/**
Expand All @@ -178,8 +180,7 @@ export function dot(arr1, arr2) {

/**
* Get the top k items from an iterable, sorted by descending order
*
* @param {Array} items The items to be sorted
* @param {any[]|TypedArray} items The items to be sorted
* @param {number} [top_k=0] The number of top items to return (default: 0 = return all)
* @returns {Array} The top k items, sorted by descending order
*/
Expand Down Expand Up @@ -252,8 +253,8 @@ export function min(arr) {

/**
* Returns the value and index of the maximum element in an array.
* @param {number[]|TypedArray} arr array of numbers.
* @returns {number[]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
* @param {number[]|AnyTypedArray} arr array of numbers.
* @returns {[number, number]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax]
* @throws {Error} If array is empty.
*/
export function max(arr) {
Expand All @@ -266,7 +267,7 @@ export function max(arr) {
indexOfMax = i;
}
}
return [max, indexOfMax];
return [Number(max), indexOfMax];
}

function isPowerOfTwo(number) {
Expand Down
Loading

0 comments on commit 8c465a9

Please sign in to comment.