Skip to content

Commit

Permalink
[js/webgpu] support proxy for webgpu (microsoft#15851)
Browse files Browse the repository at this point in the history
### Description
[js/webgpu] support proxy for webgpu. fixes microsoft#15832
  • Loading branch information
fs-eire authored May 15, 2023
1 parent 7b5ecf6 commit 5b43e6b
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 66 deletions.
38 changes: 15 additions & 23 deletions common/lib/env-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,27 @@
import {Env} from './env';

type LogLevelType = Env['logLevel'];
export class EnvImpl implements Env {
constructor() {
this.wasm = {};
this.webgl = {};
this.webgpu = {};
this.logLevelInternal = 'warning';
}

// TODO standadize the getter and setter convention in env for other fields.
let logLevelValue: Required<LogLevelType> = 'warning';

export const env: Env = {
wasm: {} as Env.WebAssemblyFlags,
webgl: {} as Env.WebGLFlags,
webgpu: {} as Env.WebGpuFlags,

set logLevel(value: LogLevelType) {
if (value === undefined) {
return;
}
if (typeof value !== 'string' || ['verbose', 'info', 'warning', 'error', 'fatal'].indexOf(value) === -1) {
throw new Error(`Unsupported logging level: ${value}`);
}
this.logLevelInternal = value;
}
get logLevel(): LogLevelType {
return this.logLevelInternal;
}

debug?: boolean;

wasm: Env.WebAssemblyFlags;
webgl: Env.WebGLFlags;
webgpu: Env.WebGpuFlags;

[name: string]: unknown;
logLevelValue = value;
},
get logLevel(): Required<LogLevelType> {
return logLevelValue;
},
};

private logLevelInternal: Required<LogLevelType>;
}
// set property 'logLevel' so that they can be correctly transferred to worker by `postMessage()`.
Object.defineProperty(env, 'logLevel', {enumerable: true});
4 changes: 2 additions & 2 deletions common/lib/env.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {EnvImpl} from './env-impl';
import {env as envImpl} from './env-impl';

export declare namespace Env {
export type WasmPrefixOrFilePaths = string|{
Expand Down Expand Up @@ -127,4 +127,4 @@ export interface Env {
/**
* Represent a set of flags as a global singleton.
*/
export const env: Env = new EnvImpl();
export const env: Env = envImpl;
4 changes: 2 additions & 2 deletions web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {Backend, env, InferenceSession, SessionHandler} from 'onnxruntime-common';
import {cpus} from 'os';

import {initWasm} from './wasm/proxy-wrapper';
import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler';

/**
Expand Down Expand Up @@ -38,7 +38,7 @@ class OnnxruntimeWebAssemblyBackend implements Backend {
initializeFlags();

// init wasm
await initWasm();
await initializeWebAssemblyInstance();
}
createSessionHandler(path: string, options?: InferenceSession.SessionOptions): Promise<SessionHandler>;
createSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): Promise<SessionHandler>;
Expand Down
10 changes: 7 additions & 3 deletions web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env} from 'onnxruntime-common';
import {Env} from 'onnxruntime-common';

import {LOG_DEBUG} from './log';
import {configureLogger, LOG_DEBUG} from './log';
import {TensorView} from './tensor';
import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
Expand Down Expand Up @@ -94,7 +94,7 @@ export class WebGpuBackend {
profilingQuerySet: GPUQuerySet;
profilingTimeBase?: bigint;

async initialize(): Promise<void> {
async initialize(env: Env): Promise<void> {
if (!navigator.gpu) {
// WebGPU is not available.
throw new Error('WebGpuBackend: WebGPU is not available.');
Expand Down Expand Up @@ -126,6 +126,10 @@ export class WebGpuBackend {
this.kernels = new Map();
this.kernelPersistentData = new Map();
this.kernelCustomData = new Map();

// set up flags for logger
configureLogger(env.logLevel!, !!env.debug);

// TODO: set up flags

this.device.onuncapturederror = ev => {
Expand Down
6 changes: 4 additions & 2 deletions web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Env} from 'onnxruntime-common';

import {OrtWasmModule} from '../binding/ort-wasm';
import {getTensorElementSize} from '../wasm-common';

Expand Down Expand Up @@ -93,11 +95,11 @@ class ComputeContextImpl implements ComputeContext {
}
}

export const init = async(module: OrtWasmModule): Promise<void> => {
export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {
const init = module.jsepInit;
if (init && navigator.gpu) {
const backend = new WebGpuBackend();
await backend.initialize();
await backend.initialize(env);

init(
// backend
Expand Down
16 changes: 12 additions & 4 deletions web/lib/wasm/jsep/log.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env} from 'onnxruntime-common';
import {Env} from 'onnxruntime-common';

import {logLevelStringToEnum} from '../wasm-common';

type LogLevel = NonNullable<typeof env.logLevel>;
type LogLevel = NonNullable<Env['logLevel']>;
type MessageString = string;
type MessageFunction = () => string;
type Message = MessageString|MessageFunction;
Expand All @@ -17,12 +17,20 @@ const doLog = (level: number, message: string): void => {
console.log(`[${logLevelPrefix[level]},${new Date().toISOString()}]${message}`);
};

let configLogLevel: LogLevel|undefined;
let debug: boolean|undefined;

export const configureLogger = ($configLogLevel: LogLevel, $debug: boolean): void => {
configLogLevel = $configLogLevel;
debug = $debug;
};

/**
* A simple logging utility to log messages to the console.
*/
export const LOG = (logLevel: LogLevel, msg: Message): void => {
const messageLevel = logLevelStringToEnum(logLevel);
const configLevel = logLevelStringToEnum(env.logLevel!);
const configLevel = logLevelStringToEnum(configLogLevel);
if (messageLevel >= configLevel) {
doLog(messageLevel, typeof msg === 'function' ? msg() : msg);
}
Expand All @@ -32,7 +40,7 @@ export const LOG = (logLevel: LogLevel, msg: Message): void => {
* A simple logging utility to log messages to the console. Only logs when debug is enabled.
*/
export const LOG_DEBUG: typeof LOG = (...args: Parameters<typeof LOG>) => {
if (env.debug) {
if (debug) {
LOG(...args);
}
};
2 changes: 1 addition & 1 deletion web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ interface MessageInitWasm extends MessageError {

interface MessageInitOrt extends MessageError {
type: 'init-ort';
in ?: {numThreads: number; loggingLevel: number};
in ?: Env;
}

interface MessageCreateSessionAllocate extends MessageError {
Expand Down
20 changes: 13 additions & 7 deletions web/lib/wasm/proxy-worker/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,27 @@
/// <reference lib="webworker" />

import {OrtWasmMessage} from '../proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initOrt, releaseSession, run} from '../wasm-core-impl';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl';
import {initializeWebAssembly} from '../wasm-factory';

self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
switch (ev.data.type) {
case 'init-wasm':
initializeWebAssembly(ev.data.in)
.then(
() => postMessage({type: 'init-wasm'} as OrtWasmMessage),
err => postMessage({type: 'init-wasm', err} as OrtWasmMessage));
try {
initializeWebAssembly(ev.data.in)
.then(
() => postMessage({type: 'init-wasm'} as OrtWasmMessage),
err => postMessage({type: 'init-wasm', err} as OrtWasmMessage));
} catch (err) {
postMessage({type: 'init-wasm', err} as OrtWasmMessage);
}
break;
case 'init-ort':
try {
const {numThreads, loggingLevel} = ev.data.in!;
initOrt(numThreads, loggingLevel);
initRuntime(ev.data.in).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({
type: 'init-ort',
err
} as OrtWasmMessage));
postMessage({type: 'init-ort'} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'init-ort', err} as OrtWasmMessage);
Expand Down
18 changes: 6 additions & 12 deletions web/lib/wasm/proxy-wrapper.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env, InferenceSession} from 'onnxruntime-common';
import {Env, env, InferenceSession} from 'onnxruntime-common';

import {init as initJsep} from './jsep/init';
import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages';
import * as core from './wasm-core-impl';
import {getInstance, initializeWebAssembly} from './wasm-factory';
import {initializeWebAssembly} from './wasm-factory';

const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined';
let proxyWorker: Worker|undefined;
Expand Down Expand Up @@ -99,7 +98,7 @@ const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {

const scriptSrc = typeof document !== 'undefined' ? (document?.currentScript as HTMLScriptElement)?.src : undefined;

export const initWasm = async(): Promise<void> => {
export const initializeWebAssemblyInstance = async(): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
if (initialized) {
return;
Expand Down Expand Up @@ -135,21 +134,16 @@ export const initWasm = async(): Promise<void> => {
}
};

export const initOrt = async(numThreads: number, loggingLevel: number): Promise<void> => {
export const initializeRuntime = async(env: Env): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<void>((resolve, reject) => {
initOrtCallbacks = [resolve, reject];
const message: OrtWasmMessage = {type: 'init-ort', in : {numThreads, loggingLevel}};
const message: OrtWasmMessage = {type: 'init-ort', in : env};
proxyWorker!.postMessage(message);

// TODO: support JSEP in worker
});
} else {
core.initOrt(numThreads, loggingLevel);

// init JSEP if available
await initJsep(getInstance());
await core.initRuntime(env);
}
};

Expand Down
11 changes: 5 additions & 6 deletions web/lib/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ import {env, InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'
import {promisify} from 'util';

import {SerializableModeldata} from './proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initOrt, releaseSession, run} from './proxy-wrapper';
import {logLevelStringToEnum} from './wasm-common';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper';

let ortInit: boolean;
let runtimeInitialized: boolean;

export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler {
private sessionId: number;
Expand All @@ -26,9 +25,9 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler {
}

async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
if (!ortInit) {
await initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel!));
ortInit = true;
if (!runtimeInitialized) {
await initializeRuntime(env);
runtimeInitialized = true;
}

if (typeof pathOrBuffer === 'string') {
Expand Down
2 changes: 1 addition & 1 deletion web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr
/**
* Map string log level to integer value
*/
export const logLevelStringToEnum = (logLevel: 'verbose'|'info'|'warning'|'error'|'fatal'): number => {
export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal'): number => {
switch (logLevel) {
case 'verbose':
return 0;
Expand Down
19 changes: 16 additions & 3 deletions web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {InferenceSession, Tensor} from 'onnxruntime-common';
import {Env, InferenceSession, Tensor} from 'onnxruntime-common';

import {init as initJsep} from './jsep/init';
import {SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages';
import {setRunOptions} from './run-options';
import {setSessionOptions} from './session-options';
import {allocWasmString} from './string-utils';
import {tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {getInstance} from './wasm-factory';

/**
* initialize ORT environment.
* @param numThreads SetGlobalIntraOpNumThreads(numThreads)
* @param loggingLevel CreateEnv(static_cast<OrtLoggingLevel>(logging_level))
*/
export const initOrt = (numThreads: number, loggingLevel: number): void => {
const initOrt = async(numThreads: number, loggingLevel: number): Promise<void> => {
const errorCode = getInstance()._OrtInit(numThreads, loggingLevel);
if (errorCode !== 0) {
throw new Error(`Can't initialize onnxruntime. error code = ${errorCode}`);
}
};

/**
* intialize runtime environment.
* @param env passed in the environment config object.
*/
export const initRuntime = async(env: Env): Promise<void> => {
// init ORT
await initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel));

// init JSEP if available
await initJsep(getInstance(), env);
};

/**
* tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded
*/
Expand Down

0 comments on commit 5b43e6b

Please sign in to comment.