Skip to content

Commit

Permalink
Update compatibility for webgpu EP
Browse files Browse the repository at this point in the history
Requires dynamic imports
  • Loading branch information
xenova committed Nov 7, 2023
1 parent 704d95d commit 66da130
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 63 deletions.
107 changes: 95 additions & 12 deletions src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,32 @@
* @module backends/onnx
*/

import path from 'path';
import { env, RUNNING_LOCALLY } from '../env.js';

// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
// In either case, we select the default export if it exists, otherwise we use the named export.
import * as ONNX_NODE from 'onnxruntime-node';
import * as ONNX_WEB from 'onnxruntime-web';

/** @type {module} The ONNX runtime module. */
export let ONNX;

export const executionProviders = [
// 'webgpu',
'wasm'
];
let ONNX;

if (typeof process !== 'undefined' && process?.release?.name === 'node') {
// Running in a node-like environment.
ONNX = ONNX_NODE.default ?? ONNX_NODE;
const WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator;
const USE_ONNXRUNTIME_NODE = typeof process !== 'undefined' && process?.release?.name === 'node'

// Add `cpu` execution provider, with higher precedence that `wasm`.
executionProviders.unshift('cpu');
const ONNX_MODULES = new Map();

if (USE_ONNXRUNTIME_NODE) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;
ONNX_MODULES.set('node', ONNX);
} else {
// Running in a browser-environment
// @ts-ignore
ONNX = ONNX_WEB.default ?? ONNX_WEB;
ONNX_MODULES.set('web', ONNX);

// Running in a browser-environment
// TODO: Check if 1.16.1 fixes this issue.
// SIMD for WebAssembly does not operate correctly in some recent versions of iOS (16.4.x).
// As a temporary fix, we disable it for now.
// For more information, see: https://github.com/microsoft/onnxruntime/issues/15644
Expand All @@ -48,3 +50,84 @@ if (typeof process !== 'undefined' && process?.release?.name === 'node') {
ONNX.env.wasm.simd = false;
}
}

/**
* Create an ONNX inference session, with fallback support if an operation is not supported.
* @param {Uint8Array} buffer
* @returns {Promise<Object>} The ONNX inference session.
*/
export async function createInferenceSession(buffer) {
let executionProviders;
let InferenceSession;
if (USE_ONNXRUNTIME_NODE) {
const ONNX_NODE = ONNX_MODULES.get('node');
InferenceSession = ONNX_NODE.InferenceSession;
executionProviders = ['cpu'];

} else if (WEBGPU_AVAILABLE && env.experimental.useWebGPU) {
// Only import the WebGPU version if the user enables the experimental flag.
let ONNX_WEBGPU = ONNX_MODULES.get('webgpu');
if (ONNX_WEBGPU === undefined) {
ONNX_WEBGPU = await import('onnxruntime-web/webgpu');
ONNX_MODULES.set('webgpu', ONNX_WEBGPU)
}

InferenceSession = ONNX_WEBGPU.InferenceSession;

// If WebGPU is available and the user enables the experimental flag, try to use the WebGPU execution provider.
executionProviders = ['webgpu', 'wasm'];

ONNX_WEBGPU.env = env.backends.onnx;

} else {
const ONNX_WEB = ONNX_MODULES.get('web');
InferenceSession = ONNX_WEB.InferenceSession;
executionProviders = ['wasm'];
env.backends.onnx = ONNX_MODULES.get('web').env
}

try {
return await InferenceSession.create(buffer, {
executionProviders,
});
} catch (err) {
// If the execution provided was only wasm, throw the error
if (executionProviders.length === 1 && executionProviders[0] === 'wasm') {
throw err;
}

console.warn(err);
console.warn(
'Something went wrong during model construction (most likely a missing operation). ' +
'Using `wasm` as a fallback. '
)
return await InferenceSession.create(buffer, {
executionProviders: ['wasm']
});
}
}

/**
* Check if an object is an ONNX tensor.
* @param {any} x The object to check
* @returns {boolean} Whether the object is an ONNX tensor.
*/
export function isONNXTensor(x) {
for (const module of ONNX_MODULES.values()) {
if (x instanceof module.Tensor) {
return true;
}
}
return false;
}

// Set path to wasm files. This is needed when running in a web worker.
// https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths
// We use remote wasm files by default to make it easier for newer users.
// In practice, users should probably self-host the necessary .wasm files.
ONNX.env.wasm.wasmPaths = RUNNING_LOCALLY
? path.join(env.__dirname, '/dist/')
: `https://cdn.jsdelivr.net/npm/@xenova/transformers@${env.version}/dist/`;

// Expose ONNX environment variables to `env.backends.onnx`
env.backends.onnx = ONNX.env;
23 changes: 10 additions & 13 deletions src/env.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,14 @@ import fs from 'fs';
import path from 'path';
import url from 'url';

import { ONNX } from './backends/onnx.js';
const { env: onnx_env } = ONNX;

const VERSION = '3.0.0-alpha.0';

// Check if various APIs are available (depends on environment)
const WEB_CACHE_AVAILABLE = typeof self !== 'undefined' && 'caches' in self;
const FS_AVAILABLE = !isEmpty(fs); // check if file system is available
const PATH_AVAILABLE = !isEmpty(path); // check if path is available

const RUNNING_LOCALLY = FS_AVAILABLE && PATH_AVAILABLE;
export const RUNNING_LOCALLY = FS_AVAILABLE && PATH_AVAILABLE;

const __dirname = RUNNING_LOCALLY
? path.dirname(path.dirname(url.fileURLToPath(import.meta.url)))
Expand All @@ -53,14 +50,6 @@ const localModelPath = RUNNING_LOCALLY
? path.join(__dirname, DEFAULT_LOCAL_MODEL_PATH)
: DEFAULT_LOCAL_MODEL_PATH;

// Set path to wasm files. This is needed when running in a web worker.
// https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths
// We use remote wasm files by default to make it easier for newer users.
// In practice, users should probably self-host the necessary .wasm files.
onnx_env.wasm.wasmPaths = RUNNING_LOCALLY
? path.join(__dirname, '/dist/')
: `https://cdn.jsdelivr.net/npm/@xenova/transformers@${VERSION}/dist/`;


/**
* Global variable used to control execution. This provides users a simple way to configure Transformers.js.
Expand All @@ -83,16 +72,24 @@ onnx_env.wasm.wasmPaths = RUNNING_LOCALLY
* @property {Object} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which
* implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache
*/

export const env = {
/////////////////// Backends settings ///////////////////
// NOTE: These will be populated later by the backends themselves.
backends: {
// onnxruntime-web/onnxruntime-node
onnx: onnx_env,
onnx: {},

// TensorFlow.js
tfjs: {},
},

/////////////////// Experimental settings ///////////////////
experimental: {
// Whether to use the experimental WebGPU backend for ONNX.js.
useWebGPU: false,
},

__dirname,
version: VERSION,

Expand Down
35 changes: 8 additions & 27 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,8 @@ import {
Tensor,
} from './utils/tensor.js';

import { executionProviders, ONNX } from './backends/onnx.js';
import { createInferenceSession, isONNXTensor } from './backends/onnx.js';
import { medianFilter } from './transformers.js';
const { InferenceSession, Tensor: ONNXTensor } = ONNX;

//////////////////////////////////////////////////
// Model types: used internally
Expand Down Expand Up @@ -111,38 +110,19 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
* @param {import('./utils/hub.js').PretrainedOptions} options Additional options for loading the model.
* @returns {Promise<InferenceSession>} A Promise that resolves to an InferenceSession object.
* @returns {Promise<Object>} A Promise that resolves to an InferenceSession object.
* @private
*/
async function constructSession(pretrained_model_name_or_path, fileName, options) {
// TODO add option for user to force specify their desired execution provider
let modelFileName = `onnx/${fileName}${options.quantized ? '_quantized' : ''}.onnx`;
let buffer = await getModelFile(pretrained_model_name_or_path, modelFileName, true, options);

try {
return await InferenceSession.create(buffer, {
executionProviders,
});
} catch (err) {
// If the execution provided was only wasm, throw the error
if (executionProviders.length === 1 && executionProviders[0] === 'wasm') {
throw err;
}

console.warn(err);
console.warn(
'Something went wrong during model construction (most likely a missing operation). ' +
'Using `wasm` as a fallback. '
)
return await InferenceSession.create(buffer, {
executionProviders: ['wasm']
});
}
return await createInferenceSession(buffer);
}

/**
* Validate model inputs
* @param {InferenceSession} session The InferenceSession object that will be run.
* @param {Object} session The InferenceSession object that will be run.
* @param {Object} inputs The inputs to check.
* @returns {Promise<Object>} A Promise that resolves to the checked inputs.
* @throws {Error} If any inputs are missing.
Expand Down Expand Up @@ -182,7 +162,7 @@ async function validateInputs(session, inputs) {
* - If additional inputs are passed, they will be ignored.
* - If inputs are missing, an error will be thrown.
*
* @param {InferenceSession} session The InferenceSession object to run.
* @param {Object} session The InferenceSession object to run.
* @param {Object} inputs An object that maps input names to input tensors.
* @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors.
* @private
Expand All @@ -209,7 +189,7 @@ async function sessionRun(session, inputs) {
*/
function replaceTensors(obj) {
for (let prop in obj) {
if (obj[prop] instanceof ONNXTensor) {
if (isONNXTensor(obj[prop])) {
obj[prop] = new Tensor(obj[prop]);
} else if (typeof obj[prop] === 'object') {
replaceTensors(obj[prop]);
Expand Down Expand Up @@ -639,7 +619,8 @@ export class PreTrainedModel extends Callable {
let promises = [];
for (let key of Object.keys(this)) {
let item = this[key];
if (item instanceof InferenceSession) {
// TODO improve check for ONNX session
if (item?.handler?.dispose !== undefined) {
promises.push(item.handler.dispose())
}
}
Expand Down
19 changes: 8 additions & 11 deletions src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
* @module utils/tensor
*/

import { ONNX } from '../backends/onnx.js';

import {
interpolate_data,
transpose_data
Expand All @@ -19,22 +17,21 @@ import {
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray
*/

/** @type {Object} */
const ONNXTensor = ONNX.Tensor;

export class Tensor extends ONNXTensor {
export class Tensor {
/**
* Create a new Tensor or copy an existing Tensor.
* @param {[string, DataArray, number[]]|[ONNXTensor]} args
* @param {[string, DataArray, number[]]|Object} args
*/
constructor(...args) {
if (args[0] instanceof ONNX.Tensor) {
if (args.length === 1) {
// Create shallow copy
super(args[0].type, args[0].data, args[0].dims);
Object.assign(this, args[0]);

} else {
// Create new
super(...args);
// Create new tensor
this.type = args[0];
this.data = args[1];
this.dims = args[2];
}

return new Proxy(this, {
Expand Down

0 comments on commit 66da130

Please sign in to comment.