Skip to content

Commit

Permalink
fix format issues
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Jan 10, 2024
1 parent 030d347 commit c4cfde0
Showing 1 changed file with 68 additions and 59 deletions.
127 changes: 68 additions & 59 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError} from './wasm-utils';

let currentEpName: string;
// #region Initializations

/**
Expand Down Expand Up @@ -105,6 +106,7 @@ export const initEp = async(env: Env, epName: string): Promise<void> => {
const initJsep = require('./jsep/init').init;
await initJsep(getInstance(), env, adapter);
}
currentEpName = epName;
};

// #endregion Initializations
Expand Down Expand Up @@ -220,7 +222,7 @@ export const createSession =
}

let graphCaptureEnabled = false;
if (!BUILD_DEFS.DISABLE_WEBGPU) {
if (currentEpName === 'webgpu') {
const executionProviders = options?.executionProviders;
for (const ep of executionProviders!) {
const epName = typeof ep === 'string' ? ep : ep.name;
Expand Down Expand Up @@ -331,70 +333,75 @@ export const releaseSession = (sessionId: number): void => {
};

export const prepareInputOutputTensor =
(tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number):
void => {
if (!tensor) {
tensorHandles.push(0);
return;
}
(tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number,
graphCaptureEnabled = false): void => {
if (!tensor) {
tensorHandles.push(0);
return;
}

const wasm = getInstance();
const wasm = getInstance();

const dataType = tensor[0];
const dims = tensor[1];
const location = tensor[3];
const dataType = tensor[0];
const dims = tensor[1];
const location = tensor[3];

let rawData: number;
let dataByteLength: number;
let rawData: number;
let dataByteLength: number;

if (dataType === 'string' && location === 'gpu-buffer') {
throw new Error('String tensor is not supported on GPU.');
}
if (dataType === 'string' && location === 'gpu-buffer') {
throw new Error('String tensor is not supported on GPU.');
}

if (location === 'gpu-buffer') {
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength);
} else {
const data = tensor[2];

if (Array.isArray(data)) {
// string tensor
dataByteLength = 4 * data.length;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
let dataIndex = rawData / 4;
for (let i = 0; i < data.length; i++) {
if (typeof data[i] !== 'string') {
throw new TypeError(`tensor data at index ${i} is not a string`);
}
wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);
}
} else {
dataByteLength = data.byteLength;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
}
}
if (graphCaptureEnabled && location !== 'gpu-buffer') {
throw new Error(
`External buffer must be provided for input/output index ${index} when graphCaptureEnabled is true.`);
}

const stack = wasm.stackSave();
const dimsOffset = wasm.stackAlloc(4 * dims.length);
try {
let dimIndex = dimsOffset / 4;
dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
const tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length,
dataLocationStringToEnum(location));
if (tensor === 0) {
checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
if (location === 'gpu-buffer') {
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength);
} else {
const data = tensor[2];

if (Array.isArray(data)) {
// string tensor
dataByteLength = 4 * data.length;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
let dataIndex = rawData / 4;
for (let i = 0; i < data.length; i++) {
if (typeof data[i] !== 'string') {
throw new TypeError(`tensor data at index ${i} is not a string`);
}
tensorHandles.push(tensor);
} finally {
wasm.stackRestore(stack);
wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);
}
};
} else {
dataByteLength = data.byteLength;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
}
}

const stack = wasm.stackSave();
const dimsOffset = wasm.stackAlloc(4 * dims.length);
try {
let dimIndex = dimsOffset / 4;
dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
const tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length,
dataLocationStringToEnum(location));
if (tensor === 0) {
checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
}
tensorHandles.push(tensor);
} finally {
wasm.stackRestore(stack);
}
};

/**
* perform inference run
Expand Down Expand Up @@ -431,13 +438,15 @@ export const run = async(

// create input tensors
for (let i = 0; i < inputCount; i++) {
prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]);
prepareInputOutputTensor(
inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], graphCaptureEnabled);
}

// create output tensors
for (let i = 0; i < outputCount; i++) {
prepareInputOutputTensor(
outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]);
outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i],
graphCaptureEnabled);
}

let inputValuesIndex = inputValuesOffset / 4;
Expand Down

0 comments on commit c4cfde0

Please sign in to comment.