Skip to content

Commit

Permalink
[js/webgpu] Enable GroupedConvVectorize path (#19791)
Browse files Browse the repository at this point in the history
Vectorize met 2 failed cases in a CI bot with NVIDIA GPU, but we
couldn't repro with all the GPUs at hand, including NVIDIA GPUs. This PR
introduces GPUAdapterInfo and enables this opt on non-NVIDIA GPUs to
make the bots happy.
No obivous perf gain can be seen if we enable vectorize on NVIDIA.
However, it shows big perf improvement on Intel. On my Gen12 Intel GPU,
mobilenetv2-12 perf was improved from 11.14ms to 7.1ms.
  • Loading branch information
gyagp authored Mar 13, 2024
1 parent 4538d31 commit 53de2d8
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 5 deletions.
24 changes: 23 additions & 1 deletion js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {createView, TensorView} from './tensor-view';
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
import {ProgramManager} from './webgpu/program-manager';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';
import {AdapterInfo, ComputeContext, GpuArchitecture, GpuData, GpuVendor, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';

interface CommandInfo {
readonly kernelId: number;
Expand Down Expand Up @@ -94,11 +94,32 @@ const getProgramInfoUniqueKey =
return key;
};

class AdapterInfoImpl implements AdapterInfo {
readonly architecture?: string;
readonly vendor?: string;

constructor(adapterInfo: GPUAdapterInfo) {
if (adapterInfo) {
this.architecture = adapterInfo.architecture;
this.vendor = adapterInfo.vendor;
}
}

isArchitecture(architecture: GpuArchitecture): boolean {
return this.architecture === architecture;
}

isVendor(vendor: GpuVendor): boolean {
return this.vendor === vendor;
}
}

/**
* this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
* the first parameter so that it is stored for future use.
*/
export class WebGpuBackend {
adapterInfo: AdapterInfoImpl;
device: GPUDevice;
/**
* an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping
Expand Down Expand Up @@ -212,6 +233,7 @@ export class WebGpuBackend {
}

this.device = await adapter.requestDevice(deviceDescriptor);
this.adapterInfo = new AdapterInfoImpl(await adapter.requestAdapterInfo());
this.gpuDataManager = createGpuDataManager(this);
this.programManager = new ProgramManager(this);
this.kernels = new Map();
Expand Down
4 changes: 3 additions & 1 deletion js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {WebGpuBackend} from './backend-webgpu';
import {LOG_DEBUG} from './log';
import {TensorView} from './tensor-view';
import {ShapeUtil} from './util';
import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types';
import {AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types';

/* eslint-disable no-bitwise */

Expand Down Expand Up @@ -54,6 +54,7 @@ class TensorViewImpl implements TensorView {
}

class ComputeContextImpl implements ComputeContext {
readonly adapterInfo: AdapterInfo;
readonly opKernelContext: number;
readonly inputs: readonly TensorView[];
readonly outputCount: number;
Expand All @@ -66,6 +67,7 @@ class ComputeContextImpl implements ComputeContext {
private customDataOffset = 0;
private customDataSize = 0;
constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) {
this.adapterInfo = backend.adapterInfo;
const heapU32 = module.HEAPU32;

// extract context data
Expand Down
7 changes: 4 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
const isChannelsLast = attributes.format === 'NHWC';
if (attributes.group !== 1) {
// Temporarily disable createGroupedConvVectorizeProgramInfo path due to bots failures with below two cases:
// NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other
// GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs.
// [webgpu]Conv - conv - vectorize group - B
// [webgpu]Conv - conv - vectorize group - D
const disableGroupedConvVectorize = true;
if (!disableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group &&
const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere');
if (enableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group &&
inputs[1].dims[1] === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1) {
const outputShape = calculateOutputShape(
inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides,
Expand Down
12 changes: 12 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ export enum GpuDataType {
}
export type GpuDataId = number;

export type GpuArchitecture = 'ampere';
export type GpuVendor = 'amd'|'intel'|'nvidia';
export interface AdapterInfo {
isArchitecture: (architecture: GpuArchitecture) => boolean;
isVendor: (vendor: GpuVendor) => boolean;
}

export interface GpuData {
type: GpuDataType;
id: GpuDataId;
Expand Down Expand Up @@ -146,6 +153,11 @@ export interface ComputeContextInputsOutputsMapping {
* A ComputeContext instance carries the states that representing the current running of a kernel.
*/
export interface ComputeContext {
/**
* gpu adapter info
*/
readonly adapterInfo: AdapterInfo;

/**
* stores the pointer to OpKernelContext
*/
Expand Down

0 comments on commit 53de2d8

Please sign in to comment.