@@ -10,6 +10,7 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType
10
10
import { getInstance } from './wasm-factory' ;
11
11
import { allocWasmString , checkLastError } from './wasm-utils' ;
12
12
13
+ let currentEpName : string ;
13
14
// #region Initializations
14
15
15
16
/**
@@ -105,6 +106,7 @@ export const initEp = async(env: Env, epName: string): Promise<void> => {
105
106
const initJsep = require ( './jsep/init' ) . init ;
106
107
await initJsep ( getInstance ( ) , env , adapter ) ;
107
108
}
109
+ currentEpName = epName ;
108
110
} ;
109
111
110
112
// #endregion Initializations
@@ -220,7 +222,7 @@ export const createSession =
220
222
}
221
223
222
224
let graphCaptureEnabled = false ;
223
- if ( ! BUILD_DEFS . DISABLE_WEBGPU ) {
225
+ if ( currentEpName === 'webgpu' ) {
224
226
const executionProviders = options ?. executionProviders ;
225
227
for ( const ep of executionProviders ! ) {
226
228
const epName = typeof ep === 'string' ? ep : ep . name ;
@@ -331,70 +333,75 @@ export const releaseSession = (sessionId: number): void => {
331
333
} ;
332
334
333
335
export const prepareInputOutputTensor =
334
- ( tensor : TensorMetadata | null , tensorHandles : number [ ] , allocs : number [ ] , sessionId : number , index : number ) :
335
- void => {
336
- if ( ! tensor ) {
337
- tensorHandles . push ( 0 ) ;
338
- return ;
339
- }
336
+ ( tensor : TensorMetadata | null , tensorHandles : number [ ] , allocs : number [ ] , sessionId : number , index : number ,
337
+ graphCaptureEnabled = false ) : void => {
338
+ if ( ! tensor ) {
339
+ tensorHandles . push ( 0 ) ;
340
+ return ;
341
+ }
340
342
341
- const wasm = getInstance ( ) ;
343
+ const wasm = getInstance ( ) ;
342
344
343
- const dataType = tensor [ 0 ] ;
344
- const dims = tensor [ 1 ] ;
345
- const location = tensor [ 3 ] ;
345
+ const dataType = tensor [ 0 ] ;
346
+ const dims = tensor [ 1 ] ;
347
+ const location = tensor [ 3 ] ;
346
348
347
- let rawData : number ;
348
- let dataByteLength : number ;
349
+ let rawData : number ;
350
+ let dataByteLength : number ;
349
351
350
- if ( dataType === 'string' && location === 'gpu-buffer' ) {
351
- throw new Error ( 'String tensor is not supported on GPU.' ) ;
352
- }
352
+ if ( dataType === 'string' && location === 'gpu-buffer' ) {
353
+ throw new Error ( 'String tensor is not supported on GPU.' ) ;
354
+ }
353
355
354
- if ( location === 'gpu-buffer' ) {
355
- const gpuBuffer = tensor [ 2 ] . gpuBuffer as GPUBuffer ;
356
- const elementSizeInBytes = getTensorElementSize ( tensorDataTypeStringToEnum ( dataType ) ) ! ;
357
- dataByteLength = dims . reduce ( ( a , b ) => a * b , 1 ) * elementSizeInBytes ;
358
- rawData = wasm . jsepRegisterBuffer ( sessionId , index , gpuBuffer , dataByteLength ) ;
359
- } else {
360
- const data = tensor [ 2 ] ;
361
-
362
- if ( Array . isArray ( data ) ) {
363
- // string tensor
364
- dataByteLength = 4 * data . length ;
365
- rawData = wasm . _malloc ( dataByteLength ) ;
366
- allocs . push ( rawData ) ;
367
- let dataIndex = rawData / 4 ;
368
- for ( let i = 0 ; i < data . length ; i ++ ) {
369
- if ( typeof data [ i ] !== 'string' ) {
370
- throw new TypeError ( `tensor data at index ${ i } is not a string` ) ;
371
- }
372
- wasm . HEAPU32 [ dataIndex ++ ] = allocWasmString ( data [ i ] , allocs ) ;
373
- }
374
- } else {
375
- dataByteLength = data . byteLength ;
376
- rawData = wasm . _malloc ( dataByteLength ) ;
377
- allocs . push ( rawData ) ;
378
- wasm . HEAPU8 . set ( new Uint8Array ( data . buffer , data . byteOffset , dataByteLength ) , rawData ) ;
379
- }
380
- }
356
+ if ( graphCaptureEnabled && location !== 'gpu-buffer' ) {
357
+ throw new Error (
358
+ `External buffer must be provided for input/output index ${ index } when graphCaptureEnabled is true.` ) ;
359
+ }
381
360
382
- const stack = wasm . stackSave ( ) ;
383
- const dimsOffset = wasm . stackAlloc ( 4 * dims . length ) ;
384
- try {
385
- let dimIndex = dimsOffset / 4 ;
386
- dims . forEach ( d => wasm . HEAP32 [ dimIndex ++ ] = d ) ;
387
- const tensor = wasm . _OrtCreateTensor (
388
- tensorDataTypeStringToEnum ( dataType ) , rawData , dataByteLength , dimsOffset , dims . length ,
389
- dataLocationStringToEnum ( location ) ) ;
390
- if ( tensor === 0 ) {
391
- checkLastError ( `Can't create tensor for input/output. session=${ sessionId } , index=${ index } .` ) ;
361
+ if ( location === 'gpu-buffer' ) {
362
+ const gpuBuffer = tensor [ 2 ] . gpuBuffer as GPUBuffer ;
363
+ const elementSizeInBytes = getTensorElementSize ( tensorDataTypeStringToEnum ( dataType ) ) ! ;
364
+ dataByteLength = dims . reduce ( ( a , b ) => a * b , 1 ) * elementSizeInBytes ;
365
+ rawData = wasm . jsepRegisterBuffer ( sessionId , index , gpuBuffer , dataByteLength ) ;
366
+ } else {
367
+ const data = tensor [ 2 ] ;
368
+
369
+ if ( Array . isArray ( data ) ) {
370
+ // string tensor
371
+ dataByteLength = 4 * data . length ;
372
+ rawData = wasm . _malloc ( dataByteLength ) ;
373
+ allocs . push ( rawData ) ;
374
+ let dataIndex = rawData / 4 ;
375
+ for ( let i = 0 ; i < data . length ; i ++ ) {
376
+ if ( typeof data [ i ] !== 'string' ) {
377
+ throw new TypeError ( `tensor data at index ${ i } is not a string` ) ;
392
378
}
393
- tensorHandles . push ( tensor ) ;
394
- } finally {
395
- wasm . stackRestore ( stack ) ;
379
+ wasm . HEAPU32 [ dataIndex ++ ] = allocWasmString ( data [ i ] , allocs ) ;
396
380
}
397
- } ;
381
+ } else {
382
+ dataByteLength = data . byteLength ;
383
+ rawData = wasm . _malloc ( dataByteLength ) ;
384
+ allocs . push ( rawData ) ;
385
+ wasm . HEAPU8 . set ( new Uint8Array ( data . buffer , data . byteOffset , dataByteLength ) , rawData ) ;
386
+ }
387
+ }
388
+
389
+ const stack = wasm . stackSave ( ) ;
390
+ const dimsOffset = wasm . stackAlloc ( 4 * dims . length ) ;
391
+ try {
392
+ let dimIndex = dimsOffset / 4 ;
393
+ dims . forEach ( d => wasm . HEAP32 [ dimIndex ++ ] = d ) ;
394
+ const tensor = wasm . _OrtCreateTensor (
395
+ tensorDataTypeStringToEnum ( dataType ) , rawData , dataByteLength , dimsOffset , dims . length ,
396
+ dataLocationStringToEnum ( location ) ) ;
397
+ if ( tensor === 0 ) {
398
+ checkLastError ( `Can't create tensor for input/output. session=${ sessionId } , index=${ index } .` ) ;
399
+ }
400
+ tensorHandles . push ( tensor ) ;
401
+ } finally {
402
+ wasm . stackRestore ( stack ) ;
403
+ }
404
+ } ;
398
405
399
406
/**
400
407
* perform inference run
@@ -431,13 +438,15 @@ export const run = async(
431
438
432
439
// create input tensors
433
440
for ( let i = 0 ; i < inputCount ; i ++ ) {
434
- prepareInputOutputTensor ( inputTensors [ i ] , inputTensorHandles , inputOutputAllocs , sessionId , inputIndices [ i ] ) ;
441
+ prepareInputOutputTensor (
442
+ inputTensors [ i ] , inputTensorHandles , inputOutputAllocs , sessionId , inputIndices [ i ] , graphCaptureEnabled ) ;
435
443
}
436
444
437
445
// create output tensors
438
446
for ( let i = 0 ; i < outputCount ; i ++ ) {
439
447
prepareInputOutputTensor (
440
- outputTensors [ i ] , outputTensorHandles , inputOutputAllocs , sessionId , inputCount + outputIndices [ i ] ) ;
448
+ outputTensors [ i ] , outputTensorHandles , inputOutputAllocs , sessionId , inputCount + outputIndices [ i ] ,
449
+ graphCaptureEnabled ) ;
441
450
}
442
451
443
452
let inputValuesIndex = inputValuesOffset / 4 ;
0 commit comments