Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ export class WebNNBackend {
* Temporary tensors for the current session.
*/
private temporarySessionTensorIds: Map<number, TensorId[]> = new Map();
/**
* Maps from session id to MLOpSupportLimits.
*/
private mlOpSupportLimitsBySessionId = new Map<number, MLOpSupportLimits>();

constructor(env: Env) {
configureLogger(env.logLevel!, !!env.debug);
Expand Down Expand Up @@ -172,6 +176,10 @@ export class WebNNBackend {
}
sessionIds.add(sessionId);

if (!this.mlOpSupportLimitsBySessionId.has(sessionId)) {
this.mlOpSupportLimitsBySessionId.set(sessionId, mlContext.opSupportLimits());
}

if (this.temporaryGraphInputs.length > 0) {
this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs);
this.temporaryGraphInputs = [];
Expand All @@ -192,6 +200,7 @@ export class WebNNBackend {
}
this.tensorManager.releaseTensorsForSession(sessionId);
this.mlContextBySessionId.delete(sessionId);
this.mlOpSupportLimitsBySessionId.delete(sessionId);
const sessionIds = this.sessionIdsByMLContext.get(mlContext)!;
sessionIds.delete(sessionId);
if (sessionIds.size === 0) {
Expand All @@ -207,6 +216,10 @@ export class WebNNBackend {
return this.mlContextBySessionId.get(sessionId);
}

public getMLOpSupportLimits(sessionId: number): MLOpSupportLimits | undefined {
return this.mlOpSupportLimitsBySessionId.get(sessionId);
}

public reserveTensorId(): TensorId {
return this.tensorManager.reserveTensorId();
}
Expand Down Expand Up @@ -399,17 +412,17 @@ export class WebNNBackend {
}

public isGraphInputOutputTypeSupported(sessionId: number, type: Tensor.Type, isInput = true): boolean {
const context = this.mlContextBySessionId.get(sessionId);
const dataType = onnxDataTypeToWebnnDataType.get(tensorDataTypeStringToEnum(type));
const opLimits = this.mlOpSupportLimitsBySessionId.get(sessionId);

if (typeof dataType === 'undefined') {
return false;
}

if (isInput) {
return !!context?.opSupportLimits().input.dataTypes.includes(dataType);
return !!opLimits?.input.dataTypes.includes(dataType);
} else {
return !!context?.opSupportLimits().output.dataTypes.includes(dataType);
return !!opLimits?.output.dataTypes.includes(dataType);
}
}

Expand Down
9 changes: 7 additions & 2 deletions js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,12 @@ class TensorIdTracker {
copyOld: boolean,
): Promise<MLTensor> {
const context = this.tensorManager.getMLContext(sessionId);
const opLimits = this.tensorManager.getMLOpSupportLimits(sessionId);
let fallbackDataType: MLOperandDataType | undefined;
// Check if the context supports the data type. If not, try to use the fallback data type.
if (!context.opSupportLimits().input.dataTypes.includes(dataType)) {
if (!opLimits?.input.dataTypes.includes(dataType)) {
fallbackDataType = webnnDataTypeToFallback.get(dataType);
if (!fallbackDataType || !context.opSupportLimits().input.dataTypes.includes(fallbackDataType)) {
if (!fallbackDataType || opLimits?.input.dataTypes.includes(fallbackDataType)) {
throw new Error(`WebNN backend does not support data type: ${dataType}`);
}
LOG_DEBUG(
Expand Down Expand Up @@ -460,6 +461,10 @@ class TensorManagerImpl implements TensorManager {
return context;
}

public getMLOpSupportLimits(sessionId: number): MLOpSupportLimits | undefined {
return this.backend.getMLOpSupportLimits(sessionId);
}

public reserveTensorId(): TensorId {
const tensorId = createNewTensorId();
this.tensorTrackersById.set(tensorId, new TensorIdTracker(this));
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/core/providers/webnn/data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,17 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {

if (dst_device.Type() == OrtDevice::GPU) {
EM_ASM({ Module.webnnUploadTensor($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast<intptr_t>(src_data), bytes);
if (trace) {
console.call<void>("timeEnd", emscripten::val("ORT::DataTransfer::webnnUploadTensor"));
}
} else {
auto webnnDownloadTensor = emscripten::val::module_property("webnnDownloadTensor");
auto subarray = emscripten::typed_memory_view(bytes, static_cast<char*>(dst_data));
webnnDownloadTensor(reinterpret_cast<intptr_t>(src_data), subarray).await();
if (trace) {
console.call<void>("timeEnd", emscripten::val("ORT::DataTransfer::webnnDownloadTensor"));
}
}
}

if (trace) {
console.call<void>("timeEnd", emscripten::val("ORT::DataTransfer::CopyTensor"));
}

return Status::OK();
}

Expand Down
Loading