Skip to content

Commit

Permalink
Revert 'auto' as default device for custom ort
Browse files Browse the repository at this point in the history
  • Loading branch information
kallebysantos committed Dec 3, 2024
1 parent 9994c75 commit 42ae1fe
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
1 change: 0 additions & 1 deletion src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ let ONNX;
if (apis.IS_EXPOSED_RUNTIME_ENV) {
// If the JS runtime exposes their own ONNX runtime, use it
ONNX = globalThis[apis.EXPOSED_RUNTIME_SYMBOL];
defaultDevices = ['auto'];

} else if (apis.IS_NODE_ENV) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;
Expand Down
12 changes: 6 additions & 6 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
}

// If the device is not specified, we use the default (supported) execution providers.
const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */(
device ?? (
apis.IS_EXPOSED_RUNTIME_ENV ? 'auto' : (
apis.IS_NODE_ENV ? 'cpu' : 'wasm'
))
let selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */ (
// Do not asign default device if 'IS_EXPOSED_RUNTIME_ENV'
device ?? (apis.IS_EXPOSED_RUNTIME_ENV ? undefined
: (apis.IS_NODE_ENV ? 'cpu' : 'wasm'))
);

const executionProviders = deviceToExecutionProviders(selectedDevice);

// If options.dtype is specified, we use it to choose the suffix for the model file.
Expand Down Expand Up @@ -238,7 +238,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
const free_dimension_overrides = custom_config.free_dimension_overrides;
if (free_dimension_overrides) {
session_options.freeDimensionOverrides ??= free_dimension_overrides;
} else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
} else if (selectedDevice?.startsWith('webnn') && !session_options.freeDimensionOverrides) {
console.warn(
'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' +
'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
Expand Down

0 comments on commit 42ae1fe

Please sign in to comment.