diff --git a/js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts b/js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts index c9f767bdc191a..d7eb843a031db 100644 --- a/js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts +++ b/js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts @@ -42,6 +42,11 @@ interface DownloadCacheValue { data: Promise; } +/** + * normalize the buffer size so that it fits the 128-bits (16 bytes) alignment. + */ +const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; + class GpuDataManagerImpl implements GpuDataManager { // GPU Data ID => GPU Data ( storage buffer ) storageCache: Map; @@ -62,10 +67,10 @@ class GpuDataManagerImpl implements GpuDataManager { const srcArrayBuffer = data.buffer; const srcOffset = data.byteOffset; const srcLength = data.byteLength; + const size = calcNormalizedBufferSize(srcLength); // create gpu buffer - const gpuBuffer = - this.device.createBuffer({mappedAtCreation: true, size: srcLength, usage: GPUBufferUsage.STORAGE}); + const gpuBuffer = this.device.createBuffer({mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE}); // copy (upload) data const arrayBuffer = gpuBuffer.getMappedRange(); @@ -89,11 +94,12 @@ class GpuDataManagerImpl implements GpuDataManager { const elemCount = ShapeUtil.size(dims); const bufferLength = sizeof(type) * elemCount; + const size = calcNormalizedBufferSize(bufferLength); // create gpu buffer const gpuBuffer = // eslint-disable-next-line no-bitwise - this.device.createBuffer({size: bufferLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC}); + this.device.createBuffer({size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC}); const gpuData = {id: Guid.create(), type: GpuDataType.default, buffer: gpuBuffer}; this.storageCache.set(gpuData.id, {gpuData, size: bufferLength});