Skip to content

Commit

Permalink
only bind input/output once for IOBinding when graphCaptureEnabled =
Browse files Browse the repository at this point in the history
true
  • Loading branch information
qjia7 committed Jan 9, 2024
1 parent 56dc850 commit 02091c3
Showing 1 changed file with 57 additions and 32 deletions.
89 changes: 57 additions & 32 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ type IOBindingState = {
*/
type SessionMetadata = [
inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[],
bindingState: IOBindingState|null
bindingState: IOBindingState|null, graphCaptureEnabled: boolean, inputOutputBounded: boolean
];

const activeSessions = new Map<number, SessionMetadata>();
Expand Down Expand Up @@ -219,6 +219,19 @@ export const createSession =
checkLastError('Can\'t create a session.');
}

let graphCaptureEnabled = false;
if (!BUILD_DEFS.DISABLE_WEBGPU) {
const executionProviders = options?.executionProviders;
for (const ep of executionProviders!) {
const epName = typeof ep === 'string' ? ep : ep.name;
if (epName === 'webgpu') {
const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption;
graphCaptureEnabled =
webgpuOptions.graphCaptureEnabled === undefined ? false : webgpuOptions.graphCaptureEnabled;
}
}
}

const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);

const inputNames = [];
Expand All @@ -242,6 +255,10 @@ export const createSession =
outputNames.push(nameString);

if (!BUILD_DEFS.DISABLE_WEBGPU) {
if (graphCaptureEnabled) {
outputPreferredLocations.push('gpu-buffer');
continue;
}
const location = typeof options?.preferredOutputLocation === 'string' ?
options.preferredOutputLocation :
options?.preferredOutputLocation?.[nameString] ?? 'cpu';
Expand All @@ -267,7 +284,9 @@ export const createSession =
};
}

activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]);
activeSessions.set(
sessionHandle,
[sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, graphCaptureEnabled, false]);
return [sessionHandle, inputNames, outputNames];
} catch (e) {
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
Expand Down Expand Up @@ -388,7 +407,8 @@ export const run = async(
if (!session) {
throw new Error(`cannot run inference. invalid session id: ${sessionId}`);
}
const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session;
const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, graphCaptureEnabled, inputOutputBounded] =
session;

const inputCount = inputIndices.length;
const outputCount = outputIndices.length;
Expand Down Expand Up @@ -434,41 +454,46 @@ export const run = async(
}

if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState;

if (inputNamesUTF8Encoded.length !== inputCount) {
throw new Error(`input count from feeds (${
inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`);
}
if (!graphCaptureEnabled || (graphCaptureEnabled && !inputOutputBounded)) {
const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState;

// process inputs
for (let i = 0; i < inputCount; i++) {
const index = inputIndices[i];
const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]);
if (errorCode !== 0) {
checkLastError(`Can't bind input[${i}] for session=${sessionId}.`);
if (inputNamesUTF8Encoded.length !== inputCount) {
throw new Error(`input count from feeds (${
inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`);
}
}

// process pre-allocated outputs
for (let i = 0; i < outputCount; i++) {
const index = outputIndices[i];
const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated.

if (location) {
// output is pre-allocated. bind the tensor.
const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0);
// process inputs
for (let i = 0; i < inputCount; i++) {
const index = inputIndices[i];
const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]);
if (errorCode !== 0) {
checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`);
checkLastError(`Can't bind input[${i}] for session=${sessionId}.`);
}
} else {
// output is not pre-allocated. reset preferred location.
const errorCode =
wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]);
if (errorCode !== 0) {
checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`);
}

// process pre-allocated outputs
for (let i = 0; i < outputCount; i++) {
const index = outputIndices[i];
const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated.

if (location) {
// output is pre-allocated. bind the tensor.
const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0);
if (errorCode !== 0) {
checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`);
}
} else {
// output is not pre-allocated. reset preferred location.
const errorCode =
wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]);
if (errorCode !== 0) {
checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`);
}
}
}
activeSessions.set(
sessionId,
[sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, graphCaptureEnabled, true]);
}
}

Expand Down Expand Up @@ -579,7 +604,7 @@ export const run = async(
}
}

if (ioBindingState) {
if (ioBindingState && !graphCaptureEnabled) {
wasm._OrtClearBoundOutputs(ioBindingState.handle);
}
return output;
Expand Down

0 comments on commit 02091c3

Please sign in to comment.