diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 860424d31ffab..0bb1b09687018 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -97,6 +97,10 @@ export class WebNNBackend { * Temporary tensors for the current session. */ private temporarySessionTensorIds: Map = new Map(); + /** + * Maps from session id to MLOpSupportLimits. + */ + private mlOpSupportLimitsBySessionId = new Map(); constructor(env: Env) { configureLogger(env.logLevel!, !!env.debug); @@ -172,6 +176,10 @@ export class WebNNBackend { } sessionIds.add(sessionId); + if (!this.mlOpSupportLimitsBySessionId.has(sessionId)) { + this.mlOpSupportLimitsBySessionId.set(sessionId, mlContext.opSupportLimits()); + } + if (this.temporaryGraphInputs.length > 0) { this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs); this.temporaryGraphInputs = []; @@ -192,6 +200,7 @@ export class WebNNBackend { } this.tensorManager.releaseTensorsForSession(sessionId); this.mlContextBySessionId.delete(sessionId); + this.mlOpSupportLimitsBySessionId.delete(sessionId); const sessionIds = this.sessionIdsByMLContext.get(mlContext)!; sessionIds.delete(sessionId); if (sessionIds.size === 0) { @@ -207,6 +216,10 @@ export class WebNNBackend { return this.mlContextBySessionId.get(sessionId); } + public getMLOpSupportLimits(sessionId: number): MLOpSupportLimits | undefined { + return this.mlOpSupportLimitsBySessionId.get(sessionId); + } + public reserveTensorId(): TensorId { return this.tensorManager.reserveTensorId(); } @@ -399,17 +412,17 @@ export class WebNNBackend { } public isGraphInputOutputTypeSupported(sessionId: number, type: Tensor.Type, isInput = true): boolean { - const context = this.mlContextBySessionId.get(sessionId); const dataType = onnxDataTypeToWebnnDataType.get(tensorDataTypeStringToEnum(type)); + const opLimits = this.mlOpSupportLimitsBySessionId.get(sessionId); if (typeof dataType === 'undefined') { return false; } if (isInput) { - return !!context?.opSupportLimits().input.dataTypes.includes(dataType); + return !!opLimits?.input.dataTypes.includes(dataType); } else { - return !!context?.opSupportLimits().output.dataTypes.includes(dataType); + return !!opLimits?.output.dataTypes.includes(dataType); } } diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 784e5ab53c1fa..70a2a9a892470 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -336,11 +336,12 @@ class TensorIdTracker { copyOld: boolean, ): Promise { const context = this.tensorManager.getMLContext(sessionId); + const opLimits = this.tensorManager.getMLOpSupportLimits(sessionId); let fallbackDataType: MLOperandDataType | undefined; // Check if the context supports the data type. If not, try to use the fallback data type. - if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { + if (!opLimits?.input.dataTypes.includes(dataType)) { fallbackDataType = webnnDataTypeToFallback.get(dataType); - if (!fallbackDataType || !context.opSupportLimits().input.dataTypes.includes(fallbackDataType)) { + if (!fallbackDataType || opLimits?.input.dataTypes.includes(fallbackDataType)) { throw new Error(`WebNN backend does not support data type: ${dataType}`); } LOG_DEBUG( @@ -460,6 +461,10 @@ class TensorManagerImpl implements TensorManager { return context; } + public getMLOpSupportLimits(sessionId: number): MLOpSupportLimits | undefined { + return this.backend.getMLOpSupportLimits(sessionId); + } + public reserveTensorId(): TensorId { const tensorId = createNewTensorId(); this.tensorTrackersById.set(tensorId, new TensorIdTracker(this)); diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc index 17369e6fbc75d..6b5bed4ecdeae 100644 --- a/onnxruntime/core/providers/webnn/data_transfer.cc +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -35,19 +35,17 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (dst_device.Type() == OrtDevice::GPU) { EM_ASM({ Module.webnnUploadTensor($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast(src_data), bytes); - if (trace) { - console.call("timeEnd", emscripten::val("ORT::DataTransfer::webnnUploadTensor")); - } } else { auto webnnDownloadTensor = emscripten::val::module_property("webnnDownloadTensor"); auto subarray = emscripten::typed_memory_view(bytes, static_cast(dst_data)); webnnDownloadTensor(reinterpret_cast(src_data), subarray).await(); - if (trace) { - console.call("timeEnd", emscripten::val("ORT::DataTransfer::webnnDownloadTensor")); - } } } + if (trace) { + console.call("timeEnd", emscripten::val("ORT::DataTransfer::CopyTensor")); + } + return Status::OK(); }