Skip to content

Commit c4cfde0

Browse files
committed
fix format issues
1 parent 030d347 commit c4cfde0

File tree

1 file changed

+68
-59
lines changed

1 file changed

+68
-59
lines changed

js/web/lib/wasm/wasm-core-impl.ts

+68-59
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType
1010
import {getInstance} from './wasm-factory';
1111
import {allocWasmString, checkLastError} from './wasm-utils';
1212

13+
let currentEpName: string;
1314
// #region Initializations
1415

1516
/**
@@ -105,6 +106,7 @@ export const initEp = async(env: Env, epName: string): Promise<void> => {
105106
const initJsep = require('./jsep/init').init;
106107
await initJsep(getInstance(), env, adapter);
107108
}
109+
currentEpName = epName;
108110
};
109111

110112
// #endregion Initializations
@@ -220,7 +222,7 @@ export const createSession =
220222
}
221223

222224
let graphCaptureEnabled = false;
223-
if (!BUILD_DEFS.DISABLE_WEBGPU) {
225+
if (currentEpName === 'webgpu') {
224226
const executionProviders = options?.executionProviders;
225227
for (const ep of executionProviders!) {
226228
const epName = typeof ep === 'string' ? ep : ep.name;
@@ -331,70 +333,75 @@ export const releaseSession = (sessionId: number): void => {
331333
};
332334

333335
export const prepareInputOutputTensor =
334-
(tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number):
335-
void => {
336-
if (!tensor) {
337-
tensorHandles.push(0);
338-
return;
339-
}
336+
(tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number,
337+
graphCaptureEnabled = false): void => {
338+
if (!tensor) {
339+
tensorHandles.push(0);
340+
return;
341+
}
340342

341-
const wasm = getInstance();
343+
const wasm = getInstance();
342344

343-
const dataType = tensor[0];
344-
const dims = tensor[1];
345-
const location = tensor[3];
345+
const dataType = tensor[0];
346+
const dims = tensor[1];
347+
const location = tensor[3];
346348

347-
let rawData: number;
348-
let dataByteLength: number;
349+
let rawData: number;
350+
let dataByteLength: number;
349351

350-
if (dataType === 'string' && location === 'gpu-buffer') {
351-
throw new Error('String tensor is not supported on GPU.');
352-
}
352+
if (dataType === 'string' && location === 'gpu-buffer') {
353+
throw new Error('String tensor is not supported on GPU.');
354+
}
353355

354-
if (location === 'gpu-buffer') {
355-
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
356-
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
357-
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
358-
rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength);
359-
} else {
360-
const data = tensor[2];
361-
362-
if (Array.isArray(data)) {
363-
// string tensor
364-
dataByteLength = 4 * data.length;
365-
rawData = wasm._malloc(dataByteLength);
366-
allocs.push(rawData);
367-
let dataIndex = rawData / 4;
368-
for (let i = 0; i < data.length; i++) {
369-
if (typeof data[i] !== 'string') {
370-
throw new TypeError(`tensor data at index ${i} is not a string`);
371-
}
372-
wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);
373-
}
374-
} else {
375-
dataByteLength = data.byteLength;
376-
rawData = wasm._malloc(dataByteLength);
377-
allocs.push(rawData);
378-
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
379-
}
380-
}
356+
if (graphCaptureEnabled && location !== 'gpu-buffer') {
357+
throw new Error(
358+
`External buffer must be provided for input/output index ${index} when graphCaptureEnabled is true.`);
359+
}
381360

382-
const stack = wasm.stackSave();
383-
const dimsOffset = wasm.stackAlloc(4 * dims.length);
384-
try {
385-
let dimIndex = dimsOffset / 4;
386-
dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
387-
const tensor = wasm._OrtCreateTensor(
388-
tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length,
389-
dataLocationStringToEnum(location));
390-
if (tensor === 0) {
391-
checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
361+
if (location === 'gpu-buffer') {
362+
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
363+
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
364+
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
365+
rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength);
366+
} else {
367+
const data = tensor[2];
368+
369+
if (Array.isArray(data)) {
370+
// string tensor
371+
dataByteLength = 4 * data.length;
372+
rawData = wasm._malloc(dataByteLength);
373+
allocs.push(rawData);
374+
let dataIndex = rawData / 4;
375+
for (let i = 0; i < data.length; i++) {
376+
if (typeof data[i] !== 'string') {
377+
throw new TypeError(`tensor data at index ${i} is not a string`);
392378
}
393-
tensorHandles.push(tensor);
394-
} finally {
395-
wasm.stackRestore(stack);
379+
wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);
396380
}
397-
};
381+
} else {
382+
dataByteLength = data.byteLength;
383+
rawData = wasm._malloc(dataByteLength);
384+
allocs.push(rawData);
385+
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
386+
}
387+
}
388+
389+
const stack = wasm.stackSave();
390+
const dimsOffset = wasm.stackAlloc(4 * dims.length);
391+
try {
392+
let dimIndex = dimsOffset / 4;
393+
dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
394+
const tensor = wasm._OrtCreateTensor(
395+
tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length,
396+
dataLocationStringToEnum(location));
397+
if (tensor === 0) {
398+
checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
399+
}
400+
tensorHandles.push(tensor);
401+
} finally {
402+
wasm.stackRestore(stack);
403+
}
404+
};
398405

399406
/**
400407
* perform inference run
@@ -431,13 +438,15 @@ export const run = async(
431438

432439
// create input tensors
433440
for (let i = 0; i < inputCount; i++) {
434-
prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]);
441+
prepareInputOutputTensor(
442+
inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], graphCaptureEnabled);
435443
}
436444

437445
// create output tensors
438446
for (let i = 0; i < outputCount; i++) {
439447
prepareInputOutputTensor(
440-
outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]);
448+
outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i],
449+
graphCaptureEnabled);
441450
}
442451

443452
let inputValuesIndex = inputValuesOffset / 4;

0 commit comments

Comments
 (0)