Skip to content

Commit

Permalink
check webgpu backend in execution loop
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent b0d7dfa commit 7c5e446
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions js/web/lib/onnxjs/execution-plan.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env} from 'onnxruntime-common';

import {SessionHandler} from './backend';
import {WebGpuBackend} from './backends/backend-webgpu';
import {Graph} from './graph';
import {Logger, Profiler} from './instrument';
import {Operator} from './operators';
Expand Down Expand Up @@ -57,6 +60,7 @@ export class ExecutionPlan {

// create inference handler
const inferenceHandler = sessionHandler.createInferenceHandler();
const IS_WEBGPU = sessionHandler.backend instanceof WebGpuBackend;

// populate inputs value
const graphInputs = this.graph.getInputIndices();
Expand Down Expand Up @@ -103,13 +107,28 @@ export class ExecutionPlan {
throw new Error('the size of output does not match model definition.');
}

if (env.debug) {
for (let i = 0; i < outputList.length; i++) {
if (IS_WEBGPU) {
await outputList[i].getData();
} else {
// eslint-disable-next-line no-unused-expressions
outputList[i].data;
}
}
}

// fill value
outputList.forEach((output, i) => {
const j = thisOp.node.outputs[i];
if (this._values[j]) {
throw new Error(`output [${j}] already has value: op:${thisOp.node.name}`);
}
this._values[j] = output;

if (env.debug) {
Logger.verbose('ExecPlanDataDump', `output${i}[${output.dims}]:${output.data}`);
}
});

// resolve downstream nodes
Expand Down Expand Up @@ -141,13 +160,8 @@ export class ExecutionPlan {
throw new Error(`required output [${outputIndex}] does not have value`);
}

// TODO: use env to check
const IS_WEBGPU = true;

if (IS_WEBGPU) {
await outputTensor.getData();
} else if (outputIndex === 0) {
await outputTensor.getData();
} else {
// eslint-disable-next-line no-unused-expressions
outputTensor.data;
Expand Down

0 comments on commit 7c5e446

Please sign in to comment.