From 14cc02c65ccd946b2f231756e86afa76095bb1cd Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 24 Apr 2023 15:21:18 -0700 Subject: [PATCH] [js/web] WebGPU backend via JSEP (#14579) ### Description This change introduced the following new components into ONNX Runtime Web: - JavaScript Execution Provider (JSEP) - Asynchronized inferencing execution powered by Emscripten's Asyncify - WebGPU backend implemented in TypeScript - initial implementation of kernels: - elementwise operators (22) - binary operators (5) - tensor: Shape, Reshape, Transpose, Gemm - nn: Conv, {Global}Maxpool, {Global}AveragePool Code need to be polished. still working on it. ## Q&A What is JSEP? > JSEP, aka JavaScript Execution Provider, is a new ONNXRuntime execution provider that specifically works on Web environment (browsers). JSEP allows JavaScript code to kick in from various places when ONNX Runtime inferences a model. Why JSEP? > JSEP is a hybrid mode EP that contains both C/C++ and TypeScript/JavaScript implementation. There are 2 strong reasons why we introduces JSEP: > 1. the C/C++ part helps JSEP to leverage ONNX Runtime's capabilities as much as possible including graph transformer, optimizers and also the capabilities to fallback to CPU EP. TypeScript/JavaScript helps JSEP to develop and debug much easier in the browser for the kernel implementation. > 2. the requirement of asynchronized execution from JavaScript API (eg. `buffer.mapAsync()`) makes it impossible to run `OrtRun()` in a synchronized context (see "async problem" section below). This is done by using Emscripten's Asyncify. What is WebGPU? > WebGPU is the new GPU API that available in browser. It's one of the only 2 APIs that currently available to access the GPU from browser (the other is WebGL). > WebGPU is designed with more advanced and stronger features comparing to WebGL and is potentially solution that offer the best GPU performance for model inferencing that currently available. What is the async problem and why we have the problem? > The "async problem" is a problem that you cannot call an async function in a synchronous context. Think about the following C++ code: > ```c > // C-style declarations (API) > typedef void (*ON_COMPLETE)(PVOID state, DATA *data); > void read_data_from_file(FILEHANDLE file, ON_COMPLETE on_complete); > > // implementation > DATA * my_impl_read_data_from_file_sync(FILEHANDLE file) { > // how to implement? > } > ``` > The answer is, it's impossible to implement this function. Usually we try to find a sync version API, or launch a thread to call the async function and sync-wait on the main thread. Unfortunately, in browser environment, neither is possible. > > WebGPU does not offer any synchronized API for data downloading (GPU to CPU). This is the only operation that MUST be async. As `OrtRun()` will eventually call into DataTransfer for copy data from GPU to CPU, and `OrtRun()` is a synchronized function, this cannot be done in normal way. What is Emscripten? How is the Asyncify feature resolved the problem? > Emscripten is the C/C++ compiler for WebAssembly. It's what we use to compile ORT and generates the WebAssembly artifacts which runs on browsers. > > Asyncify is a [compiler feature](https://emscripten.org/docs/porting/asyncify.html) that allows calling async functions from a synchronized context. In short, it generates code to unwind and rewind call stack to emulate async execution. With this feature, we are able to call the async function inside `OrtRun()` call. ## Design Overview **Inter-op** JSEP is doing pretty much same thing to just another EP. It exposes an interface for inter-op with JavaScript, which is defined in onnxruntime/wasm/js_internal_api.js: ```js // init JSEP Module["jsepInit"] = function (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, run) { Module.jsepBackend = backend; Module.jsepAlloc = alloc; Module.jsepFree = free; Module.jsepCopy = copy; Module.jsepCopyAsync = copyAsync; Module.jsepCreateKernel = createKernel; Module.jsepReleaseKernel = releaseKernel; Module.jsepRun = run; }; ``` This simple JavaScript snippet defines all language barrier level functions that requires by JSEP to achieve implementing kernels and data transfers using JavaScript inside ONNX Runtime: - `jsepBackend`: assign the singleton object to webassembly module - `jsepAlloc` and `jsepFree`: implementation of data transfer's Alloc() and Free() - `jsepCopy`: synchronized copy ( GPU to GPU, CPU to GPU) - `jsepCopyAsync`: asynchronized copy ( GPU to CPU) - `jsepCreateKernel` and `jsepReleaseKernel`: a corresponding object that maintained in JS to match lifecycle of Kernel in ORT - `jsepRun`: OpKernel::Compute() should call into this The abstraction above allows to tie as little as possible connections and dependencies between C/C++ and TypeScript/JavaScript. **Resource Management** Lifecycle of tensor data and kernels are managed by ORT(C/C++) but the implementation are left to JavaScript. JavaScript code are responsible to implement the callbacks correctly. For WebGPU, the GPU data is managed by JavaScript using a singleton map (tensot_data_id => GPUBuffer). GPU pipeline is managed as singleton. Shaders are managed using a singletonmap (shader_key => gpu_program), while shader_key is generated by cache_key (OP specific, including attributes) and input shapes. **about data transfer** `js::DataTransfer::CopyTensor` implemented to call either synchronized or asynchronized copy callback, depending on the destination is GPU or not. Emscripten's macro `EM_ASYNC_JS` is used to wrap the async function to be called in the synchronized context. **run kernel in JS** Kernel class constructor calls once `jsepCreateKernel()` with an optional per-kernel specific serialization to pass attributes into JavaScript. `Compute()` are implemented in a way that a metadata serialization is performed in a base class and JavaScript code can access the data using the Emscripten specific builtin macro `EM_ASM_*`. **disabled features** memory pattern is force disabled, because the WebGPU data is not presented by a general memory model (a buffer can be represented by offset + size). concurrent run support is disabled. WebGPU is stateful and it also has async function call. To support concurrent run will significantly increase the complexity and we don't get any real benefit from it. **prefer channels last** JSEP prefers channels last and returns `DataLayout::NHWC` in method `GetPreferredLayout()`. This will let the graph transformers to preprocess the graph into a channels last form so that a more optimized WebGPU shader can be used. **Testing code** It's impossible to test JSEP directly because JSEP itself does not contain any kernel implementation. However, it has the kernel registration which need to work together with the corresponding JavaScript code. There are unit tests that run onnx models from JavaScript API. --------- Co-authored-by: Scott McKay --- ThirdPartyNotices.txt | 221 +++- cmake/CMakeLists.txt | 6 + cmake/onnxruntime_providers.cmake | 21 + cmake/onnxruntime_unittests.cmake | 15 + cmake/onnxruntime_webassembly.cmake | 76 +- include/onnxruntime/core/graph/constants.h | 1 + js/.eslintrc.js | 6 + js/common/lib/env-impl.ts | 3 +- js/common/lib/env.ts | 9 + js/web/karma.conf.js | 36 +- js/web/lib/build-def.d.ts | 4 + js/web/lib/index.ts | 4 + js/web/lib/onnxjs/backend.ts | 2 +- .../lib/onnxjs/backends/webgl/ops/reduce.ts | 3 +- js/web/lib/onnxjs/opset.ts | 2 - js/web/lib/wasm/binding/ort-wasm.d.ts | 22 + js/web/lib/wasm/jsep/backend-webgpu.ts | 345 ++++++ js/web/lib/wasm/jsep/init.ts | 149 +++ js/web/lib/wasm/jsep/log.ts | 38 + js/web/lib/wasm/jsep/tensor.ts | 105 ++ js/web/lib/wasm/jsep/util.ts | 774 ++++++++++++ .../jsep/webgpu/attribute-with-cache-key.ts | 27 + .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 231 ++++ .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 56 + .../webgpu/ops/3rd-party/activation_util.ts | 52 + .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 250 ++++ .../jsep/webgpu/ops/3rd-party/conv_util.ts | 31 + .../ops/3rd-party/matmul_packed_webgpu.ts | 327 +++++ js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 213 ++++ js/web/lib/wasm/jsep/webgpu/ops/common.ts | 137 +++ js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 158 +++ .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 130 ++ js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 252 ++++ js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts | 28 + js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 41 + js/web/lib/wasm/jsep/webgpu/ops/gemm.ts | 146 +++ js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 100 ++ js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 379 ++++++ js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 98 ++ js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 220 ++++ .../lib/wasm/jsep/webgpu/program-manager.ts | 143 +++ js/web/lib/wasm/jsep/webgpu/types.ts | 140 +++ js/web/lib/wasm/proxy-worker/main.ts | 10 +- js/web/lib/wasm/proxy-wrapper.ts | 8 +- js/web/lib/wasm/session-handler.ts | 22 +- js/web/lib/wasm/session-options.ts | 9 + js/web/lib/wasm/wasm-common.ts | 158 +++ js/web/lib/wasm/wasm-core-impl.ts | 420 +++---- js/web/package-lock.json | 13 + js/web/package.json | 1 + js/web/script/test-runner-cli-args.ts | 30 +- js/web/script/test-runner-cli.ts | 47 +- js/web/test/suite-test-list.jsonc | 1093 ++++++++++++++++- js/web/test/test-main.ts | 3 + js/web/test/test-runner.ts | 67 +- js/web/tsconfig.json | 1 + js/web/webpack.config.js | 1 + .../core/providers/get_execution_providers.cc | 8 + onnxruntime/core/providers/js/allocator.cc | 29 + onnxruntime/core/providers/js/allocator.h | 39 + .../core/providers/js/data_transfer.cc | 46 + onnxruntime/core/providers/js/data_transfer.h | 23 + .../providers/js/js_execution_provider.cc | 339 +++++ .../core/providers/js/js_execution_provider.h | 52 + onnxruntime/core/providers/js/js_export.cc | 26 + onnxruntime/core/providers/js/js_export.h | 14 + onnxruntime/core/providers/js/js_kernel.cc | 9 + onnxruntime/core/providers/js/js_kernel.h | 148 +++ .../core/providers/js/js_provider_factory.cc | 30 + .../js/js_provider_factory_creator.h | 17 + .../core/providers/js/operators/binary.cc | 54 + .../core/providers/js/operators/conv.cc | 41 + .../core/providers/js/operators/conv.h | 98 ++ .../core/providers/js/operators/gemm.cc | 40 + .../core/providers/js/operators/gemm.h | 35 + .../core/providers/js/operators/matmul.cc | 20 + .../core/providers/js/operators/pool.cc | 76 ++ .../core/providers/js/operators/pool.h | 75 ++ .../core/providers/js/operators/reshape.cc | 46 + .../core/providers/js/operators/reshape.h | 48 + .../core/providers/js/operators/shape_op.cc | 47 + .../core/providers/js/operators/transpose.cc | 28 + .../core/providers/js/operators/transpose.h | 35 + .../core/providers/js/operators/unary.cc | 120 ++ onnxruntime/core/providers/js/symbols.txt | 0 .../providers/provider_factory_creators.h | 4 + .../core/session/provider_registration.cc | 6 + onnxruntime/wasm/api.cc | 9 +- onnxruntime/wasm/js_internal_api.js | 16 + tools/ci_build/build.py | 2 + .../azure-pipelines/templates/web-ci.yml | 6 + .../azure-pipelines/templates/win-wasm-ci.yml | 15 + 92 files changed, 8091 insertions(+), 394 deletions(-) create mode 100644 js/web/lib/wasm/jsep/backend-webgpu.ts create mode 100644 js/web/lib/wasm/jsep/init.ts create mode 100644 js/web/lib/wasm/jsep/log.ts create mode 100644 js/web/lib/wasm/jsep/tensor.ts create mode 100644 js/web/lib/wasm/jsep/util.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/common.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/concat.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/conv.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/gemm.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/matmul.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/pool.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/transpose.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/program-manager.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/types.ts create mode 100644 js/web/lib/wasm/wasm-common.ts create mode 100644 onnxruntime/core/providers/js/allocator.cc create mode 100644 onnxruntime/core/providers/js/allocator.h create mode 100644 onnxruntime/core/providers/js/data_transfer.cc create mode 100644 onnxruntime/core/providers/js/data_transfer.h create mode 100644 onnxruntime/core/providers/js/js_execution_provider.cc create mode 100644 onnxruntime/core/providers/js/js_execution_provider.h create mode 100644 onnxruntime/core/providers/js/js_export.cc create mode 100644 onnxruntime/core/providers/js/js_export.h create mode 100644 onnxruntime/core/providers/js/js_kernel.cc create mode 100644 onnxruntime/core/providers/js/js_kernel.h create mode 100644 onnxruntime/core/providers/js/js_provider_factory.cc create mode 100644 onnxruntime/core/providers/js/js_provider_factory_creator.h create mode 100644 onnxruntime/core/providers/js/operators/binary.cc create mode 100644 onnxruntime/core/providers/js/operators/conv.cc create mode 100644 onnxruntime/core/providers/js/operators/conv.h create mode 100644 onnxruntime/core/providers/js/operators/gemm.cc create mode 100644 onnxruntime/core/providers/js/operators/gemm.h create mode 100644 onnxruntime/core/providers/js/operators/matmul.cc create mode 100644 onnxruntime/core/providers/js/operators/pool.cc create mode 100644 onnxruntime/core/providers/js/operators/pool.h create mode 100644 onnxruntime/core/providers/js/operators/reshape.cc create mode 100644 onnxruntime/core/providers/js/operators/reshape.h create mode 100644 onnxruntime/core/providers/js/operators/shape_op.cc create mode 100644 onnxruntime/core/providers/js/operators/transpose.cc create mode 100644 onnxruntime/core/providers/js/operators/transpose.h create mode 100644 onnxruntime/core/providers/js/operators/unary.cc create mode 100644 onnxruntime/core/providers/js/symbols.txt create mode 100644 onnxruntime/wasm/js_internal_api.js diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index e925f75090a46..b4d981d42dfb8 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -5422,8 +5422,8 @@ _____ Tencent/rapidjson, https://github.com/Tencent/rapidjson -Tencent is pleased to support the open source community by making RapidJSON available. - +Tencent is pleased to support the open source community by making RapidJSON available. + Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. All rights reserved. If you have downloaded a copy of the RapidJSON binary from Tencent, please note that the RapidJSON binary is licensed under the MIT License. @@ -5435,13 +5435,13 @@ Other dependencies and licenses: Open Source Software Licensed Under the BSD License: -------------------------------------------------------------------- -The msinttypes r29 -Copyright (c) 2006-2013 Alexander Chemeris +The msinttypes r29 +Copyright (c) 2006-2013 Alexander Chemeris All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. @@ -5450,7 +5450,7 @@ THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND ANY EXPR Open Source Software Licensed Under the JSON License: -------------------------------------------------------------------- -json.org +json.org Copyright (c) 2002 JSON.org All Rights Reserved. @@ -5784,3 +5784,212 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +_____ + +TensorFlow.js + +https://github.com/tensorflow/tfjs + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ce41c264356ff..37395c8d11545 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -68,6 +68,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) +option(onnxruntime_USE_JS "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) option(onnxruntime_BUILD_OBJC "Build Objective-C library" OFF) @@ -655,6 +656,11 @@ if (onnxruntime_USE_NNAPI_BUILTIN) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_NNAPI_BUILTIN=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES nnapi) endif() +if (onnxruntime_USE_JS) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_JS=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_JS=1) + list(APPEND ONNXRUNTIME_PROVIDER_NAMES js) +endif() if (onnxruntime_USE_QNN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_QNN=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_QNN=1) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 905f50f99b80a..c253b6b9c7197 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -114,6 +114,9 @@ endif() if(onnxruntime_USE_NNAPI_BUILTIN) set(PROVIDERS_NNAPI onnxruntime_providers_nnapi) endif() +if(onnxruntime_USE_JS) + set(PROVIDERS_JS onnxruntime_providers_js) +endif() if(onnxruntime_USE_QNN) set(PROVIDERS_QNN onnxruntime_providers_qnn) endif() @@ -1064,6 +1067,24 @@ if (onnxruntime_USE_NNAPI_BUILTIN) endif() endif() +if (onnxruntime_USE_JS) + add_compile_definitions(USE_JS=1) + + file(GLOB_RECURSE onnxruntime_providers_js_cc_srcs + "${ONNXRUNTIME_ROOT}/core/providers/js/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/js/*.cc" + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_js_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_js ${onnxruntime_providers_js_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_js + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11 + ) + + add_dependencies(onnxruntime_providers_js ${onnxruntime_EXTERNAL_DEPENDENCIES}) + +endif() + if (onnxruntime_USE_QNN) add_compile_definitions(USE_QNN=1) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 9c68d5e74e580..c0c50e0e11b08 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -504,6 +504,10 @@ if(onnxruntime_USE_NNAPI_BUILTIN) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nnapi) endif() +if(onnxruntime_USE_JS) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_js) +endif() + if(onnxruntime_USE_RKNPU) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_rknpu) endif() @@ -551,6 +555,7 @@ set(ONNXRUNTIME_TEST_LIBS ${onnxruntime_libs} # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime ${PROVIDERS_NNAPI} + ${PROVIDERS_JS} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} ${PROVIDERS_RKNPU} @@ -604,6 +609,13 @@ if(onnxruntime_USE_NNAPI_BUILTIN) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_nnapi) endif() +if(onnxruntime_USE_JS) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/js/*) + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_js) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_js) + list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_js) +endif() + if(onnxruntime_USE_QNN) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/*) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_qnn) @@ -839,6 +851,9 @@ if (onnxruntime_BUILD_WEBASSEMBLY) if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s USE_PTHREADS=1 -s PROXY_TO_PTHREAD=1") endif() + if (onnxruntime_USE_JS) + set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " --pre-js \"${ONNXRUNTIME_ROOT}/wasm/js_internal_api.js\"") + endif() endif() if (onnxruntime_ENABLE_ATEN) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 2188565a876d3..80a44ffb3fa63 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -108,6 +108,7 @@ if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB) onnxruntime_mlas onnxruntime_optimizer onnxruntime_providers + ${PROVIDERS_JS} ${PROVIDERS_XNNPACK} onnxruntime_session onnxruntime_util @@ -183,6 +184,7 @@ else() onnxruntime_mlas onnxruntime_optimizer onnxruntime_providers + ${PROVIDERS_JS} ${PROVIDERS_XNNPACK} onnxruntime_session onnxruntime_util @@ -197,49 +199,83 @@ else() endif() set(EXPORTED_RUNTIME_METHODS "['stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8']") + if (onnxruntime_USE_JS) + set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput") + else() + set(EXPORTED_FUNCTIONS "_malloc,_free") + endif() + + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s EXPORTED_RUNTIME_METHODS=${EXPORTED_RUNTIME_METHODS}" + "SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}" + "SHELL:-s MAXIMUM_MEMORY=4294967296" + "SHELL:-s EXIT_RUNTIME=0" + "SHELL:-s ALLOW_MEMORY_GROWTH=1" + "SHELL:-s MODULARIZE=1" + "SHELL:-s EXPORT_ALL=0" + "SHELL:-s VERBOSE=0" + "SHELL:-s FILESYSTEM=0" + ${WASM_API_EXCEPTION_CATCHING} + --no-entry + ) - set_target_properties(onnxruntime_webassembly PROPERTIES LINK_FLAGS " \ - -s \"EXPORTED_RUNTIME_METHODS=${EXPORTED_RUNTIME_METHODS}\" \ - -s \"EXPORTED_FUNCTIONS=_malloc,_free\" \ - -s MAXIMUM_MEMORY=4294967296 \ - -s EXIT_RUNTIME=0 \ - -s ALLOW_MEMORY_GROWTH=1 \ - -s MODULARIZE=1 \ - -s EXPORT_ALL=0 \ - -s VERBOSE=0 \ - -s FILESYSTEM=0 \ - ${WASM_API_EXCEPTION_CATCHING} \ - --no-entry") + if (onnxruntime_USE_JS) + # NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU + # This flag allows async functions to be called from sync functions, in the cost of binary size and + # build time. See https://emscripten.org/docs/porting/asyncify.html for more details. + + target_compile_definitions(onnxruntime_webassembly PRIVATE USE_JS=1) + target_link_options(onnxruntime_webassembly PRIVATE + --pre-js "${ONNXRUNTIME_ROOT}/wasm/js_internal_api.js" + "SHELL:-s ASYNCIFY=1" + "SHELL:-s ASYNCIFY_STACK_SIZE=65536" + ) + endif() if (onnxruntime_EMSCRIPTEN_SETTINGS) foreach(setting IN LISTS onnxruntime_EMSCRIPTEN_SETTINGS) - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS - " -s ${setting}") + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ${setting}") endforeach() endif() if (CMAKE_BUILD_TYPE STREQUAL "Debug") - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=2 -s SAFE_HEAP=1 -s STACK_OVERFLOW_CHECK=1 -s DEMANGLE_SUPPORT=1") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s ASSERTIONS=2" + "SHELL:-s SAFE_HEAP=1" + "SHELL:-s STACK_OVERFLOW_CHECK=1" + "SHELL:-s DEMANGLE_SUPPORT=1" + ) else() - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=0 -s SAFE_HEAP=0 -s STACK_OVERFLOW_CHECK=0 -s DEMANGLE_SUPPORT=0 --closure 1") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s ASSERTIONS=0" + "SHELL:-s SAFE_HEAP=0" + "SHELL:-s STACK_OVERFLOW_CHECK=0" + "SHELL:-s DEMANGLE_SUPPORT=0" + --closure 1 + ) endif() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -s DISABLE_EXCEPTION_THROWING=0") + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING) - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --profiling --profiling-funcs") + target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs) endif() if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -s EXPORT_NAME=ortWasmThreaded -s USE_PTHREADS=1") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s EXPORT_NAME=ortWasmThreaded" + "SHELL:-s USE_PTHREADS=1" + ) if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME "ort-wasm-simd-threaded") else() set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME "ort-wasm-threaded") endif() else() - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -s EXPORT_NAME=ortWasm") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s EXPORT_NAME=ortWasm" + ) if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME "ort-wasm-simd") else() diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 928330293ce3d..6fc9ef6e1c8c3 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -44,6 +44,7 @@ constexpr const char* kAclExecutionProvider = "ACLExecutionProvider"; constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider"; constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider"; constexpr const char* kCoreMLExecutionProvider = "CoreMLExecutionProvider"; +constexpr const char* kJsExecutionProvider = "JsExecutionProvider"; constexpr const char* kSnpeExecutionProvider = "SNPEExecutionProvider"; constexpr const char* kTvmExecutionProvider = "TvmExecutionProvider"; constexpr const char* kXnnpackExecutionProvider = "XnnpackExecutionProvider"; diff --git a/js/.eslintrc.js b/js/.eslintrc.js index 24620a2791871..519284617f428 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -182,6 +182,12 @@ module.exports = { 'import/no-extraneous-dependencies': 'off', 'no-console': 'off' } + }, { + files: ['web/lib/**/3rd-party/**/*.ts'], rules: { + 'header/header': 'off', + 'unicorn/filename-case': 'off', + '@typescript-eslint/explicit-module-boundary-types': 'off', + } }], extends: [ 'eslint:recommended', diff --git a/js/common/lib/env-impl.ts b/js/common/lib/env-impl.ts index 9b9ea78e83364..f4f3f447b4c1a 100644 --- a/js/common/lib/env-impl.ts +++ b/js/common/lib/env-impl.ts @@ -8,6 +8,7 @@ export class EnvImpl implements Env { constructor() { this.wasm = {}; this.webgl = {}; + this.webgpu = {}; this.logLevelInternal = 'warning'; } @@ -28,8 +29,8 @@ export class EnvImpl implements Env { debug?: boolean; wasm: Env.WebAssemblyFlags; - webgl: Env.WebGLFlags; + webgpu: Env.WebGpuFlags; [name: string]: unknown; diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 59aa6426b7cd7..8d75a4c99a65f 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -86,6 +86,10 @@ export declare namespace Env { */ async?: boolean; } + + export interface WebGpuFlags { + profilingMode?: 'off'|'default'; + } } export interface Env { @@ -112,6 +116,11 @@ export interface Env { */ webgl: Env.WebGLFlags; + /** + * Represent a set of flags for WebGPU + */ + webgpu: Env.WebGpuFlags; + [name: string]: unknown; } diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 49553088bd2d8..2a4e71e064632 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -6,6 +6,7 @@ const bundleMode = require('minimist')(process.argv)['bundle-mode'] || 'dev'; // 'dev'|'perf'|undefined; const karmaPlugins = require('minimist')(process.argv)['karma-plugins'] || undefined; const timeoutMocha = require('minimist')(process.argv)['timeout-mocha'] || 60000; +const forceLocalHost = !!require('minimist')(process.argv)['force-localhost']; const commonFile = bundleMode === 'dev' ? '../common/dist/ort-common.js' : '../common/dist/ort-common.min.js' const mainFile = bundleMode === 'dev' ? 'test/ort.dev.js' : 'test/ort.perf.js'; @@ -16,18 +17,20 @@ const mainFile = bundleMode === 'dev' ? 'test/ort.dev.js' : 'test/ort.perf.js'; // https://stackoverflow.com/a/8440736 // function getMachineIpAddress() { - var os = require('os'); - var ifaces = os.networkInterfaces(); + if (!forceLocalHost) { + var os = require('os'); + var ifaces = os.networkInterfaces(); - for (const ifname in ifaces) { - for (const iface of ifaces[ifname]) { - if ('IPv4' !== iface.family || iface.internal !== false) { - // skip over internal (i.e. 127.0.0.1) and non-ipv4 addresses - continue; - } + for (const ifname in ifaces) { + for (const iface of ifaces[ifname]) { + if ('IPv4' !== iface.family || iface.internal !== false) { + // skip over internal (i.e. 127.0.0.1) and non-ipv4 addresses + continue; + } - // returns the first available IP address - return iface.address; + // returns the first available IP address + return iface.address; + } } } @@ -35,6 +38,11 @@ function getMachineIpAddress() { return 'localhost'; } +const hostname = getMachineIpAddress(); +// In Node.js v16 and below, 'localhost' is using IPv4, so need to listen to '0.0.0.0' +// In Node.js v17+, 'localhost' is using IPv6, so need to listen to '::' +const listenAddress = Number.parseInt(process.versions.node.split('.')[0]) >= 17 ? '::' : '0.0.0.0'; + module.exports = function (config) { config.set({ // global config of your BrowserStack account @@ -75,12 +83,16 @@ module.exports = function (config) { browserNoActivityTimeout: 300000, browserDisconnectTolerance: 0, browserSocketTimeout: 60000, - hostname: getMachineIpAddress(), + hostname, + listenAddress, customLaunchers: { ChromeTest: { base: 'ChromeHeadless', flags: ['--enable-features=SharedArrayBuffer'] }, ChromePerf: { base: 'Chrome', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer'] }, ChromeDebug: { debug: true, base: 'Chrome', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer'] }, - + ChromeCanaryTest: { base: 'ChromeCanary', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] }, + ChromeCanaryProfileTest: { base: 'ChromeCanary', flags: ['--window-size=1,1', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu', '--disable-dawn-features=disallow_unsafe_apis'] }, + ChromeCanaryDebug: { debug: true, base: 'ChromeCanary', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu'] }, + ChromeCanaryProfileDebug: { debug: true, base: 'ChromeCanary', flags: ['--remote-debugging-port=9333', '--enable-features=SharedArrayBuffer', '--enable-unsafe-webgpu', '--disable-dawn-features=disallow_unsafe_apis'] }, // // ==== BrowserStack browsers ==== // diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 687b5aefcfccf..2049b2663ead3 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -14,6 +14,10 @@ interface BuildDefinitions { * defines whether to disable the whole WebGL backend in the build. */ DISABLE_WEBGL: boolean; + /** + * defines whether to disable the whole WebGpu backend in the build. + */ + DISABLE_WEBGPU: boolean; /** * defines whether to disable the whole WebAssembly backend in the build. */ diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 0e4a3f6d575f8..749331058cc4a 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -13,8 +13,12 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { const onnxjsBackend = require('./backend-onnxjs').onnxjsBackend; registerBackend('webgl', onnxjsBackend, -10); } + if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = require('./backend-wasm').wasmBackend; + if (!BUILD_DEFS.DISABLE_WEBGPU) { + registerBackend('webgpu', wasmBackend, 5); + } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); registerBackend('xnnpack', wasmBackend, 9); diff --git a/js/web/lib/onnxjs/backend.ts b/js/web/lib/onnxjs/backend.ts index a363ec9f21368..f402b820e76e1 100644 --- a/js/web/lib/onnxjs/backend.ts +++ b/js/web/lib/onnxjs/backend.ts @@ -78,7 +78,7 @@ export interface Backend { const backendsCache: Map = new Map(); export const backend: {[name: string]: Backend} = { - webgl: new WebGLBackend(), + webgl: new WebGLBackend() }; /** diff --git a/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts b/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts index a61270163f879..1a2bc7422c833 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts @@ -98,6 +98,7 @@ const createReduceProgramInfo = }; const validateInputs = (inputs: Tensor[]): void => { + // TODO: support Reduce* operators with 2 inputs. if (!inputs || inputs.length !== 1) { throw new Error('Reduce op requires 1 input.'); } @@ -174,4 +175,4 @@ export const reduceLogSumSquare: OperatorImplementation = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes): Tensor[] => { const reduceOp: ReduceOp = (): string[] => ['float t; value = 0.0;', 't = _A(inputIdx); value += t * t;', '']; return reduce(inferenceHandler, inputs, attributes, 'ReduceLogSumSquare', reduceOp); - }; \ No newline at end of file + }; diff --git a/js/web/lib/onnxjs/opset.ts b/js/web/lib/onnxjs/opset.ts index e23a288b4e22b..e7eb3251babc5 100644 --- a/js/web/lib/onnxjs/opset.ts +++ b/js/web/lib/onnxjs/opset.ts @@ -8,13 +8,11 @@ export interface OpSet { domain: string; version: number; } - export declare namespace OpSet { /** * Domain of an opset, it can be an empty string(default value, represent for ai.onnx), or 'ai.onnx.ml' */ type Domain = ''|'ai.onnx.ml'|'com.microsoft'; - /** * A resolve rule consists of 4 or 5 items: opType, opSetDomain, versionSelector, operatorImplementation and * operatorInitialization (optional) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index e7bf279f0ed0e..2e51d3257ec9c 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -1,6 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +declare namespace JSEP { + type BackendType = unknown; + type AllocFunction = (size: number) => number; + type FreeFunction = (size: number) => number; + type UploadFunction = (dataOffset: number, gpuDataId: number, size: number) => void; + type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise; + type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void; + type ReleaseKernelFunction = (kernel: number) => void; + type RunFunction = (kernel: number, contextDataOffset: number) => number; +} + export interface OrtWasmModule extends EmscriptenModule { // #region emscripten functions stackSave(): number; @@ -51,6 +62,17 @@ export interface OrtWasmModule extends EmscriptenModule { // #region config mainScriptUrlOrBlob?: string|Blob; // #endregion + + // #region JSEP + jsepInit? + (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, + download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, + releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + + _JsepOutput(context: number, index: number, data: number): number; + + jsepRunPromise?: Promise; + // #endregion } declare const moduleFactory: EmscriptenModuleFactory; diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts new file mode 100644 index 0000000000000..332a9d86f6646 --- /dev/null +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -0,0 +1,345 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; + +import {LOG_DEBUG} from './log'; +import {TensorView} from './tensor'; +import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager'; +import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; +import {ProgramManager} from './webgpu/program-manager'; +import {ComputeContext, GpuData, GpuDataType, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; + +/** + * get a unique key representing the program from the program info,input shapes and types. + * + * @returns a unique key is a shorter string than the shader source, which contains all the information to identify a + * program. if the key is the same, the program shader source should be the same, so we can reuse the program. + * + */ +const getProgramInfoUniqueKey = + (programInfo: ProgramInfo|ProgramInfoLoader, inputTensorShapes: ReadonlyArray, + inputGpuDataTypes: readonly GpuDataType[]): string => { + const inputTensorShapesToString = inputTensorShapes.map(d => `${d.join(',')}`).join('_'); + const inputGpuDataTypesToString = inputGpuDataTypes.join('_'); + let key = programInfo.name; + if (programInfo.cacheHint) { + key += '[' + programInfo.cacheHint + ']'; + } + key += ':' + inputTensorShapesToString + ';' + inputGpuDataTypesToString; + return key; + }; + +/** + * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as + * the first parameter so that it is stored for future use. + */ +export class WebGpuBackend { + device: GPUDevice; + /** + * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping + */ + gpuDataManager: GpuDataManager; + /** + * an instance of ProgramManager to build and run WebGPU compute shader program, and manage a ProgramKey -> Program + * artifacts mapping + */ + programManager: ProgramManager; + + /** + * representing the kernel ID of which is currently being computed (CPU code perspective). + * `null` means no kernel is being computed. + * only one kernel can be computed at a moment. + */ + currentKernelId: number|null = null; + /** + * a list of temporary GPU data for the current kernel. should release when the kernel done computation. + */ + private temporaryData: GpuData[]; + /** + * a KernelID -> a GPU data list, which stores persistent GPU data owned by the specific kernel. + */ + private kernelPersistentData: Map; + /** + * a KernelID -> a custom data, which stores custom data owned by the specific kernel. + */ + private kernelCustomData: Map; + /** + * get the custom data of the current kernel + */ + get currentKernelCustomData(): {[key: string]: unknown} { + if (this.currentKernelId === null) { + throw new Error('currentKernelCustomData(): currentKernelId is null. (should not happen)'); + } + + let data = this.kernelCustomData.get(this.currentKernelId); + if (!data) { + data = {}; + this.kernelCustomData.set(this.currentKernelId, data); + } + + return data; + } + + /** + * a KernelID -> kernel info mapping. value is [ name, run function, [optional] preprocess_attribute_once function ] + */ + kernels: Map unknown) | undefined, unknown]]>; + + commandEncoder: GPUCommandEncoder|null = null; + computePassEncoder: GPUComputePassEncoder|null = null; + pendingDispatchNumber = 0; + + profilingEnabled = false; + profilingQuerySet: GPUQuerySet; + profilingTimeBase?: bigint; + + async initialize(): Promise { + if (!navigator.gpu) { + // WebGPU is not available. + throw new Error('WebGpuBackend: WebGPU is not available.'); + } + + const adapter = await navigator.gpu.requestAdapter(); + if (!adapter) { + throw new Error('WebGpuBackend: Failed to get GPU adapter.'); + } + + const deviceDescriptor: GPUDeviceDescriptor = { + requiredLimits: { + maxComputeWorkgroupStorageSize: adapter.limits.maxComputeWorkgroupStorageSize, + maxComputeWorkgroupsPerDimension: adapter.limits.maxComputeWorkgroupsPerDimension, + maxStorageBufferBindingSize: adapter.limits.maxStorageBufferBindingSize, + } + }; + // WebGPU Spec: Timestamp Queries Inside Passes + // https://github.com/gpuweb/gpuweb/blob/main/proposals/timestamp-query-inside-passes.md + if (adapter.features.has('timestamp-query-inside-passes') && env.webgpu.profilingMode === 'default') { + this.profilingEnabled = true; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + deviceDescriptor.requiredFeatures = ['timestamp-query-inside-passes' as any]; + } + + this.device = await adapter.requestDevice(deviceDescriptor); + this.gpuDataManager = createGpuDataManager(this); + this.programManager = new ProgramManager(this); + this.kernels = new Map(); + this.kernelPersistentData = new Map(); + this.kernelCustomData = new Map(); + // TODO: set up flags + + this.device.onuncapturederror = ev => { + if (ev.error instanceof GPUValidationError) { + // eslint-disable-next-line no-console + console.error(`An uncaught WebGPU validation error was raised: ${ev.error.message}`); + } + }; + + if (this.profilingEnabled) { + this.profilingQuerySet = this.device.createQuerySet({ + type: 'timestamp', + count: 2, + }); + } + } + + dispose(): void { + // TODO: uninitialization + // this.glContext.dispose(); + } + + getCommandEncoder(): GPUCommandEncoder { + if (!this.commandEncoder) { + this.commandEncoder = this.device.createCommandEncoder(); + } + return this.commandEncoder; + } + + getComputePassEncoder(): GPUComputePassEncoder { + if (!this.computePassEncoder) { + this.computePassEncoder = this.getCommandEncoder().beginComputePass(); + } + return this.computePassEncoder; + } + + endComputePass(): void { + if (this.computePassEncoder) { + this.computePassEncoder.end(); + this.computePassEncoder = null; + } + } + + flush(): void { + this.endComputePass(); + this.device.queue.submit([this.getCommandEncoder().finish()]); + this.gpuDataManager.refreshPendingBuffers(); + this.commandEncoder = null; + this.pendingDispatchNumber = 0; + } + + /** + * run a WebGPU program. + * @param program either a ProgramInfo instance containing metadata including the shader code, or a function that + * can be called and return a ProgramInfo instance + * @param inputs a TensorView array. each element represents a value already exists in GPU. + * @param outputIndices an indices array. each element can be either -1 (temporary data), -2 (persistent data) or an + * index to the kernel's output. + * @param createKernelOutput a callback function that create a value to kernel's output with the given index + * @param createIntermediateOutput a callback function that create a value as a intermediate value, either temporary + * or persistent (owned by the current kernel) + * @returns a TensorView array representing the result. + */ + run(program: ProgramInfoLoader|ProgramInfo, inputs: readonly TensorView[], outputIndices: readonly number[], + createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView, + createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView): TensorView[] { + if (inputs.length !== program.inputTypes.length) { + throw new Error(`Input size must be equal to ${program.inputTypes.length}.`); + } + + // create info for inputs + const inputDatas: GpuData[] = []; + for (let i = 0; i < inputs.length; ++i) { + const gpuData = this.gpuDataManager.get(inputs[i].data); + if (!gpuData) { + throw new Error(`no GPU data for input: ${inputs[i].data}`); + } + inputDatas[i] = gpuData; + } + + const key = getProgramInfoUniqueKey(program, inputs.map(i => i.dims), inputDatas.map(i => i.type)); + let artifact = this.programManager.getArtifact(key); + const programInfo = artifact ? + artifact.programInfo : + (typeof (program as ProgramInfoLoader).get === 'function' ? (program as ProgramInfoLoader).get() : + (program as ProgramInfo)); + + // check output indices + const validatedOutputIndices = outputIndices.length === 0 ? programInfo.outputs.map((_, i) => i) : outputIndices; + if (validatedOutputIndices.length !== programInfo.outputs.length) { + throw new Error(`Output size ${validatedOutputIndices.length} must be equal to ${programInfo.outputs.length}.`); + } + + // create info for outputs + const outputTensorViews: TensorView[] = []; + const outputDatas: GpuData[] = []; + for (let i = 0; i < programInfo.outputs.length; ++i) { + // value -1 and -2 are used for creating temporary and persistent outputs. so -2, -1 and 0, 1, 2, ... are valid + // output indices. see type definition of ComputeContextInputsOutputsMapping for more details. + if (!Number.isInteger(validatedOutputIndices[i]) || validatedOutputIndices[i] < -2 || + validatedOutputIndices[i] >= programInfo.outputs.length) { + throw new Error(`Invalid output index: ${validatedOutputIndices[i]}`); + } + const isTemporary = validatedOutputIndices[i] === -1; + const isPersistent = validatedOutputIndices[i] === -2; + const tensorView = (isTemporary || isPersistent) ? + createIntermediateOutput(programInfo.outputs[i].dataType, programInfo.outputs[i].dims) : + createKernelOutput(validatedOutputIndices[i], programInfo.outputs[i].dataType, programInfo.outputs[i].dims); + const gpuData = this.gpuDataManager.get(tensorView.data); + if (!gpuData) { + throw new Error(`no GPU data for output: ${tensorView.data}`); + } + if (isTemporary) { + this.temporaryData.push(gpuData); + } + if (isPersistent) { + let persistentData = this.kernelPersistentData.get(this.currentKernelId!); + if (!persistentData) { + persistentData = []; + this.kernelPersistentData.set(this.currentKernelId!, persistentData); + } + persistentData.push(gpuData); + } + outputTensorViews.push(tensorView); + outputDatas.push(gpuData); + } + + const normalizedDispatchGroup = this.programManager.normalizeDispatchGroupSize(programInfo.dispatchGroup(inputs)); + + if (!artifact) { + artifact = this.programManager.build(programInfo, normalizedDispatchGroup); + this.programManager.setArtifact(key, artifact); + } + + LOG_DEBUG( + 'info', + () => `[ProgramManager] run "${programInfo.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ + normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); + this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup); + + return outputTensorViews; + } + + upload(gpuDataId: number, data: Uint8Array): void { + this.gpuDataManager.upload(gpuDataId, data); + } + + memcpy(src: number, dst: number): void { + this.gpuDataManager.memcpy(src, dst); + } + + async download(gpuDataId: number, data: Uint8Array): Promise { + const arrayBuffer = await this.gpuDataManager.download(gpuDataId); + data.set(new Uint8Array(arrayBuffer)); + } + + alloc(size: number): number { + return this.gpuDataManager.create(size).id; + } + + free(ptr: number): number { + return this.gpuDataManager.release(ptr); + } + + createKernel(name: string, kernelId: number, attribute: unknown): void { + const op = WEBGPU_OP_RESOLVE_RULES.get(name); + if (!op) { + throw new Error(`kernel not implemented: ${name}`); + } + + this.kernels.set(kernelId, [name, op[0], [op[1], attribute]]); + } + + releaseKernel(kernelId: number): void { + const persistentData = this.kernelPersistentData.get(kernelId); + if (persistentData) { + for (const data of persistentData) { + this.gpuDataManager.release(data.id); + } + this.kernelPersistentData.delete(kernelId); + } + + this.kernelCustomData.delete(kernelId); + this.kernels.delete(kernelId); + } + + computeKernel(kernelId: number, context: ComputeContext): number { + const kernel = this.kernels.get(kernelId); + if (!kernel) { + throw new Error(`kernel not created: ${kernelId}`); + } + const [name, kernelEntry, attributes] = kernel; + if (this.currentKernelId !== null) { + throw new Error(`kernel "${name}" is not allowed to be called recursively`); + } + this.currentKernelId = kernelId; + + // parse attributes if necessary + if (attributes[0]) { + attributes[1] = attributes[0](attributes[1]); + attributes[0] = undefined; + } + + LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "${name}"...`); + + this.temporaryData = []; + try { + return kernelEntry(context, attributes[1]); + } finally { + for (const data of this.temporaryData) { + this.gpuDataManager.release(data.id); + } + this.temporaryData = []; + this.currentKernelId = null; + } + } +} diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts new file mode 100644 index 0000000000000..4226a0ef46f57 --- /dev/null +++ b/js/web/lib/wasm/jsep/init.ts @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {OrtWasmModule} from '../binding/ort-wasm'; +import {getTensorElementSize} from '../wasm-common'; + +import {WebGpuBackend} from './backend-webgpu'; +import {LOG_DEBUG} from './log'; +import {TensorView} from './tensor'; +import {ShapeUtil} from './util'; +import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; + +/* eslint-disable no-bitwise */ + +class TensorViewImpl implements TensorView { + constructor( + private module: OrtWasmModule, public readonly dataType: number, public readonly data: number, + public readonly dims: readonly number[]) {} + + getFloat32Array(): Float32Array { + return new Float32Array(this.module.HEAP8.buffer, this.data, ShapeUtil.size(this.dims)); + } + + reshape(newDims: readonly number[]): TensorView { + if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) { + throw new Error('Invalid new shape'); + } + return new TensorViewImpl(this.module, this.dataType, this.data, newDims); + } +} + +class OpKernelContext implements ComputeContext { + readonly opKernelContext: number; + readonly inputs: readonly TensorView[]; + get customData(): {[key: string]: unknown} { + return this.backend.currentKernelCustomData; + } + constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) { + const heapU32 = module.HEAPU32; + + // extract context data + let dataIndex = (contextDataOffset >> 2); + this.opKernelContext = heapU32[dataIndex++]; + const inputCount = heapU32[dataIndex++]; + + const inputs: TensorView[] = []; + for (let i = 0; i < inputCount; i++) { + const dataType = heapU32[dataIndex++]; + const data = heapU32[dataIndex++]; + const dim = heapU32[dataIndex++]; + const dims: number[] = []; + for (let d = 0; d < dim; d++) { + dims.push(heapU32[dataIndex++]); + } + inputs.push(new TensorViewImpl(module, dataType, data, dims)); + } + this.inputs = inputs; + } + + compute(program: ProgramInfoLoader|ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): + TensorView[] { + // prepare inputs. inputs should always be valid data. + const mappedInputs = + inputsOutputsMapping?.inputs?.map(i => typeof i === 'number' ? this.inputs[i] : i) ?? this.inputs; + // prepare outputs. + const outputIndices = inputsOutputsMapping?.outputs ?? []; + const createKernelOutput = (index: number, dataType: number, dims: readonly number[]): TensorView => + new TensorViewImpl(this.module, dataType, this.output(index, dims), dims); + const createTemporaryOutput = (dataType: number, dims: readonly number[]): TensorView => { + const elementSize = getTensorElementSize(dataType); + if (!elementSize) { + throw new Error(`Unsupported data type: ${dataType}`); + } + const bufferSize = elementSize * ShapeUtil.size(dims); + return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims); + }; + return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput); + } + + output(index: number, dims: readonly number[]): number { + const stack = this.module.stackSave(); + try { + const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */); + let offset = data >> 2; + this.module.HEAPU32[offset++] = dims.length; + for (let i = 0; i < dims.length; i++) { + this.module.HEAPU32[offset++] = dims[i]; + } + return this.module._JsepOutput(this.opKernelContext, index, data); + } finally { + this.module.stackRestore(stack); + } + } +} + +export const init = async(module: OrtWasmModule): Promise => { + const init = module.jsepInit; + if (init && navigator.gpu) { + const backend = new WebGpuBackend(); + await backend.initialize(); + + init( + // backend + {backend}, + + // jsepAlloc() + (size: number) => backend.alloc(size), + + // jsepFree() + (ptr: number) => backend.free(ptr), + + // jsepCopy(src, dst, size, isSourceGpu) + (src: number, dst: number, size: number, isSourceGpu = false) => { + if (isSourceGpu) { + LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`); + backend.memcpy(src, dst); + } else { + LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); + const data = module.HEAPU8.subarray(src, src + size); + backend.upload(dst, data); + } + }, + + // jsepCopyAsync(src, dst, size) + async(gpuDataId: number, dataOffset: number, size: number): + Promise => { + const data = module.HEAPU8.subarray(dataOffset, dataOffset + size); + + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); + + await backend.download(gpuDataId, data); + }, + + // jsepCreateKernel + (name: string, kernel: number, attribute: unknown) => backend.createKernel(name, kernel, attribute), + + // jsepReleaseKernel + (kernel: number) => backend.releaseKernel(kernel), + + // jsepRun + (kernel: number, contextDataOffset: number) => { + LOG_DEBUG('verbose', () => `[WebGPU] jsepRun: kernel=${kernel}, contextDataOffset=${contextDataOffset}`); + const context = new OpKernelContext(module, backend, contextDataOffset); + return backend.computeKernel(kernel, context); + }); + } +}; diff --git a/js/web/lib/wasm/jsep/log.ts b/js/web/lib/wasm/jsep/log.ts new file mode 100644 index 0000000000000..2e27e4905742e --- /dev/null +++ b/js/web/lib/wasm/jsep/log.ts @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; + +import {logLevelStringToEnum} from '../wasm-common'; + +type LogLevel = NonNullable; +type MessageString = string; +type MessageFunction = () => string; +type Message = MessageString|MessageFunction; + +const logLevelPrefix = ['V', 'I', 'W', 'E', 'F']; + +const doLog = (level: number, message: string): void => { + // eslint-disable-next-line no-console + console.log(`[${logLevelPrefix[level]},${new Date().toISOString()}]${message}`); +}; + +/** + * A simple logging utility to log messages to the console. + */ +export const LOG = (logLevel: LogLevel, msg: Message): void => { + const messageLevel = logLevelStringToEnum(logLevel); + const configLevel = logLevelStringToEnum(env.logLevel!); + if (messageLevel >= configLevel) { + doLog(messageLevel, typeof msg === 'function' ? msg() : msg); + } +}; + +/** + * A simple logging utility to log messages to the console. Only logs when debug is enabled. + */ +export const LOG_DEBUG: typeof LOG = (...args: Parameters) => { + if (env.debug) { + LOG(...args); + } +}; diff --git a/js/web/lib/wasm/jsep/tensor.ts b/js/web/lib/wasm/jsep/tensor.ts new file mode 100644 index 0000000000000..720b2357df1f2 --- /dev/null +++ b/js/web/lib/wasm/jsep/tensor.ts @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +export declare namespace Tensor { + export interface DataTypeMap { + bool: Uint8Array; + float32: Float32Array; + float64: Float64Array; + string: string[]; + int8: Int8Array; + uint8: Uint8Array; + int16: Int16Array; + uint16: Uint16Array; + int32: Int32Array; + uint32: Uint32Array; + int64: BigInt64Array; + uint64: BigUint64Array; + } + + export type DataType = keyof DataTypeMap; + + export type StringType = Tensor.DataTypeMap['string']; + export type BooleanType = Tensor.DataTypeMap['bool']; + export type IntegerType = Tensor.DataTypeMap['int8']|Tensor.DataTypeMap['uint8']|Tensor.DataTypeMap['int16']| + Tensor.DataTypeMap['uint16']|Tensor.DataTypeMap['int32']|Tensor.DataTypeMap['uint32']| + Tensor.DataTypeMap['int64']|Tensor.DataTypeMap['uint64']; + export type FloatType = Tensor.DataTypeMap['float32']|Tensor.DataTypeMap['float64']; + export type NumberType = BooleanType|IntegerType|FloatType; + + export type Id = number; +} + +export const sizeof = (type: Tensor.DataType): number => { + switch (type) { + case 'bool': + case 'int8': + case 'uint8': + return 1; + case 'int16': + case 'uint16': + return 2; + case 'int32': + case 'uint32': + case 'float32': + return 4; + case 'int64': + case 'uint64': + case 'float64': + return 8; + default: + throw new Error(`cannot calculate sizeof() on type ${type}`); + } +}; + +const dataviewConstructor = (type: Tensor.DataType) => { + switch (type) { + case 'bool': + case 'uint8': + return Uint8Array; + case 'int8': + return Int8Array; + case 'int16': + return Int16Array; + case 'uint16': + return Uint16Array; + case 'int32': + return Int32Array; + case 'uint32': + return Uint32Array; + case 'int64': + return BigInt64Array; + case 'uint64': + return BigUint64Array; + case 'float32': + return Float32Array; + case 'float64': + return Float64Array; + default: + // should never run to here + throw new Error('unspecified error'); + } +}; + +export const createView = (dataBuffer: ArrayBuffer, type: Tensor.DataType): Int32Array|Uint32Array|BigInt64Array| + BigUint64Array|Uint8Array|Float32Array|Float64Array|Int8Array|Int16Array|Uint16Array => + new (dataviewConstructor(type))(dataBuffer); + +/** + * a TensorView does not own the data. + */ +export interface TensorView { + readonly data: number; + readonly dataType: number; + readonly dims: readonly number[]; + + /** + * get a Float32Array data view of the tensor data. tensor data must be on CPU. + */ + getFloat32Array(): Float32Array; + + /** + * create a new tensor view with the same data but different dimensions. + */ + reshape(newDims: readonly number[]): TensorView; +} diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts new file mode 100644 index 0000000000000..cd128ad5e501d --- /dev/null +++ b/js/web/lib/wasm/jsep/util.ts @@ -0,0 +1,774 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/* eslint-disable no-param-reassign */ + +export class MatMulUtil { + /** + * Fix the input shapes for MatMul operation if they need fixing + * @param dimsA The shape of tensor A. Should be an array of positive integers + * @param dimsB The shape of tensor B. Should be an array of positive integers + * @returns A tuple containing the preprocessed input shapes as required by ONNX specifications + */ + static preprocessInputShapes(dimsA: readonly number[], dimsB: readonly number[]): + [readonly number[], readonly number[]] { + // If the first argument is 1-D, it is promoted to a matrix by prepending + // a 1 to its dimensions. After matrix multiplication the prepended 1 is + // removed. + const a = (dimsA.length === 1) ? [1, dimsA[0]] : dimsA; + + // If the second argument is 1-D, it is promoted to a matrix by appending + // a 1 to its dimensions. After matrix multiplication the appended 1 is + // removed. + const b = (dimsB.length === 1) ? [dimsB[0], 1] : dimsB; + + return [a, b]; + } + + /** + * Fix the output shape computed for MatMul operation if it needs fixing + * @param outputShape The computed outputShape. Should be an array (atleast of length 2) of positive integers. + * This will be mutated. + * @param aRank The rank of tensor A. + * @param bRank The rank of tensor B. + */ + static postprocessOutputShape(outputShape: number[], aRank: number, bRank: number): void { + // Remove prepended dimension if first input is 1d + if (aRank === 1) { + // outputShape = outputShape.slice(0, outputShape.length - 2).concat(outputShape.slice(outputShape.length - 1)); + outputShape.splice(outputShape.length - 2, 1); + } + // Remove appended dimension if second input is 1d + if (bRank === 1) { + outputShape.pop(); + } + } + + /** + * Calculate the expected shape when matrix multiplication + * @param a The shape of tensor A. Should be a tuple of 2 positive integers + * @param b The shape of tensor B. Should be a tuple of 2 positive integers + * @returns The expected shape of the result, or undefined if N/A + */ + static calcMatMulShape(a: [number, number], b: [number, number]): [number, number]|undefined { + return (a[1] !== b[0]) ? undefined : [a[0], b[1]]; + } +} + + +export class BroadcastUtil { + /** + * Calculate the expected shape when broadcasting 2 tensors + * @param a The shape of tensor A. Should be an array of positive integers + * @param b The shape of tensor B. Should be an array of positive integers + * @param isMatMul Whether the operation is MatMul + * @returns The expected shape of the result, or undefined if N/A + */ + static calcShape(adims: readonly number[], bdims: readonly number[], isMatMul = false): readonly number[]|undefined { + const arank = adims.length; + const brank = bdims.length; + if (arank === 0) { + return bdims; + } + if (brank === 0) { + return adims; + } + const crank = Math.max(adims.length, bdims.length); + const cdims = new Array(crank); + + // calculate the last 2 dimension if it is MatMul + if (isMatMul) { + if (arank < 2 || brank < 2) { + return undefined; + } + const cShapeMatMul = + MatMulUtil.calcMatMulShape([adims[arank - 2], adims[arank - 1]], [bdims[brank - 2], bdims[brank - 1]]); + if (cShapeMatMul === undefined) { + return undefined; + } + [cdims[crank - 2], cdims[crank - 1]] = cShapeMatMul; + } + + for (let i = isMatMul ? 3 : 1; i <= crank; i++) { + const aLen = arank - i < 0 ? 1 : adims[arank - i]; + const bLen = brank - i < 0 ? 1 : bdims[brank - i]; + + if (aLen !== bLen && aLen > 1 && bLen > 1) { + return undefined; + } + cdims[crank - i] = Math.max(aLen, bLen); + } + + return cdims; + } + + /** + * Given the indices of a broadcasted tensor, calculate the original indices + * @param broadcastedIndices The given indices of the broadcasted tensor. + * @param originalShape The original shape of the tensor before broadcas + * @returns The calculated indices that maps to the original tensor. + */ + static index(broadcastedIndices: readonly number[], originalShape: readonly number[]): number[] { + // NOTE 1: we assume the parameter broadcastedIndices is valid. ie. it should have the same + // length as the broadcasted shape, and for each dimension the index should + // not be out of range. + const originalIndices = new Array(originalShape.length); + BroadcastUtil.fillIndex(broadcastedIndices, originalShape, originalIndices); + return originalIndices; + } + + /** + * Given the indices of a broadcasted tensor, calculate the original indices + * @param broadcastedIndices The given indices of the broadcasted tensor. + * @param originalShape The original shape of the tensor before broadcast + * @param originalIndices The mapping of broadcastedIndices to the originalIndices (output parameter - will be + * mutated). + */ + static fillIndex(broadcastedIndices: readonly number[], originalShape: readonly number[], originalIndices: number[]): + void { + // NOTE 1: we assume the parameter broadcastedIndices is valid. ie. it should have the same length as the + // broadcasted shape, and for each dimension the index should not be out of range. + // NOTE 2: we assume the parameter originalIndices has the same length as the originalShape + const dimOffset = broadcastedIndices.length - originalShape.length; + for (let i = 0; i < originalShape.length; i++) { + originalIndices[i] = broadcastedIndices[dimOffset + i] % originalShape[i]; + } + } + + /** + * Determine if a shape is unidirectional broadcastable to another shape + * @param shape The input shape + * @param finalShape The desired shape after broadcasting + */ + static isValidBroadcast(shape: readonly number[], finalShape: readonly number[]): boolean { + // align shape to the right + const inputRank = shape.length; + const finalRank = finalShape.length; + if (inputRank > finalRank) { + return false; + } + for (let i = 1; i <= inputRank; i++) { + if (shape[inputRank - i] !== 1 && shape[inputRank - i] !== finalShape[finalRank - i]) { + return false; + } + } + return true; + } + + /** + * Determine the broadcasted dims in input shape based on the given output shape. + * Note that this function only returns the broadcasted dims. + * @param inputShape The input shape + * @param outputShape The output shape + * @returns The broadcasted dims in input shape. + */ + static getBroadcastDims(inputShape: readonly number[], outputShape: readonly number[]): number[] { + const inRank = inputShape.length; + const dims: number[] = []; + for (let i = 0; i < inRank; i++) { + const dim = inRank - 1 - i; + const a = inputShape[dim] || 1; + const b = outputShape[outputShape.length - 1 - i] || 1; + if (b > 1 && a === 1) { + dims.unshift(dim); + } + } + return dims; + } +} + + +export class ShapeUtil { + /** + * calculate the size (number of elements) + */ + static size(dims: readonly number[]): number { + return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length); + } + + /** + * calculate the size (number of elements) from the given axis (inclusive) + */ + static sizeFromDimension(dims: readonly number[], axis: number): number { + if (axis < 0 || axis > dims.length) { + throw new Error(`invalid dimension of ${axis} for sizeFromDimension as Tensor has ${dims.length} dimensions.`); + } + return ShapeUtil.getSizeFromDimensionRange(dims, axis, dims.length); + } + + /** + * calculate the size (number of elements) to the given axis (exclusive) + */ + static sizeToDimension(dims: readonly number[], axis: number): number { + if (axis < 0 || axis > dims.length) { + throw new Error(`invalid dimension of ${axis} for sizeToDimension as Tensor has ${dims.length} dimensions.`); + } + return ShapeUtil.getSizeFromDimensionRange(dims, 0, axis); + } + + /** + * calculate the size (number of elements) from and to the given axis [start, end) + */ + static getSizeFromDimensionRange(dims: readonly number[], start: number, end: number): number { + let size = 1; + for (let i = start; i < end; i++) { + // safety check as this method is called by multiple other methods requiring size. + // size cannot be 0 or negative. + if (dims[i] <= 0) { + throw new Error( + // eslint-disable-next-line max-len + 'cannot get valid size from specified dimension range. Most likely the range contains 0 or negative values in them.'); + } + size *= dims[i]; + } + return size; + } + + static computeStrides(dims: readonly number[]): readonly number[] { + const rank = dims.length; + if (rank === 0) { + return []; + } else if (rank === 1) { + return [1]; + } + const strides = new Array(rank); + strides[rank - 1] = 1; + strides[rank - 2] = dims[rank - 1]; + for (let i = rank - 3; i >= 0; --i) { + strides[i] = strides[i + 1] * dims[i + 1]; + } + return strides; + } + + static transpose(dims: readonly number[]): readonly number[] { + const copy = dims.slice(); + return copy.reverse(); + } + + static indicesToOffset(indices: readonly number[], strides: readonly number[], axis?: number): number { + if (axis === undefined) { + axis = indices.length; + } + let offset = 0; + for (let i = 0; i < axis; ++i) { + offset += strides[i] * indices[i]; + } + return offset; + } + + static offsetToIndices(offset: number, strides: readonly number[]): readonly number[] { + const rank = strides.length; + if (rank === 0) { + return []; + } else if (rank === 1) { + return [offset * strides[0]]; + } + const indices: number[] = new Array(strides.length); + for (let i = 0; i < indices.length - 1; ++i) { + indices[i] = Math.floor(offset / strides[i]); + offset -= indices[i] * strides[i]; + } + indices[indices.length - 1] = offset; + return indices; + } + + /** + * normailze axis of range [-r, r) into [0, r). + */ + static normalizeAxis(axis: number, tensorRank: number): number { + if (axis < -tensorRank && axis >= tensorRank) { + throw new Error('unsupported axis for this operation.'); + } + return axis < 0 ? axis + tensorRank : axis; + } + + static normalizeAxes(axes: readonly number[], tensorRank?: number): number[] { + return axes.map(x => this.normalizeAxis(x, tensorRank ?? axes.length)); + } + + /** + * Increment an index into a tensor (in lexicographic ordering), wrapping around the specified upper_bound. + * @param index Given index to increment (Will be mutated) + * @param dims The dimensions of the tensor for which the given index corresponds to + * @param axisToIncrementOn The 1-indexed axis to increment on. If undefined, axisToIncrementOn == rank + */ + static incrementIndex(index: number[], dims: readonly number[], axisToIncrementOn?: number): void { + if (dims.length === 0 || index.length === 0) { + throw new Error('Index incrementing unsupported for scalar Tensor'); + } + if (axisToIncrementOn === undefined) { + axisToIncrementOn = dims.length; + } else { + if (axisToIncrementOn <= 0 || axisToIncrementOn > dims.length) { + throw new Error('Incorrect axis to increment on'); + } + } + + for (let k = axisToIncrementOn - 1; k >= 0; --k) { + index[k]++; + if (index[k] < dims[k]) { + break; + } + index[k] = 0; + } + } + + /** + * Produces a new dimensions array based on the values in the 'originalDimensions' and 'shape' array + * Used in Reshape + * @param originalDims Original Shape array + * @param shapeHints array containing values to compute the new dimensions + * For example: + * originalDims = [2,2] and shapeHints = [0,-1] will return [2,2] + * originalDims = [2,2] and shapeHints = [4] will return [4] + * originalDims = [2,2] and shapeHints = [5] will throw an exception + * https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape + */ + + static calculateReshapedDims(originalDims: readonly number[], shapeHints: ArrayLike): number[] { + // reshape to a Scalar Tensor + if (shapeHints.length === 0) { + if (originalDims.length === 0 || ShapeUtil.size(originalDims) === 1) { + return []; + } else { + throw new Error('cannot reshape to a scalar Tensor'); + } + } + + const nDims = shapeHints.length; + const reshapedDims = new Array(nDims); + let unknownDimension = -1; + let newTensorSize = 1; + for (let i = 0; i < nDims; i++) { + if (shapeHints[i] < -1) { + throw new Error('a dimension in shape hints cannot be less than -1'); + } + if (shapeHints[i] === -1) { + if (unknownDimension !== -1) { + throw new Error('at most one dimension in shape hints can be -1'); + } + unknownDimension = i; + } else { + if (shapeHints[i] === 0) { + if (i >= originalDims.length) { + throw new Error('the dimension with value zero exceeds the dimension size of the input tensor'); + } + reshapedDims[i] = originalDims[i]; + } else { + reshapedDims[i] = shapeHints[i]; + } + newTensorSize *= reshapedDims[i]; + } + } + + const oldTensorSize = ShapeUtil.size(originalDims); + if (unknownDimension !== -1) { + if (oldTensorSize % newTensorSize !== 0) { + throw new Error(`the input tensor cannot be reshaped to the requested shape. Input shape: [${ + originalDims}] Output shape: [${shapeHints}]`); + } + reshapedDims[unknownDimension] = oldTensorSize / newTensorSize; + } + // validate sizes from originalDims and reshapedDims match + else { + if (newTensorSize !== oldTensorSize) { + throw new Error('reshapedDims and originalDims don\'t have matching sizes'); + } + } + return reshapedDims; + } + + /** + * Sorts a given array based on the indices in the Perm array + * Used in Transpose + * @param a Array to be sorted such as dims or strides + * @param perm Perm given; if null a will be reversed + */ + static sortBasedOnPerm(a: readonly number[], perm?: readonly number[]): readonly number[] { + if (perm) { + return perm.map((v) => a[v]); + } else { + return a.slice().reverse(); + } + } + + /** + * Pads a given shape according to the padding values + * @param dims shape of the Tensor to be padded + * @param pad pad values + */ + static padShape(dims: readonly number[], pad: readonly number[]): readonly number[] { + const rank = dims.length; + return dims.map((v, i) => v + pad[i] + pad[i + rank]); + } + + /** + * Determines if the two shapes are identical + * @param shape1 + * @param shape2 + */ + static areEqual(shape1: readonly number[], shape2: readonly number[]): boolean { + if (shape1.length !== shape2.length) { + return false; + } + return shape1.every((v, i) => v === shape2[i]); + } + + /** + * Validates if the given `dims` or `shape` is valid in ONNX.js context and returns data size + * @param dims - input `dims` that needs to be checked + */ + static validateDimsAndCalcSize(dims: readonly number[]): number { + if (dims.length > 6) { + throw new TypeError('Only rank 0 to 6 is supported for tensor shape.'); + } + let size = 1; + for (const n of dims) { + if (!Number.isInteger(n)) { + throw new TypeError(`Invalid shape: ${n} is not an integer`); + } + if (n < 0 || n > 2147483647) { + throw new TypeError(`Invalid shape: length ${n} is not allowed`); + } + size *= n; + } + return size; + } + + /** + * Determines the shape of output tensor y = flatten(x, axis) + * @param dims - shape of input tensor + * @param axis - flatten axis, in the range [-r, r] + */ + static flattenShape(dims: readonly number[], axis: number): readonly number[] { + if (axis < 0) { + axis += dims.length; + } + const total = dims.reduce((x, y) => x * y, 1); + const right = dims.slice(axis).reduce((x, y) => x * y, 1); + const outputDims = [total / right, right]; + + return outputDims; + } + + /** + * Determines the shape of output tensor y = squeeze(x, axes) + * @param dims - shape of input tensor + * @param axes - squeeze axes + */ + static squeezeShape(dims: readonly number[], axes: readonly number[]): readonly number[] { + const outputDims = new Array(); + + // sanity check + axes = ShapeUtil.normalizeAxes(axes, dims.length); + + for (let i = 0; i < dims.length; i++) { + const inSqueezeList = axes.indexOf(i) >= 0; + if (inSqueezeList && dims[i] !== 1) { + throw new Error('squeeze an axis of size different than 1'); + } + + if ((axes.length === 0 && dims[i] > 1) || (axes.length > 0 && !inSqueezeList)) { + outputDims.push(dims[i]); + } + } + + return outputDims; + } + + /** + * Determines the shape of output tensor y = unsqueeze(x, axes) + * @param dims - shape of input tensor + * @param axes - unsqueeze axes + */ + static unsqueezeShape(dims: readonly number[], axes: readonly number[]): readonly number[] { + const outputDims = new Array(dims.length + axes.length); + + // initialize the array elements to 0 + outputDims.fill(0); + + // set all axes indices to 1 in outputDims and check for duplicates + for (let i = 0; i < axes.length; i++) { + const axis = ShapeUtil.normalizeAxis(axes[i], outputDims.length); + if (axis >= outputDims.length) { + throw new Error('\'axes\' has an out of range axis'); + } + if (outputDims[axis] !== 0) { + throw new Error('\'axes\' has a duplicate axis'); + } + + outputDims[axis] = 1; + } + + // fill in the zero entries of outputDims with the input tensor's shape + let inputDimsIterator = 0; + for (let i = 0; i < outputDims.length; i++) { + if (outputDims[i] === 0) { + outputDims[i] = dims[inputDimsIterator++]; + } + } + + // sanity check assertion. 'inputDimsIterator' + // should be equal to the length of 'dims' + if (inputDimsIterator !== dims.length) { + throw new Error('the unsqueezed dimension could not be established'); + } + + return outputDims; + } +} + +export class PoolConvUtil { + /** + * Adjust the kernel, strides, pads to correct rank. Set to default value if not present + * @param isGlobalOperator If true, perform global pooling. + * @param inputDims The input tensor dimension. + * @param kernelShape The size of the kernel along each axis. + * @param strides Stride along each axis. + * @param dilations Dilation along each axis. + * @param pads Padding for the beginning and ending along each axis. + */ + static adjustPoolAttributes( + isGlobalOperator: boolean, inputDims: readonly number[], kernelShape: number[], strides: number[], + dilations: number[], pads: number[]): void { + if (!isGlobalOperator && kernelShape.length !== inputDims.length - 2) { + throw new Error('length of specified kernel shapes should be 2 less than length of input dimensions'); + } + + if (isGlobalOperator) { + // adjust kernel shape to cover the input dims + for (let dim = 0; dim < inputDims.length - 2; dim++) { + if (dim >= kernelShape.length) { + kernelShape.push(inputDims[dim + 2]); + } else { + kernelShape[dim] = inputDims[dim + 2]; + } + } + } + + // adjust strides length to match kernel shape length + for (let dim = 0; dim < kernelShape.length; dim++) { + if (dim < strides.length) { + if (strides[dim] < 0) { + throw new Error('strides should be greater than or equal to 1'); + } + } else { + strides.push(1); + } + } + + // adjust dilation value + for (let dim = 0; dim < kernelShape.length; dim++) { + if (dim < dilations.length) { + if (dilations[dim] < 0) { + throw new Error('dilations should be greater than or equal to 1'); + } + } else { + dilations.push(1); + } + } + + // adjust pads length to match 2 * kernel shape length + for (let dim = 0; dim < kernelShape.length * 2; dim++) { + if (dim < pads.length) { + if (pads[dim] < 0) { + throw new Error('pad should be greater than or equal to 1'); + } + } else { + pads.push(0); + } + } + + // sanity checks for values in kernel shapes and pads + for (let dim = 0; dim < kernelShape.length; dim++) { + if (kernelShape[dim] <= 0) { + throw new Error('kernel shapes need to be greater than 0'); + } + + if (pads[dim] >= kernelShape[dim] || pads[dim + kernelShape.length] >= kernelShape[dim]) { + throw new Error('pads should be smaller than kernel'); + } + } + } + + // adjust pad values based on 'autoPad' attribute + static adjustPadsBasedOnAutoPad( + inputDims: readonly number[], strides: readonly number[], dilations: readonly number[], + kernelShape: readonly number[], pads: number[], isChannelLast: boolean, autoPad?: string): void { + if (!autoPad) { + return; + } + + if (pads.length !== 2 * (inputDims.length - 2)) { + throw new Error('length of pads should be twice the length of data dimensions'); + } + + if (strides.length !== (inputDims.length - 2)) { + throw new Error('length of strides should be the length of data dimensions'); + } + + if (kernelShape.length !== (inputDims.length - 2)) { + throw new Error('length of kernel shapes should be the length of data dimensions'); + } + + for (let dim = 0; dim < inputDims.length - 2; dim++) { + PoolConvUtil.adjustPadAndReturnShape( + inputDims[dim + (isChannelLast ? 1 : 2)], strides[dim], dilations[dim], kernelShape[dim], pads, dim, + dim + inputDims.length - 2, autoPad); + } + } + + /** + * Calculate the output shape for Pool ops based on input attributes. (Should be used only for Pool ops) + * @param isGlobalOperator If true, perform global pooling. + * @param inputDims The input tensor dimension. (inputs[0].dims) + * @param strides Stride along each axis. + * @param dilations Dilation along each axis. + * @param kernelShape The size of the kernel along each axis. + * @param pads Padding for the beginning and ending along each axis. + * @param autoPad DEPRECATED attribute supported for legacy models. Specifies how to implicitly calculate pads in each + * dimension. Can take values NOTSET, SAME_UPPER, SAME_LOWER, or VALID. + */ + static computePoolOutputShape( + isGlobalOperator: boolean, inputDims: readonly number[], strides: number[], dilations: number[], + kernelShape: number[], pads: number[], autoPad?: string): number[] { + if (inputDims.length <= 0) { + throw new Error('input shape must be of size greater than 0'); + } + + // Add batch size and number of channels of output + const outputDims = [inputDims[0], inputDims[1]]; + + PoolConvUtil.computeShapeHelper( + isGlobalOperator, inputDims, outputDims, strides, dilations, kernelShape, pads, autoPad); + return outputDims; + } + + /** + * Calculate the output shape for Conv op based on input attributes. (Should be used only for Conv op) + * @param inputDims The input tensor dimension. (inputs[0].dims) + * @param filterDims The filter tensor dimension. (inputs[1].dims) + * @param strides Stride along each axis. + * @param kernelShape The size of the kernel along each axis. + * @param pads Padding for the beginning and ending along each axis. + * @param autoPad DEPRECATED attribute supported for legacy models. Specifies how to implicitly calculate pads in each + * dimension. Can take values NOTSET, SAME_UPPER, SAME_LOWER, or VALID. + */ + static computeConvOutputShape( + inputDims: readonly number[], filterDims: readonly number[], strides: number[], dilations: number[], + kernelShape: number[], pads: number[], autoPad?: string): number[] { + if (inputDims.length <= 0 || filterDims.length <= 0) { + throw new Error('invalid input tensor dims or invalid filter tensor dims'); + } + + // Add batch size and number of channels of output + const outputDims = [inputDims[0], filterDims[0]]; + + PoolConvUtil.computeShapeHelper(false, inputDims, outputDims, strides, dilations, kernelShape, pads, autoPad); + return outputDims; + } + + // will compute output shapes for data dimensions ONLY (i.e.) no batch size and channels + // called by computePoolOutputShape() and computeConvOutputShape() + // adjust pads based on 'autoPad' attribute prior to shape computation + private static computeShapeHelper( + isGlobalOperator: boolean, inputDims: readonly number[], outputDims: number[], strides: readonly number[], + dilations: readonly number[], kernelShape: readonly number[], pads: number[], autoPad?: string) { + if (isGlobalOperator) { + for (let dim = 0; dim < inputDims.length - 2; dim++) { + outputDims.push(1); + } + } else { + for (let dim = 0; dim < inputDims.length - 2; dim++) { + outputDims.push(PoolConvUtil.adjustPadAndReturnShape( + inputDims[dim + 2], strides[dim], dilations[dim], kernelShape[dim], pads, dim, dim + inputDims.length - 2, + autoPad)); + } + } + } + + // helper for computeShapeHelper() and adjustPadsBasedOnAutoPad() + // adjusts pad value for given 'autoPad' string and computes output shape along a particular dimension + private static adjustPadAndReturnShape( + inSize: number, stride: number, dilation: number, kernel: number, pads: number[], padHeadIndex: number, + padTailIndex: number, autoPad?: string): number { + const dkernel = dilation * (kernel - 1) + 1; + if (autoPad && autoPad !== 'NOTSET') { + switch (autoPad) { + case 'VALID': + pads[padHeadIndex] = 0; + pads[padTailIndex] = 0; + return Math.floor(((inSize - dkernel) / stride) + 1); + case 'SAME_LOWER': + case 'SAME_UPPER': + if (dilation !== 1) { + throw new Error('Dilation not supported for SAME_UPPER or SAME_LOWER'); + } else { + const legacyTargetSize = (inSize + stride - 1) / stride; + const padNeeded = (legacyTargetSize - 1) * stride + kernel - inSize; + pads[padHeadIndex] = + (autoPad === 'SAME_LOWER') ? Math.floor((padNeeded + 1) / 2) : Math.floor(padNeeded / 2); + pads[padTailIndex] = padNeeded - pads[padHeadIndex]; + return Math.floor(((inSize + padNeeded - kernel) / stride) + 1); + } + default: + throw new Error('Unsupported AutoPad type'); + } + } else { + return Math.floor(((inSize + pads[padHeadIndex] + pads[padTailIndex] - dkernel) / stride) + 1); + } + } +} + +export class GemmUtil { + // will make sure input shapes are compatible for this op + // and return back the shape of the output in the form of a tuple + // will throw exception if the input shapes are not compatible + static getShapeOfGemmResult( + leftShape: readonly number[], transLeft: boolean, rightShape: readonly number[], transRight: boolean, + biasShape?: readonly number[]): readonly number[] { + if (leftShape.length !== 2 || rightShape.length !== 2) { + throw new Error('shape need to be of size 2'); + } + + let M: number; + let K: number; + let N: number; + + if (transLeft) { + M = leftShape[1]; + K = leftShape[0]; + } else { + M = leftShape[0]; + K = leftShape[1]; + } + + let kDim = -1; + + if (transRight) { + N = rightShape[0]; + kDim = 1; + } else { + N = rightShape[1]; + kDim = 0; + } + + if (rightShape[kDim] !== K) { + throw new Error('dimension mismatch'); + } + + if (M <= 0 || N <= 0 || K <= 0) { + throw new Error('invalid shape specified'); + } + + if (biasShape && !BroadcastUtil.isValidBroadcast(biasShape, [M, N])) { + throw new Error('gemm: invalid bias shape for broadcast'); + } + + return [M, N, K]; + } +} + + +export const MIN_CLIP = -3.4028234663852886e+38; +export const MAX_CLIP = 3.4028234663852886e+38; diff --git a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts new file mode 100644 index 0000000000000..adba0fb9d022d --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +class AttributeWithCacheKeyImpl { + constructor(attribute: Record) { + Object.assign(this, attribute); + } + + private _cacheKey: string; + public get cacheKey(): string { + if (!this._cacheKey) { + this._cacheKey = + Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); + } + return this._cacheKey; + } +} + +export interface AttributeWithCacheKey { + readonly cacheKey: string; +} + +/** + * create a new object from the given attribute, and add a cacheKey property to it + */ +export const createAttributeWithCacheKey = >(attribute: T): T&AttributeWithCacheKey => + new AttributeWithCacheKeyImpl(attribute) as unknown as T & AttributeWithCacheKey; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts new file mode 100644 index 0000000000000..076ec8ca7b5ec --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -0,0 +1,231 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {WebGpuBackend} from '../backend-webgpu'; +import {LOG_DEBUG} from '../log'; + +import {GpuData, GpuDataId, GpuDataType} from './types'; + +/** + * manages GpuDataId -> GpuBuffer + */ +export interface GpuDataManager { + /** + * copy data from CPU to GPU. + */ + upload(id: GpuDataId, data: Uint8Array): void; + /** + * copy data from GPU to GPU. + */ + memcpy(sourceId: GpuDataId, destinationId: GpuDataId): void; + /** + * create new data on GPU. + */ + create(size: number, usage?: number): GpuData; + /** + * get GPU data by ID. + */ + get(id: GpuDataId): GpuData|undefined; + /** + * release the data on GPU by ID. + * + * @return size of the data released + */ + release(id: GpuDataId): number; + /** + * copy data from GPU to CPU. + */ + download(id: GpuDataId): Promise; + + /** + * refresh the buffers that marked for release. + * + * when release() is called, the buffer is not released immediately. this is because we need to wait for the commands + * to be submitted to the GPU. this function is called after the commands are submitted so that the buffers can be + * actually released. + */ + refreshPendingBuffers(): void; +} + +interface StorageCacheValue { + gpuData: GpuData; + originalSize: number; +} + +interface DownloadCacheValue { + data: Promise; +} + +/** + * normalize the buffer size so that it fits the 128-bits (16 bytes) alignment. + */ +const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; + +let guid = 0; +const createNewGpuDataId = () => guid++; + +class GpuDataManagerImpl implements GpuDataManager { + // GPU Data ID => GPU Data ( storage buffer ) + storageCache: Map; + + // GPU Data ID => GPU Data ( read buffer ) + downloadCache: Map; + + // pending buffers for uploading ( data is unmapped ) + private buffersForUploadingPending: GPUBuffer[]; + // pending buffers for computing + private buffersPending: GPUBuffer[]; + + constructor(private backend: WebGpuBackend /* , private reuseBuffer: boolean */) { + this.storageCache = new Map(); + this.downloadCache = new Map(); + this.buffersForUploadingPending = []; + this.buffersPending = []; + } + + upload(id: GpuDataId, data: Uint8Array): void { + const srcArrayBuffer = data.buffer; + const srcOffset = data.byteOffset; + const srcLength = data.byteLength; + const size = calcNormalizedBufferSize(srcLength); + + // get destination gpu buffer + const gpuDataCache = this.storageCache.get(id); + if (!gpuDataCache) { + throw new Error('gpu data for uploading does not exist'); + } + if (gpuDataCache.originalSize !== srcLength) { + throw new Error(`inconsistent data size. gpu data size=${gpuDataCache.originalSize}, data size=${srcLength}`); + } + + // create gpu buffer + const gpuBufferForUploading = this.backend.device.createBuffer( + // eslint-disable-next-line no-bitwise + {mappedAtCreation: true, size, usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC}); + + // copy (upload) data + const arrayBuffer = gpuBufferForUploading.getMappedRange(); + new Uint8Array(arrayBuffer).set(new Uint8Array(srcArrayBuffer, srcOffset, srcLength)); + gpuBufferForUploading.unmap(); + + + // GPU copy + const commandEncoder = this.backend.getCommandEncoder(); + this.backend.endComputePass(); + commandEncoder.copyBufferToBuffer(gpuBufferForUploading, 0, gpuDataCache.gpuData.buffer, 0, size); + + LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.upload(id=${id})`); + + this.buffersForUploadingPending.push(gpuBufferForUploading); + } + + memcpy(sourceId: GpuDataId, destinationId: GpuDataId): void { + // get source gpu buffer + const sourceGpuDataCache = this.storageCache.get(sourceId); + if (!sourceGpuDataCache) { + throw new Error('source gpu data for memcpy does not exist'); + } + // get destination gpu buffer + const destinationGpuDataCache = this.storageCache.get(destinationId); + if (!destinationGpuDataCache) { + throw new Error('destination gpu data for memcpy does not exist'); + } + if (sourceGpuDataCache.originalSize !== destinationGpuDataCache.originalSize) { + throw new Error('inconsistent source and destination gpu data size'); + } + const size = calcNormalizedBufferSize(sourceGpuDataCache.originalSize); + // GPU copy + this.backend.getCommandEncoder().copyBufferToBuffer( + sourceGpuDataCache.gpuData.buffer, 0, destinationGpuDataCache.gpuData.buffer, 0, size); + } + + // eslint-disable-next-line no-bitwise + create(size: number, usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST): GpuData { + // !!! + // !!! IMPORTANT: TODO: whether we should keep the storage buffer every time, or always create new ones. + // !!! This need to be figured out by performance test results. + // !!! + + const bufferSize = calcNormalizedBufferSize(size); + + // create gpu buffer + const gpuBuffer = this.backend.device.createBuffer({size: bufferSize, usage}); + + const gpuData = {id: createNewGpuDataId(), type: GpuDataType.default, buffer: gpuBuffer}; + this.storageCache.set(gpuData.id, {gpuData, originalSize: size}); + + LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.create(size=${size}) => id=${gpuData.id}`); + return gpuData; + } + + get(id: GpuDataId): GpuData|undefined { + return this.storageCache.get(id)?.gpuData; + } + + release(id: GpuDataId): number { + const cachedData = this.storageCache.get(id); + if (!cachedData) { + throw new Error('releasing data does not exist'); + } + + LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.release(id=${id}), gpuDataId=${cachedData.gpuData.id}`); + + this.storageCache.delete(id); + this.buffersPending.push(cachedData.gpuData.buffer); + // cachedData.gpuData.buffer.destroy(); + + const downloadingData = this.downloadCache.get(id); + if (downloadingData) { + this.downloadCache.delete(id); + } + + return cachedData.originalSize; + } + + async download(id: GpuDataId): Promise { + const downloadData = this.downloadCache.get(id); + if (downloadData) { + return downloadData.data; + } + + const cachedData = this.storageCache.get(id); + if (!cachedData) { + throw new Error('data does not exist'); + } + + const commandEncoder = this.backend.getCommandEncoder(); + this.backend.endComputePass(); + const gpuReadBuffer = this.backend.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: cachedData.originalSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); + commandEncoder.copyBufferToBuffer( + cachedData.gpuData.buffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, + 0 /* destination offset */, cachedData.originalSize /* size */ + ); + this.backend.flush(); + + const readDataPromise = new Promise((resolve) => { + gpuReadBuffer.mapAsync(GPUMapMode.READ).then(() => { + const data = gpuReadBuffer.getMappedRange().slice(0); + gpuReadBuffer.destroy(); + resolve(data); + }); + }); + + this.downloadCache.set(id, {data: readDataPromise}); + + return readDataPromise; + } + + refreshPendingBuffers(): void { + for (const buffer of this.buffersForUploadingPending) { + buffer.destroy(); + } + for (const buffer of this.buffersPending) { + buffer.destroy(); + } + } +} + +export const createGpuDataManager = (...args: ConstructorParameters): GpuDataManager => + new GpuDataManagerImpl(...args); diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts new file mode 100644 index 0000000000000..10875b2abddc1 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import * as binaryOps from './ops/binary-op'; +import {conv, parseConvAttributes} from './ops/conv'; +import {gemm, parseGemmAttributes} from './ops/gemm'; +import {matMul} from './ops/matmul'; +import * as pool from './ops/pool'; +import {parseTransposeAttributes, transpose} from './ops/transpose'; +import * as unaryOps from './ops/unary-op'; +import {ComputeContext} from './types'; + +export type RunFunction = (context: ComputeContext, attribute?: unknown) => number; +export type ParseAttributeFunction = (attributeRaw: unknown) => unknown; +export type OperatorImplementation = [RunFunction]|[RunFunction, ParseAttributeFunction]; + +export const WEBGPU_OP_RESOLVE_RULES: Map = new Map([ + ['Abs', [unaryOps.abs]], + ['Acos', [unaryOps.acos]], + ['Acosh', [unaryOps.acosh]], + ['Add', [binaryOps.add]], + ['Asin', [unaryOps.asin]], + ['Asinh', [unaryOps.asinh]], + ['Atan', [unaryOps.atan]], + ['Atanh', [unaryOps.atanh]], + // TODO: support new attributes for AveragePool-10 + ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], + ['Ceil', [unaryOps.ceil]], + ['ClipV10', [unaryOps.clipV10]], + ['Clip', [unaryOps.clip]], + ['Conv', [conv, parseConvAttributes]], + ['Cos', [unaryOps.cos]], + ['Cosh', [unaryOps.cosh]], + ['Div', [binaryOps.div]], + ['Elu', [unaryOps.elu, unaryOps.parseEluAttributes]], + ['Erf', [unaryOps.erf]], + ['Floor', [unaryOps.floor]], + ['Gemm', [gemm, parseGemmAttributes]], + ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], + ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], + ['MatMul', [matMul]], + // TODO: support new attributes for MaxPool-8 and MaxPool-10 + ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], + ['Mul', [binaryOps.mul]], + ['Neg', [unaryOps.neg]], + ['Pow', [binaryOps.pow]], + ['Reciprocal', [unaryOps.reciprocal]], + ['Sigmoid', [unaryOps.sigmoid]], + ['Sin', [unaryOps.sin]], + ['Sinh', [unaryOps.sinh]], + ['Sqrt', [unaryOps.sqrt]], + ['Sub', [binaryOps.sub]], + ['Tan', [unaryOps.tan]], + ['Tanh', [unaryOps.tanh]], + ['Transpose', [transpose, parseTransposeAttributes]], +]); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts new file mode 100644 index 0000000000000..5345367eadfef --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/activation_util.ts +// +// modified to fit the needs of the project + +export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid'; + +export const typeSnippet = (component: number) => { + switch (component) { + case 1: + return 'f32'; + case 2: + return 'vec2'; + case 3: + return 'vec3'; + case 4: + return 'vec4'; + default: + throw new Error(`${component}-component is not supported.`); + } +}; + +export const activationFnSnippet = + (activation?: Activation, _hasPreluActivationWeights = false, _packed = false, _coordsLength = 3): string => { + if (!activation) { + return ''; + } + + // TODO: add implementations + return ''; + }; + +export const biasActivationSnippet = (hasBias: boolean, activation?: Activation): string => ` + ${hasBias ? 'value = value + getBiasByOutputCoords(coords);' : ''} + ${activation ? 'value = activation(value, coords);' : ''} + `; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts new file mode 100644 index 0000000000000..b77e9bea7b871 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -0,0 +1,250 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts +// +// modified to fit the needs of the project + +import {LOG_DEBUG} from '../../../log'; +import {TensorView} from '../../../tensor'; +import {ShapeUtil} from '../../../util'; +import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {ConvAttributes} from '../conv'; + +import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {utilFunctions} from './conv_util'; +import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; + +const conv2dCommonSnippet = + (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, + activation?: Activation, hasPreluActivationWeights = false, innerElementSizeX = 4, innerElementSizeW = 4, + innerElementSize = 4): string => { + const getXSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'resData = x[xIndex];'; + case 3: + return 'resData = vec3(x[xIndex], x[xIndex + 1], x[xIndex + 2]);'; + case 4: + return 'resData = x[xIndex / 4];'; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return w[row * wShape[3] + colIn];'; + case 4: + return 'return w[row * wShape[3] / 4 + colIn];'; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + const coordASnippet = isChannelsLast ? ` + let coord = vec4(batch, xRow, xCol, xCh); + ` : + ` + let coord = vec4(batch, xCh, xRow, xCol); + `; + + const coordResSnippet = isChannelsLast ? ` + let coords = vec4( + batch, + row / outWidth, + row % outWidth, + col); + ` : + ` + let coords = vec4( + batch, + row, + col / outWidth, + col % outWidth); + `; + + const xHeight = isChannelsLast ? 'xShape[1]' : 'xShape[2]'; + const xWidth = isChannelsLast ? 'xShape[2]' : 'xShape[3]'; + const row = isChannelsLast ? 'row' : 'col'; + const col = isChannelsLast ? 'col' : 'row'; + const readXSnippet = ` + let inChannels = wShape[2]; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outRow = ${row} / outWidth; + let outCol = ${row} % outWidth; + + let WRow = ${col} / (filterDims[1] * inChannels); + let WCol = ${col} / inChannels % filterDims[1]; + let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; + let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; + let xCh = ${col} % inChannels; + var resData = ${typeSnippet(innerElementSizeX)}(0.0); + // The bounds checking is always needed since we use it to pad zero for + // the 'same' padding type. + if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) { + ${coordASnippet} + let xIndex = getIndexFromCoords4D(coord, xShape); + ${getXSnippet(innerElementSizeX)} + } + return resData;`; + + const sampleX = isChannelsLast ? (fitAOuter && fitInner ? ` + let col = colIn * ${innerElementSizeX}; + ${readXSnippet}` : + ` + let col = colIn * ${innerElementSizeX}; + if (row < dimAOuter && col < dimInner) { + ${readXSnippet} + } + return ${typeSnippet(innerElementSizeX)}(0.0);`) : + (fitInner && fitBOuter ? ` + let col = colIn * ${innerElementSizeX}; + ${readXSnippet}` : + ` + let col = colIn * ${innerElementSizeX}; + if (row < dimInner && col < dimBOuter) { + ${readXSnippet} + } + return ${typeSnippet(innerElementSizeX)}(0.0);`); + + const sampleW = `${getWSnippet(innerElementSizeW)}`; + + const resType = typeSnippet(innerElementSize); + const aType = isChannelsLast ? typeSnippet(innerElementSizeX) : typeSnippet(innerElementSizeW); + const bType = isChannelsLast ? typeSnippet(innerElementSizeW) : typeSnippet(innerElementSizeX); + const userCode = ` + ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { + ${isChannelsLast ? sampleX : sampleW} + } + + fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${bType} { + ${isChannelsLast ? sampleW : sampleX} + } + + fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { + let col = colIn * ${innerElementSize}; + if (row < dimAOuter && col < dimBOuter) + { + var value = valueIn; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + ${coordResSnippet} + ${biasActivationSnippet(addBias, activation)} + setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value); + } + }`; + return userCode; + }; + +export const createConv2DMatMulProgramInfo = + (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvAttributes, + outputShape: readonly number[], dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfo => { + const isChannelsLast = attributes.format === 'NHWC'; + const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const batchSize = outputShape[0]; + const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; + const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; + const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; + const isVec4 = (((inChannels % 4 === 0 || inChannels % 3 === 0) && isChannelsLast) || + (outWidth % 4 === 0 && !isChannelsLast)) && + outChannels % 4 === 0; + + // TODO: fine tune size + const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; + const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; + const workGroupSize: [number, number, number] = + isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; + const elementsPerThread = + isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; + const dispatch = [ + Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), + Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[1]) + ]; + + LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); + + const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : elementsPerThread[0]; + + const tileAOuter = workGroupSize[1] * elementsPerThread[1]; + const tileBOuter = workGroupSize[0] * elementsPerThread[0]; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + + const fitAOuter = dimAOuter % tileAOuter === 0; + const fitBOuter = dimBOuter % tileBOuter === 0; + const fitInner = dimInner % tileInner === 0; + + const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; + + const declareInputs = [ + `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? 'vec4' : 'f32'}>;`, + `@group(0) @binding(1) var w: array<${isVec4 ? 'vec4' : 'f32'}>;` + ]; + let declareFunctions = ` + fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { + result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + } + fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { + let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); + setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); + }`; + if (hasBias) { + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), + getShaderSource: () => ` + ${utilFunctions} + //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, + // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, + // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; + ${declareInputs.join('')} + @group(0) @binding(${declareInputs.length}) var result: array<${ + isVec4 ? 'vec4' : 'f32'}>; + //@group(0) @binding(${declareInputs.length + 1}) var uniforms: Uniforms; + + const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); + const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); + const outShape : vec4 = vec4(${outputShape.join(',')}); + const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); + const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); + const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); + const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); + const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + const dimAOuter : i32 = ${dimAOuter}; + const dimBOuter : i32 = ${dimBOuter}; + const dimInner : i32 = ${dimInner}; + ${declareFunctions} + ${ + conv2dCommonSnippet( + isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], + elementsSize[1], elementsSize[2])} + ${ + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, !isChannelsLast, tileInner, false, undefined, + sequentialAccessByThreads)}` + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts new file mode 100644 index 0000000000000..0ba48a33fbc47 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// sampled from [@tensorflow/tfjs] tfjs-core/src/ops/conv_util.ts +// +// modified to fit the needs of the project + +export const utilFunctions = ` +fn getIndexFromCoords4D(coords : vec4, shape : vec4) -> i32 { + return dot(coords, vec4( + shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1)); +} +fn getOutputIndexFromCoords(coords : vec4) -> i32 { + return dot(coords, vec4( + outShapeStrides.x, outShapeStrides.y, outShapeStrides.z, 1)); +} +`; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts new file mode 100644 index 0000000000000..d30821e508083 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -0,0 +1,327 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/matmul_packed_webgpu.ts +// +// modified to fit the needs of the project + +const writeDataToSubAVec4Snippet = (transpose: boolean) => { + if (transpose) { + return ` + mm_Asub[inputRow][inputCol] = mm_readA(batch, + kStart + inputRow, + globalRowStart / innerElementSize + inputCol); + `; + + } else { + return ` + mm_Asub[inputRow][inputCol] = mm_readA(batch, + globalRow + innerRow, + kStart / innerElementSize + inputCol); + `; + } +}; + +const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) => { + if (transposeA) { + return ` + let ACached0 = mm_Asub[k * innerElementSize][localRow]; + let ACached1 = mm_Asub[k * innerElementSize + 1][localRow]; + let ACached2 = mm_Asub[k * innerElementSize + 2][localRow]; + ${innerElementSize === 3 ? '' : 'let ACached3 = mm_Asub[k * innerElementSize + 3][localRow];'} + for (var i = 0; i < rowPerThread; i = i + 1) { + acc[i] = BCached0 * ACached0[i] + acc[i]; + acc[i] = BCached1 * ACached1[i] + acc[i]; + acc[i] = BCached2 * ACached2[i] + acc[i]; + ${innerElementSize === 3 ? '' : 'acc[i] = BCached3 * ACached3[i] + acc[i];'} + }`; + } else { + return ` + for (var i = 0; i < rowPerThread; i = i + 1) { + let ACached = mm_Asub[tileRow + i][k]; + acc[i] = BCached0 * ACached.x + acc[i]; + acc[i] = BCached1 * ACached.y + acc[i]; + acc[i] = BCached2 * ACached.z + acc[i]; + ${innerElementSize === 3 ? '' : 'acc[i] = BCached3 * ACached.w + acc[i];'} + }`; + } +}; + +export const makeMatMulPackedVec4Source = + (workPerThread: number[], workgroupSize: [number, number, number], transposeA = false, tileInner = 32, + splitK = false, splitedDimInner = 32, isVectorA = false): string => { + const tileAOuter = workgroupSize[1] * workPerThread[1]; + const tileBOuter = workgroupSize[0] * workPerThread[0]; + const tileAWidth = transposeA ? tileAOuter : tileInner; + const tileAHight = transposeA ? tileInner : tileAOuter; + const innerElementSize = tileAWidth / workgroupSize[0]; + const rowPerThreadB = tileInner / workgroupSize[1]; + + if (!(((transposeA && innerElementSize === 4 && workPerThread[1] === 4) || + (!transposeA && (innerElementSize === 3 || innerElementSize === 4))) && + tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0 && workPerThread[0] === 4)) { + throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${ + innerElementSize} and workPerThread[1] ${workPerThread[1]} must be 4. + Otherwise, innerElementSize ${innerElementSize} must be 3 or 4. + tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}. tileInner ${ + tileInner} must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${ + workPerThread[0]} must be 4.`); + } + return ` +var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; +var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; + +const rowPerThread = ${workPerThread[1]}; +const colPerThread = ${workPerThread[0]}; +const innerElementSize = ${innerElementSize}; +const tileInner = ${tileInner}; + +@compute @workgroup_size(${workgroupSize[0]}, ${workgroupSize[1]}, ${workgroupSize[2]}) +fn main(@builtin(local_invocation_id) localId : vec3, + @builtin(global_invocation_id) globalId : vec3, + @builtin(workgroup_id) workgroupId : vec3) { + let localRow = i32(localId.y); + let tileRow = ${isVectorA ? '0' : 'localRow * rowPerThread'}; + let tileCol = i32(localId.x); + + let globalRow = ${isVectorA ? '0' : 'i32(globalId.y) * rowPerThread'}; + let globalCol = i32(globalId.x); + let batch = ${splitK ? '0' : 'i32(globalId.z)'}; + let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; + + let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; + var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; + + var acc: array, rowPerThread>; + + // Loop over shared dimension. + let tileRowB = localRow * ${rowPerThreadB}; + for (var t = 0; t < numTiles; t = t + 1) { + // Load one tile of A into local memory. + for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + let inputRow = tileRow + innerRow; + let inputCol = tileCol; + ${writeDataToSubAVec4Snippet(transposeA)} + } + + // Load one tile of B into local memory. + for (var innerRow = 0; innerRow < ${rowPerThreadB}; innerRow = innerRow + 1) { + let inputRow = tileRowB + innerRow; + let inputCol = tileCol; + mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol); + } + kStart = kStart + tileInner; + workgroupBarrier(); + + // Compute acc values for a single thread. + for (var k = 0; k < tileInner / innerElementSize; k = k + 1) { + let BCached0 = mm_Bsub[k * innerElementSize][tileCol]; + let BCached1 = mm_Bsub[k * innerElementSize + 1][tileCol]; + let BCached2 = mm_Bsub[k * innerElementSize + 2][tileCol]; + ${innerElementSize === 3 ? '' : 'let BCached3 = mm_Bsub[k * innerElementSize + 3][tileCol];'} + + ${calculateResultSnippet(transposeA, innerElementSize)} + } + + workgroupBarrier(); + } + + for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]); + } +}`; + }; + +const writeDataToSubASnippet = (transpose: boolean) => { + if (transpose) { + return ` + mm_Asub[inputRow][inputCol] = mm_readA(batch, + kStart + inputRow, + globalRowStart + inputCol); + `; + + } else { + return ` + mm_Asub[inputRow][inputCol] = mm_readA(batch, + globalRowStart + inputRow, + kStart + inputCol); + `; + } +}; + +const readDataFromSubASnippet = (transposeA: boolean) => + transposeA ? 'let ACached = mm_Asub[k][tileRow + innerRow];' : 'let ACached = mm_Asub[tileRow + innerRow][k];'; + +// sequentialAccessByThreads means sequential data in memory is accessed by +// threads, instead of a single thread (default behavior). +export const makeMatMulPackedSource = + (workPerThread: number[], workgroupSize: [number, number, number], transposeA = false, tileInner = 32, + splitK = false, splitedDimInner = 32, sequentialAccessByThreads = false): string => { + const tileAOuter = workPerThread[1] * workgroupSize[1]; + const tileBOuter = workPerThread[0] * workgroupSize[0]; + const tileAWidth = transposeA ? tileAOuter : tileInner; + const tileAHight = transposeA ? tileInner : tileAOuter; + + if (!(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 && + tileInner % workgroupSize[1] === 0)) { + throw new Error(`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${ + workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${ + workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`); + } + const rowPerThreadA = tileAHight / workgroupSize[1]; + const colPerThreadA = tileAWidth / workgroupSize[0]; + const rowPerThreadB = tileInner / workgroupSize[1]; + const matmulSnippet = sequentialAccessByThreads ? + ` + let localRow = i32(localId.y); + let localCol = i32(localId.x); + let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; + let globalColStart = i32(workgroupId.x) * ${tileBOuter}; + + // Loop over shared dimension. + for (var t = 0; t < numTiles; t = t + 1) { + // Load one tile of A into local memory. + for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) { + for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) { + ${writeDataToSubASnippet(transposeA)} + } + } + // Load one tile of B into local memory. + for (var inputRow = localRow; inputRow < ${tileInner}; inputRow = inputRow + ${workgroupSize[1]}) { + for (var inputCol = localCol; inputCol < ${tileBOuter}; inputCol = inputCol + ${workgroupSize[0]}) { + mm_Bsub[inputRow][inputCol] = mm_readB(batch, + kStart + inputRow, + globalColStart + inputCol); + } + } + kStart = kStart + tileInner; + workgroupBarrier(); + + // Compute acc values for a single thread. + var BCached : array; + for (var k = 0; k < tileInner; k = k + 1) { + for (var inner = 0; inner < colPerThread; inner = inner + 1) { + BCached[inner] = mm_Bsub[k][localCol + inner * ${workgroupSize[0]}]; + } + for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + let ACached = ${ + transposeA ? `mm_Asub[k][localRow + innerRow * ${workgroupSize[1]}];` : + `mm_Asub[localRow + innerRow * ${workgroupSize[1]}][k];`} + for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + acc[innerRow][innerCol] = acc[innerRow][innerCol] + + ACached * BCached[innerCol]; + } + } + } + workgroupBarrier(); + } + for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + let gRow = globalRowStart + localRow + innerRow * ${workgroupSize[1]}; + for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + let gCol = globalColStart + localCol + innerCol * ${workgroupSize[0]}; + mm_write(batch, gRow, gCol, acc[innerRow][innerCol]); + } + } + ` : + ` +let tileRow = i32(localId.y) * rowPerThread; +let tileCol = i32(localId.x) * colPerThread; + +let globalRow = i32(globalId.y) * rowPerThread; +let globalCol = i32(globalId.x) * colPerThread; +let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; + +let tileRowA = i32(localId.y) * ${rowPerThreadA}; +let tileColA = i32(localId.x) * ${colPerThreadA}; +let tileRowB = i32(localId.y) * ${rowPerThreadB}; +// Loop over shared dimension. +for (var t = 0; t < numTiles; t = t + 1) { + // Load one tile of A into local memory. + for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow = innerRow + 1) { + for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) { + let inputRow = tileRowA + innerRow; + let inputCol = tileColA + innerCol; + ${writeDataToSubASnippet(transposeA)} + } + } + + // Load one tile of B into local memory. + for (var innerRow = 0; innerRow < ${rowPerThreadB}; innerRow = innerRow + 1) { + for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + let inputRow = tileRowB + innerRow; + let inputCol = tileCol + innerCol; + mm_Bsub[inputRow][inputCol] = mm_readB(batch, + kStart + inputRow, + globalCol + innerCol); + } + } + kStart = kStart + tileInner; + workgroupBarrier(); + + // Compute acc values for a single thread. + var BCached : array; + for (var k = 0; k < tileInner; k = k + 1) { + for (var inner = 0; inner < colPerThread; inner = inner + 1) { + BCached[inner] = mm_Bsub[k][tileCol + inner]; + } + + for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + ${readDataFromSubASnippet(transposeA)} + for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol]; + } + } + } + + workgroupBarrier(); +} + +for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + mm_write(batch, globalRow + innerRow, globalCol + innerCol, + acc[innerRow][innerCol]); + } +} +`; + + return ` + var mm_Asub : array, ${tileAHight}>; + var mm_Bsub : array, ${tileInner}>; + const rowPerThread = ${workPerThread[1]}; + const colPerThread = ${workPerThread[0]}; + const tileInner = ${tileInner}; + +@compute @workgroup_size(${workgroupSize[0]}, ${workgroupSize[1]}, ${workgroupSize[2]}) +fn main(@builtin(local_invocation_id) localId : vec3, + @builtin(global_invocation_id) globalId : vec3, + @builtin(workgroup_id) workgroupId : vec3) { + let batch = ${splitK ? '0' : 'i32(globalId.z)'}; + let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; + var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; + + var acc : array, rowPerThread>; + + // Without this initialization strange values show up in acc. + for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + acc[innerRow][innerCol] = 0.0; + } + } + ${matmulSnippet} + } +`; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts new file mode 100644 index 0000000000000..604f4fc66e1ac --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {BroadcastUtil, ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createIndicesHelper, ShaderHelper} from './common'; + +type BuiltinFunctionName = string; +type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; +type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ + scalar: BinaryCustomExpression; + vector: BinaryCustomExpression; +}; + +const createBinaryOpProgramShader = + (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], + vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, additionalImplementation?: string, + typeA = 'f32', typeB = 'f32', typeOutput = 'f32') => { + const outputSize = ShapeUtil.size(dimsOutput); + const vecSize = Math.ceil(outputSize / 4); + + let expressionScalar: BinaryCustomExpression; + let expressionVector: BinaryCustomExpression; + if (typeof funcCall === 'string') { + expressionScalar = expressionVector = (a, b) => `${funcCall}((${a}),(${b}))`; + } else if (typeof funcCall === 'function') { + expressionScalar = expressionVector = funcCall; + } else { + expressionScalar = funcCall.scalar; + expressionVector = funcCall.vector; + } + + let broadcastImpl = ''; + const outputIndicesHelper = createIndicesHelper('output', dimsOutput); + if (doBroadcast) { + const calcOffsetImpl = (dims: readonly number[]) => { + const strides = ShapeUtil.computeStrides(dims); + const offsets: string[] = []; + for (let i = dims.length - 1; i >= 0; i--) { + const idx = dimsOutput.length === 0 ? '0u' : + (dimsOutput.length === 1) ? '(*outputIndices)' : + `(*outputIndices)[${i + dimsOutput.length - dims.length}]`; + offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`); + } + return offsets.length > 0 ? offsets.join('+') : '0u'; + }; + + broadcastImpl = ` + ${outputIndicesHelper.o2iImpl} + + fn calcOffsetA(outputIndices: ptr) -> u32 { + return ${calcOffsetImpl(dimsA)}; + } + + fn calcOffsetB(outputIndices: ptr) -> u32 { + return ${calcOffsetImpl(dimsB)}; + } + `; + } + + let assignment: string; + if (vectorize) { + if (doBroadcast) { + assignment = ` + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} + ${outputIndicesHelper.o2iCall('global_idx * 4u', 'outputIndices')} + let offsetA = calcOffsetA(&outputIndices); + let offsetB = calcOffsetB(&outputIndices); + outputData[global_idx] = ${expressionVector('aData[offsetA / 4u]', 'bData[offsetB / 4u]')};`; + } else { + assignment = `outputData[global_idx] = ${expressionVector('aData[global_idx]', 'bData[global_idx]')};`; + } + } else { + if (!doBroadcast) { + throw new Error('no necessary to use scalar implementation for element-wise binary op implementation.'); + } + const singleAssignment = (x: number) => { + const expressionA = `aData[indexA${x}][componentA${x}]`; + const expressionB = `bData[indexB${x}][componentB${x}]`; + return ` + ${outputIndicesHelper.o2iCall(`global_idx * 4u + ${x}u`, 'outputIndices')} + let offsetA${x} = calcOffsetA(&outputIndices); + let offsetB${x} = calcOffsetB(&outputIndices); + let indexA${x} = offsetA${x} / 4u; + let indexB${x} = offsetB${x} / 4u; + let componentA${x} = offsetA${x} % 4u; + let componentB${x} = offsetB${x} % 4u; + outputData[global_idx][${x}] = ${expressionScalar(expressionA, expressionB)};`; + }; + + assignment = ` + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} + ${singleAssignment(0)} + ${singleAssignment(1)} + ${singleAssignment(2)} + ${singleAssignment(3)}`; + } + + return ` + @group(0) @binding(0) var aData : array>; + @group(0) @binding(1) var bData : array>; + @group(0) @binding(2) var outputData : array>; + + ${additionalImplementation ?? ''} + ${broadcastImpl} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${assignment} + }`; + }; + +const createBinaryOpProgramInfo = + (metadata: ProgramMetadata, a: TensorView, b: TensorView, funcCall: BinaryFunctionCall, + additionalImplementation?: string, outputDataType: number = a.dataType): ProgramInfo => { + const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims); + let outputShape = a.dims; + let outputSize = ShapeUtil.size(a.dims); + + let vectorize = false; + + // TODO: deal with zero-sized tensors (eg. dims=[1,0]) + + if (isBroadcast) { + const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); + if (!calculatedShape) { + throw new Error('Can\'t perform binary op on the given tensors'); + } + outputShape = calculatedShape; + outputSize = ShapeUtil.size(outputShape); + + // check whether vectorize can be enabled + let sharedDimension = 1; + for (let i = 0; i < outputShape.length; i++) { + const dimA = a.dims[a.dims.length - i] ?? 1; + const dimB = b.dims[b.dims.length - i] ?? 1; + if (dimA === dimB) { + sharedDimension *= dimA; + } else { + break; + } + } + if (sharedDimension % 4 === 0) { + vectorize = true; + } + + + } else { + // element-wise + vectorize = true; + } + + return { + ...metadata, + getShaderSource: (shaderHelper) => createBinaryOpProgramShader( + shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, additionalImplementation), + outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => + ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)}) + }; + }; + +const createBinaryOpProgramInfoLoader = + (inputs: readonly TensorView[], name: string, funcCall: BinaryFunctionCall, additionalImplementation?: string, + cacheKey?: string): ProgramInfoLoader => { + const metadata: + ProgramMetadata = {name, inputTypes: [GpuDataType.default, GpuDataType.default], cacheHint: cacheKey}; + return { + ...metadata, + get: () => createBinaryOpProgramInfo(metadata, inputs[0], inputs[1], funcCall, additionalImplementation) + }; + }; + +export const add = (context: ComputeContext): number => { + context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Add', (a, b) => `${a}+${b}`)); + return 0; +}; + +export const div = (context: ComputeContext): number => { + context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Div', (a, b) => `${a}/${b}`)); + return 0; +}; + +export const mul = (context: ComputeContext): number => { + context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Mul', (a, b) => `${a}*${b}`)); + return 0; +}; + +export const pow = (context: ComputeContext): number => { + context.compute(createBinaryOpProgramInfoLoader( + context.inputs, 'Pow', ({scalar: (a, b) => `pow_f32(${a},${b})`, vector: (a, b) => `pow_vf32(${a},${b})`}), ` + fn pow_f32(a : f32, b : f32) -> f32 { + if (b == 0.0) { + return 1.0; + } else if (a < 0.0 && b != floor(b)) { + return pow(a, b); // NaN + } + return select(sign(a), 1.0, round(abs(b) % 2.0) != 1.0) * pow(abs(a), b); + } + fn pow_vf32(a : vec4, b : vec4) -> vec4 { + // TODO: implement vectorized pow + return vec4(pow_f32(a.x, b.x), pow_f32(a.y, b.y), pow_f32(a.z, b.z), pow_f32(a.w, b.w)); + } + `)); + return 0; +}; + +export const sub = (context: ComputeContext): number => { + context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`)); + return 0; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts new file mode 100644 index 0000000000000..7305ab592d4a7 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {ShapeUtil} from '../../util'; + +/** + * constant value for a workgroup size. + * + * We definitely can do further optimization in future, but for now we use 64. + * + * rule of thumb: Use [a workgroup size of] 64 unless you know what GPU you are targeting or that your workload + * needs something different. + * + * from: https://surma.dev/things/webgpu/ + **/ +export const WORKGROUP_SIZE = 64; + +export interface IndicesHelper { + /** + * WGSL code of function implementation for offset-to-indices + */ + o2iImpl: string; + /** + * WGSL code of function call for offset-to-indices + */ + o2iCall: (varOffset: string, varIndices: string) => string; + /** + * WGSL code of function implementation for indices-to-offset + */ + i2oImpl: string; + /** + * WGSL code of function implementation for indices-to-offset + * + * @param isPtr - whether the variable is a pointer. default is false. + */ + i2oExpression: (varIndices: string, isPtr?: boolean) => string; + /** + * WGSL code of indices variable declaration + * + * @param v - variable name. + * @param init - initial value. + */ + indicesVariableDeclaration: (v: string, init?: string[]) => string; + /** + * data type of indices + */ + iType: string; +} + +export const createIndicesHelper = (name: string, shape: readonly number[]): IndicesHelper => { + const iType = shape.length < 2 ? 'u32' : `array`; + + const strides = ShapeUtil.computeStrides(shape); + let o2iSnippet = ''; + for (let i = 0; i < shape.length - 1; i++) { + o2iSnippet += ` + let dim${i} = current / ${strides[i]}u; + let rest${i} = current % ${strides[i]}u; + (*indices)[${i}] = dim${i}; + current = rest${i}; + `; + } + o2iSnippet += `(*indices)[${shape.length - 1}] = current;`; + + const o2iImpl = shape.length < 2 ? '' : ` + fn ih_o2i_${name}(offset: u32, indices: ptr) { + var current = offset; + ${o2iSnippet} + }`; + + const o2iCall = (varOffset: string, varIndices: string) => + shape.length < 2 ? `${varIndices}=${varOffset};` : `ih_o2i_${name}(${varOffset}, &${varIndices});`; + + const offsets: string[] = []; + if (shape.length === 0) { + offsets.push('0u'); + } else if (shape.length < 2) { + offsets.push('(*indices)'); + } else { + for (let i = shape.length - 1; i >= 0; i--) { + offsets.push(`${strides[i]}u * ((*indices)[${i}])`); + } + } + + const i2oImpl = shape.length < 2 ? '' : ` + fn ih_i2o_${name}(indices: ptr) -> u32 { + return ${offsets.join('+')}; + }`; + + const i2oExpression = (varIndices: string, isPtr?: boolean) => + shape.length < 2 ? `(${isPtr ? '*' : ''}${varIndices})` : `ih_i2o_${name}(${isPtr ? '' : '&'}${varIndices})`; + + const indicesVariableDeclaration = (v: string, init?: string[]) => + `var ${v}:${iType}${init ? `=${iType}(${init.join(',')})` : ''};`; + + return {o2iImpl, o2iCall, i2oImpl, i2oExpression, indicesVariableDeclaration, iType}; +}; + +/** + * A ShaderHelper is a helper class for generating WGSL code. + */ +export interface ShaderHelper { + mainStart(workgroupSize?: number|[number, number, number]): string; + guardAgainstOutOfBoundsWorkgroupSizes(size: unknown): string; +} + +class ShaderHelperImpl implements ShaderHelper { + constructor(private normalizedDispatchGroup: [number, number, number]) {} + guardAgainstOutOfBoundsWorkgroupSizes(size: number|string): string { + // Guard against out-of-bounds work group sizes + const sizeInCode = typeof size === 'number' ? `${size}u` : size; + return `if (global_idx >= ${sizeInCode}) { return; }`; + } + mainStart(workgroupSize: number|[number, number, number] = WORKGROUP_SIZE) { + const workgroupSizeX = typeof workgroupSize === 'number' ? workgroupSize : workgroupSize[0]; + const workgroupSizeY = typeof workgroupSize === 'number' ? 1 : workgroupSize[1]; + const workgroupSizeZ = typeof workgroupSize === 'number' ? 1 : workgroupSize[2]; + + const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1; + const paramList = is1DimensionDispatch ? '@builtin(global_invocation_id) global_id : vec3' : + `@builtin(local_invocation_index) local_index : u32, + @builtin(workgroup_id) workgroup_id : vec3`; + const globalIdxDefinition = is1DimensionDispatch ? + 'let global_idx = global_id.x;' : + `let global_idx = (workgroup_id.z * ${this.normalizedDispatchGroup[0] * this.normalizedDispatchGroup[1]}u + + workgroup_id.y * ${this.normalizedDispatchGroup[0]}u + workgroup_id.x) * ${ + workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`; + + return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ}) + fn main(${paramList}) { + ${globalIdxDefinition} + `; + } +} + +export const createShaderHelper = (dispatchGroup: [number, number, number]): ShaderHelper => + new ShaderHelperImpl(dispatchGroup); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts new file mode 100644 index 0000000000000..0f1381ed6bc21 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createIndicesHelper, IndicesHelper, ShaderHelper} from './common'; + +export interface ConcatAttributes extends AttributeWithCacheKey { + readonly axis: number; +} + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length < 1) { + throw new Error('too few inputs'); + } + + const inputType = inputs[0].dataType; + const inputDimensionality = inputs[0].dims.length; + + for (const input of inputs) { + // make sure types of all inputs match + if (input.dataType !== inputType) { + throw new Error('input tensors should be one type'); + } + + // make sure the dimensionality of all inputs are the same + if (input.dims.length !== inputDimensionality) { + throw new Error('input tensors should have the same shape'); + } + } +}; + +const createConcatProgramMetadata = (inputCount: number, cacheHint: string) => + ({name: 'Concat', inputTypes: Array(inputCount).fill(GpuDataType.default), cacheHint}); + +const calculateInputIndexImpl = (numberOfTensors: number): string => ` + fn calculateInputIndex(index: u32) -> u32 { + for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { + if (index < sizeInConcatAxis[i]) { + return i; + } + } + return ${numberOfTensors}u; + }`; + +const readBufferDataImpl = (indicesHelper: readonly IndicesHelper[], tensorRank: number, dataType: string) => { + const numberOfTensors = indicesHelper.length; + const codeLines: string[] = []; + for (let i = 0; i < numberOfTensors; ++i) { + const returnSnippet = `return input${i}[${indicesHelper[i].i2oExpression('indices', true)}];`; + if (numberOfTensors === 1) { + codeLines.push(returnSnippet); + } else if (i === 0) { + codeLines.push(`if (textureIndex == ${i}u) { ${returnSnippet} }`); + } else if (i === numberOfTensors - 1) { + codeLines.push(`else { ${returnSnippet} }`); + } else { + codeLines.push(`else if (textureIndex == ${i}) { ${returnSnippet} }`); + } + } + return ` + fn readBufferData(textureIndex: u32, indices: ptr) -> ${dataType} { + ${codeLines.join('\n')} + }`; +}; + +const createConcatProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], axis: number, dataType = 'f32'): ProgramInfo => { + const inputShape = inputs[0].dims.slice(); + if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { + throw new Error('axis specified for concat doesn\'t match input dimensionality'); + } + const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; + // ensure all of the non-concatenated axes match each other + // calculate the shape of the output tensor while we do that + const outputShape = inputShape.slice(0); + for (let i = 1; i < inputs.length; i++) { + const dataNShape = inputs[i].dims.slice(); + for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { + // add to the placeholder for computing output shape + if (axisIndex === adjustedAxis) { + outputShape[adjustedAxis] += dataNShape[axisIndex]; + } + // ensure all non-cancatenated axes match each other + else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { + throw new Error('non concat dimensions must match'); + } + } + } + + const outputSize = ShapeUtil.size(outputShape); + const rank = outputShape.length; + + const sizeInConcatAxis = new Array(inputs.length); + const inputStorageBuffersDeclarations = new Array(inputs.length); + const inputIndicesHelpers = new Array(inputs.length); + + let previousSum = 0; + for (let i = 0; i < inputs.length; ++i) { + previousSum += inputs[i].dims[adjustedAxis]; + sizeInConcatAxis[i] = previousSum; + + inputStorageBuffersDeclarations[i] = + `@group(0) @binding(${i}) var input${i} : array<${dataType}>;`; + + inputIndicesHelpers[i] = createIndicesHelper(`input${i}`, inputs[i].dims); + } + + const outputIndicesHelper = createIndicesHelper('output', outputShape); + + const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + + ${inputStorageBuffersDeclarations.join('\n')} + @group(0) @binding(${inputs.length}) var output : array<${dataType}>; + + ${inputIndicesHelpers.map(i => i.i2oImpl).join('\n')} + ${outputIndicesHelper.o2iImpl} + + const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); + ${calculateInputIndexImpl(sizeInConcatAxis.length)} + ${readBufferDataImpl(inputIndicesHelpers, rank, dataType)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + ${outputIndicesHelper.indicesVariableDeclaration('indices')} + ${outputIndicesHelper.o2iCall('global_idx', 'indices')} + + let textureIndex = calculateInputIndex(${indicesAxis}); + if (textureIndex != 0u) { + ${indicesAxis} -= sizeInConcatAxis[textureIndex - 1u]; + } + + output[global_idx] = readBufferData(textureIndex, &indices); + }`; + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +const createConcatProgramInfoLoader = + (inputs: readonly TensorView[], attributes: ConcatAttributes): ProgramInfoLoader => { + const metadata = createConcatProgramMetadata(inputs.length, attributes.cacheKey); + return {...metadata, get: () => createConcatProgramInfo(metadata, inputs, attributes.axis)}; + }; + +export const concat = (context: ComputeContext, attributes: ConcatAttributes): number => { + validateInputs(context.inputs); + context.compute(createConcatProgramInfoLoader(context.inputs, attributes)); + return 0; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts new file mode 100644 index 0000000000000..ebf305a129ce9 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createIndicesHelper, ShaderHelper} from './common'; +import {calculateOutputShape, ConvAttributes} from './conv'; +import {getActicationSnippet} from './fuse-utils'; + +const createGroupedConvProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ + name: 'GroupedConv', + inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default], + cacheHint +}); + +const createGroupedConvProgramInfo = + (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvAttributes, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfo => { + const hasBias = inputs.length > 2; + const processBias = hasBias ? 'value += b[output_channel];' : ''; + const xShape = inputs[0].dims; + const wShape = inputs[1].dims; + const outputChannelsPerGroup = wShape[0] / attributes.group; + + const dataType = 'f32'; // TODO: support other data type + const {activationFunction, applyActivation} = getActicationSnippet(attributes); + const inputStorageBuffersDeclarations = [ + `@group(0) @binding(0) var x : array<${dataType}>;`, + `@group(0) @binding(1) var w : array<${dataType}>;` + ]; + if (hasBias) { + inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var b : array<${dataType}>;`); + } + + const isChannelLast = attributes.format === 'NHWC'; + const outputShape = calculateOutputShape( + xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); + const outputSize = ShapeUtil.size(outputShape); + const outputIndicesHelper = createIndicesHelper('output', outputShape); + const xIndicesHelper = createIndicesHelper('x', xShape); + const wIndicesHelper = createIndicesHelper('w', wShape); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const strides: vec2 = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u); + const pads: vec2 = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u); + + ${inputStorageBuffersDeclarations.join('\n')} + @group(0) @binding(${inputStorageBuffersDeclarations.length}) var output : array<${dataType}>; + + ${activationFunction} + ${outputIndicesHelper.o2iImpl} + ${xIndicesHelper.i2oImpl} + ${wIndicesHelper.i2oImpl} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} + ${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')} + let batch: u32 = outputIndices[0]; + let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}]; + let xRCCorner: vec2 = vec2(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${ + isChannelLast ? 2 : 3}]) * strides - pads; + let group_id: u32 = output_channel / ${outputChannelsPerGroup}u; + + var value: ${dataType} = ${dataType}(0); + for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) { + let input_channel = group_id * ${wShape[1]}u + wInChannel; + for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) { + let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u; + + if (xHeight < 0u || xHeight >= ${xShape[isChannelLast ? 1 : 2]}u) { + continue; + } + + for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) { + let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u; + if (xWidth < 0u || xWidth >= ${xShape[isChannelLast ? 2 : 3]}u) { + continue; + } + + ${ + xIndicesHelper.indicesVariableDeclaration( + 'xIndices', + isChannelLast ? ['batch', 'xHeight', 'xWidth', 'input_channel'] : + [ + 'batch', 'input_channel', 'xHeight', 'xWidth' + ])} + let xVal = x[${xIndicesHelper.i2oExpression('xIndices')}]; + ${ + wIndicesHelper.indicesVariableDeclaration('wIndices', [ + 'output_channel', 'wInChannel', 'wHeight', 'wWidth' + ])} + let wVal = w[${wIndicesHelper.i2oExpression('wIndices')}]; + value += xVal*wVal; + } + } + } + ${processBias} + ${applyActivation} + output[global_idx] = value; + }`; + return { + ...metadata, + outputs: [{ + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType, + gpuDataType: GpuDataType.default + }], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +/** + * naive grouped conv implementation, supports 1d/2d conv + * @param squeezeOutputShapeFunction - an optional function to squeeze the output shape, only used in conv1d + */ +export const createGroupedConvProgramInfoLoader = + (inputs: readonly TensorView[], attributes: ConvAttributes, + squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfoLoader => { + const metadata = createGroupedConvProgramMetadata(inputs.length > 2, attributes.cacheKey); + return { + ...metadata, + get: () => createGroupedConvProgramInfo(inputs, metadata, attributes, squeezeOutputShapeFunction) + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts new file mode 100644 index 0000000000000..f333d44ea499d --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -0,0 +1,252 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {PoolConvUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext} from '../types'; + +import {createGroupedConvProgramInfoLoader} from './conv-grouped'; +import {createConv2DMatMulProgramInfoLoader} from './conv2d-mm'; +import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; +import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; + +export const calculateOutputShape = + (inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], + adjustPads: readonly number[], strides: readonly number[], isChannelLast: boolean): number[] => { + const batchSize = inputShape[0]; + const inputSpatialShape = inputShape.slice(isChannelLast ? 1 : 2, isChannelLast ? 3 : 4); + const spatialRank = inputSpatialShape.length; + const outChannels = kernelShape[0]; + const kernelSpatialShape = kernelShape.slice(2); + const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1)); + const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]); + const outputShape = + inputSpatialShapeWithPad.map((v, i) => Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i])); + outputShape.splice(0, 0, batchSize); + outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels); + return outputShape; + }; + +export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey { + readonly autoPad: string; + readonly dilations: readonly number[]; + readonly format: 'NHWC'|'NCHW'; + readonly group: number; + readonly kernelShape: readonly number[]; + readonly pads: readonly number[]; + readonly strides: readonly number[]; + readonly wIsConst: boolean; +} + +// for transposing weight tensor from [M, C/group, KH, KW] to [KH, KW, C/group, M] +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [2, 3, 1, 0]}); + +const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttributes): void => { + // Refer to the below link for all input checks + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv + if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) { + throw new Error('Conv requires 2 or 3 inputs'); + } + + // TODO : Need to add support for multi-dimensional conv + if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) { + throw new Error('currently only support conv 1D and 2D'); + } + + if (inputs[0].dims.length !== inputs[1].dims.length) { + throw new Error('filter does not have same dimension as input'); + } + + // FILTER_IN_CHANNEL should be equal to DATA_CHANNEL + const dataChannel = inputs[0].dims[attributes.format === 'NHWC' ? inputs[0].dims.length - 1 : 1]; + const filterInChannel = inputs[1].dims[1] * attributes.group; + if (dataChannel !== filterInChannel) { + throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL'); + } + + // if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps + if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[1].dims[0] !== inputs[2].dims[0])) { + throw new Error('invalid bias'); + } + + const spatialRank = inputs[0].dims.length - 2; + // wrong dilations dimension + if (attributes.dilations.length !== spatialRank) { + throw new Error(`dilations should be ${spatialRank}D`); + } + + // Wrong strides dimension + if (attributes.strides.length !== spatialRank) { + throw new Error(`strides should be ${spatialRank}D`); + } + + // Wrong pads dimension + if (attributes.pads.length !== spatialRank * 2) { + throw new Error(`pads should be ${spatialRank * 2}D`); + } + + // if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor + // (the first 2 dims are batch_size and channels) + if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) { + throw new Error('invalid kernel shape'); + } + + // TODO : Need to add support for float64 + if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { + throw new Error('Conv input(X,W) should be float tensor'); + } + + if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { + throw new Error('Conv input(bias) should be float tensor'); + } +}; + +const getAdjustedConvAttributes = (attributes: T, inputs: readonly TensorView[]): T => { + const kernelShape = attributes.kernelShape.slice(); + // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims + for (let i = 2; i < inputs[1].dims.length; ++i) { + if (kernelShape[i - 2] === 0) { + kernelShape[i - 2] = inputs[1].dims[i]; + } + } + const pads = attributes.pads.slice(); + PoolConvUtil.adjustPadsBasedOnAutoPad( + inputs[0].dims, attributes.strides, attributes.dilations, kernelShape, pads, attributes.format === 'NHWC', + attributes.autoPad); + + // always return a new object so does not modify the original attributes + const newAttributes: T = Object.assign({}, attributes); + Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey}); + return newAttributes; +}; + +export const parseConvAttributes = (attributes: Record): ConvAttributes => { + const activationAttributes = parseInternalActivationAttributes(attributes); + // TODO : Make this generic enough to compute default attributes for multi-dimensional conv + const format = attributes.format as 'NHWC' | 'NCHW'; + const autoPad = ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number]; + const dilations = attributes.dilations as [number, number]; + const group = attributes.group as number; + const kernelShape = attributes.kernel_shape as [number, number]; + const pads = attributes.pads as [number, number, number, number]; + const strides = attributes.strides as [number, number]; + const wIsConst = (attributes.w_is_const as () => boolean)(); + + return createAttributeWithCacheKey( + {autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes}); +}; + +const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): number => { + const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs); + + // check attributes + + const hasBias = inputs.length === 3; + // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ + const isChannelsLast = attributes.format === 'NHWC'; + + // const batchSize = context.inputs[0].dims[0]; + const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2]; + const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3]; + const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + const weightHeight = inputs[1].dims[2]; + const weightWidth = inputs[1].dims[3]; + + const outputShape = calculateOutputShape( + inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, + isChannelsLast); + const outHeight = outputShape[isChannelsLast ? 1 : 2]; + const outWidth = outputShape[isChannelsLast ? 2 : 3]; + const outChannels = outputShape[isChannelsLast ? 3 : 1]; + + const sameSize = + isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && attributes.autoPad === 'VALID'; + if (sameSize || + (weightHeight === 1 && weightWidth === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1 && + attributes.strides[0] === 1 && attributes.strides[1] === 1 && + (attributes.autoPad === 'SAME_UPPER' || attributes.autoPad === 'SAME_LOWER' || + attributes.autoPad === 'VALID'))) { + // TODO: implement conv2dByMatMul() + context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes)); + return 0; + } + + if (!isChannelsLast || attributes.group !== 1) { + context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes)); + return 0; + } + + // TODO: implement conv2dWithIm2Col() + + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; + + const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; + + // STEP.1: transpose weight + const transposedWeight = (context.customData.wT as TensorView | undefined) ?? + context.compute( + { + ...transposeProgramMetadata, + cacheHint: weightTransposeAttribute.cacheKey, + get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) + }, + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.customData.wT) { + context.customData.wT = transposedWeight; + } + + // STEP.2: prepare reshaped inputs + const convInputs = [inputs[0], transposedWeight]; + if (hasBias) { + if (!isChannelsLast && inputs[2].dims.length === 1) { + convInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); + } else { + convInputs.push(inputs[2]); + } + } + + // STEP.3: compute matmul + context.compute( + createConv2DMatMulProgramInfoLoader( + convInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads), + {inputs: convInputs}); + return 0; +}; + +const conv1d = (context: ComputeContext, attributes: ConvAttributes): number => { + // extend the input to 2D by adding H dimension + const isChannelLast = attributes.format === 'NHWC'; + const inputs = [ + context.inputs[0].reshape( + isChannelLast ? + // [N, W, C] -> [N, H=1, W, C] + [context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]] : + // [N, C, W] -> [N, C, H=1, W] + [context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]]), + //[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW] + context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]) + ]; + if (context.inputs.length === 3) { + inputs.push(context.inputs[2]); + } + const pads = [0, attributes.pads[0], 0, attributes.pads[1]]; + const strides = [1].concat(attributes.strides); + const dilations = [1].concat(attributes.dilations); + const kernelShape = [1].concat(attributes.kernelShape); + const adjustedAttributes = getAdjustedConvAttributes({...attributes, pads, strides, dilations, kernelShape}, inputs); + context.compute(createGroupedConvProgramInfoLoader( + inputs, adjustedAttributes, + outputShape => isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [])); + return 0; +}; + +export const conv = (context: ComputeContext, attributes: ConvAttributes): number => { + validateInputs(context.inputs, attributes); // currently will fail if not conv1D/2D + return context.inputs[0].dims.length === 3 ? conv1d(context, attributes) : + conv2d(context, context.inputs, attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts new file mode 100644 index 0000000000000..0abece9559630 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv2d-mm.ts @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; +import {ConvAttributes} from './conv'; + + +const createConv2DMatMulProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ + name: 'Conv2DMatMul', + inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default], + cacheHint +}); + +export const createConv2DMatMulProgramInfoLoader = + (inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[], dimAOuter: number, + dimBOuter: number, dimInner: number, hasBias: boolean, sequentialAccessByThreads: boolean): ProgramInfoLoader => { + const metadata = createConv2DMatMulProgramMetadata(hasBias, attributes.cacheKey); + return { + ...metadata, + get: () => createConv2DMatMulProgramInfo( + inputs, metadata, attributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads) + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts new file mode 100644 index 0000000000000..92105859a8c0e --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {MAX_CLIP, MIN_CLIP} from '../../util'; + +export interface InternalActivationAttributes { + readonly activation: string; + readonly clipMin?: number; + readonly clipMax?: number; + readonly activationCacheKey: string; +} + +export const getActicationSnippet = + (attributes: InternalActivationAttributes): {activationFunction: string; applyActivation: string} => { + switch (attributes.activation) { + case 'Relu': + return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; + case 'Sigmoid': + return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; + case 'Clip': + return { + activationFunction: + `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, + applyActivation: 'value = clamp(value, clip_min_, clip_max_);' + }; + // TODO: adding other activations that can be fused. + default: + return {activationFunction: '', applyActivation: ''}; + } + }; + +export const parseInternalActivationAttributes = + (attributes: Record|undefined): InternalActivationAttributes => { + const activation = attributes?.activation as string || ''; + + if (activation === 'Clip') { + const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; + return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`}; + } + return {activation, activationCacheKey: activation}; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts new file mode 100644 index 0000000000000..16327a18503ff --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {GemmUtil, ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {ShaderHelper} from './common'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs) { + throw new Error('Input is missing'); + } + if (inputs.length < 2 || inputs.length > 3) { + throw new Error('Invaid input number.'); + } + + // 'C' can be of dimensionality 0, 1 or 2 only + if (inputs.length === 3 && inputs[2].dims.length > 2) { + throw new Error('Invalid input shape of C'); + } + + if ((inputs[0].dataType !== DataType.float) || (inputs[1].dataType !== DataType.float) || + (inputs.length === 3 && inputs[2].dataType !== DataType.float)) { + throw new Error('Invalid input type.'); + } + + if ((inputs[0].dataType !== inputs[1].dataType) || + (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType)) { + throw new Error('Input types are mismatched'); + } +}; + +export interface GemmAttributes extends AttributeWithCacheKey { + transA: boolean; + transB: boolean; + alpha: number; + beta: number; +} + +const offsetC = (m: number, n: number, dims: readonly number[]): string => { + if (dims.length === 0) { + return '0u'; + } + + const broadcastM = (dims.length === 1 && m !== 1) || (dims.length === 2 && dims[0] !== m); + const broadcastN = dims[dims.length - 1] !== n; + + let offset = '0u'; + if (!broadcastM) { + offset += `+ m * ${dims[dims.length - 1]}u`; + } + if (!broadcastN) { + offset += '+n'; + } + + return offset; +}; + +const createGemmProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GemmAttributes): ProgramInfo => { + const aShape = inputs[0].dims.slice(); + const bShape = inputs[1].dims.slice(); + const [M, N, K] = GemmUtil.getShapeOfGemmResult( + aShape, attributes.transA, bShape, attributes.transB, inputs.length === 3 ? inputs[2].dims : undefined); + const outputShape = [M, N]; + if (!outputShape) { + throw new Error('Can\'t use gemm on the given tensors'); + } + const outputSize = ShapeUtil.size(outputShape); + let line = ''; + if (attributes.transA && attributes.transB) { + line = 'value += a[k * M + m] * b[n * K + k];'; + } else if (attributes.transA && !attributes.transB) { + line = 'value += a[k * M + m] * b[k * N + n];'; + } else if (!attributes.transA && attributes.transB) { + line = 'value += a[m * K + k] * b[n * K + k];'; + } else if (!attributes.transA && !attributes.transB) { + line = 'value += a[m * K + k] * b[k * N + n];'; + } + + const dataType = 'f32'; // TODO: support other data type + const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;'; + const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : ''; + const inputStorageBuffersDeclarations = [ + `@group(0) @binding(0) var a : array<${dataType}>;`, + `@group(0) @binding(1) var b : array<${dataType}>;` + ]; + if (inputs.length === 3) { + inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var c : array<${dataType}>;`); + } + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${M}u; + const N: u32 = ${N}u; + const K: u32 = ${K}u; + const alpha = ${dataType}(${attributes.alpha}); + const beta = ${dataType}(${attributes.beta}); + + ${inputStorageBuffersDeclarations.join('\n')} + @group(0) @binding(${inputs.length}) var output : array<${dataType}>; + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + let m = global_id.x / N; + let n = global_id.x % N; + + var value = ${dataType}(0); + for (var k: u32 = 0u; k<${K}u; k++) { + ${line} + } + + ${calculateAlpha} + ${calculateC} + output[global_id.x] = value; + + }`; + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +const createGemmProgramInfoLoader = (inputs: readonly TensorView[], attributes: GemmAttributes): ProgramInfoLoader => { + const metadata = { + name: 'Gemm', + inputTypes: inputs.length === 3 ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default], + cacheHint: attributes.cacheKey + }; + + return {...metadata, get: () => createGemmProgramInfo(metadata, inputs, attributes)}; +}; + +export const gemm = (context: ComputeContext, attributes: GemmAttributes): number => { + validateInputs(context.inputs); + context.compute(createGemmProgramInfoLoader(context.inputs, attributes)); + return 0; +}; + +export const parseGemmAttributes = (attributes: Record): GemmAttributes => + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts new file mode 100644 index 0000000000000..e78ecfa53d805 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {BroadcastUtil, ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {ShaderHelper} from './common'; +import {getActicationSnippet, InternalActivationAttributes} from './fuse-utils'; + + +const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ + name: 'MatMul', + inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default], + cacheHint +}); + +const createMatmulProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes): + ProgramInfo => { + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; + const outputShape = BroadcastUtil.calcShape(aShape, bShape, true); + if (!outputShape) { + throw new Error('Can\'t use matmul on the given tensors'); + } + const outputSize = ShapeUtil.size(outputShape); + // TODO: support broadcasting + + const dataType = 'f32'; // TODO: support other data type + const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); + + const M = outputShape[outputShape.length - 2]; + const K = aShape[aShape.length - 1]; + const N = outputShape[outputShape.length - 1]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${M}u; + const N: u32 = ${N}u; + const K: u32 = ${K}u; + + @group(0) @binding(0) var a : array<${dataType}>; + @group(0) @binding(1) var b : array<${dataType}>; + @group(0) @binding(2) var output : array<${dataType}>; + + ${activationFunction} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + let stack = global_idx / (M * N); + let mn = global_idx % (M * N); + let n = global_idx % N; + let m = mn / N; + + let offsetA = stack * (M * K); + let offsetB = stack * (K * N); + + var value = ${dataType}(0); + for (var k: u32 = 0u; k<${K}u; k++) { + value += a[offsetA + m * K + k] * b[offsetB + k * N + n]; + } + ${applyActivation} + output[global_idx] = value; + }`; + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +export const createMatmulProgramInfoLoader = + (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes): ProgramInfoLoader => { + const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey); + return {...metadata, get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes)}; + }; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 2) { + throw new Error('MatMul requires 2 inputs.'); + } + + if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) { + throw new Error('shared dimension does not match.'); + } + + if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { + throw new Error('inputs should be float type'); + } +}; + +export const matMul = (context: ComputeContext): number => { + validateInputs(context.inputs); + + context.compute(createMatmulProgramInfoLoader(context.inputs, {activation: '', activationCacheKey: ''})); + return 0; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts new file mode 100644 index 0000000000000..5c905ce1ce705 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -0,0 +1,379 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {PoolConvUtil, ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {createIndicesHelper, ShaderHelper} from './common'; + +// TODO: support: +// - ceil_mode "test_maxpool_2d_ceil" +// - storage_order "test_maxpool_with_argmax_2d_precomputed_strides" +// - [MaxPool] dilations "test_maxpool_2d_dilations" +// - [MaxPool] output[1] "test_maxpool_with_argmax_2d_precomputed_pads" + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 1) { + throw new Error('Pool ops requires 1 input.'); + } + if (inputs[0].dims.length !== 4) { + throw new Error('Pool ops supports 2-D inputs only for now.'); + } + if (inputs[0].dataType !== DataType.float) { + throw new Error('Invalid input type.'); + } +}; + +const getAdjustedPoolAttributesAndOutputShape = ( + inputs: readonly TensorView[], attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => { + const isChannelsLast = attributes.format === 'NHWC'; + const inputShapeAsChannelFirst = isChannelsLast ? + [inputs[0].dims[0], inputs[0].dims[3], inputs[0].dims[1], inputs[0].dims[2]] : + inputs[0].dims.slice(); + const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations'); + const kernelShape = attributes.kernelShape.slice(); + const strides = attributes.strides.slice(); + const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : []; + const pads = attributes.pads.slice(); + PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShapeAsChannelFirst, kernelShape, strides, dilations, pads); + + const outputShapeAsChannelFirst = PoolConvUtil.computePoolOutputShape( + isGlobalOperator, inputShapeAsChannelFirst, strides, dilations, kernelShape, pads, attributes.autoPad); + + const newAttributes = Object.assign({}, attributes); + if (hasDilations) { + Object.assign(newAttributes, {kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey}); + } else { + Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey}); + } + return [ + newAttributes, + isChannelsLast ? + [ + outputShapeAsChannelFirst[0], outputShapeAsChannelFirst[2], outputShapeAsChannelFirst[3], + outputShapeAsChannelFirst[1] + ] : + outputShapeAsChannelFirst + ]; +}; + +const generatePoolingCode = ( + shaderHelper: ShaderHelper, inputDims: readonly number[], outputShape: readonly number[], attributes: AttributeType, + op1: string, op2: string, dataType: string, start: string): string => { + const isChannelsLast = attributes.format === 'NHWC'; + const rank = inputDims.length; + const outputSize = ShapeUtil.size(outputShape); + const outputIndicesHelper = createIndicesHelper('output', outputShape); + const xIndicesHelper = createIndicesHelper('x', inputDims); + + if (attributes.kernelShape.length <= 2) { + const kw = attributes.kernelShape[attributes.kernelShape.length - 1]; + const sw = attributes.strides[attributes.strides.length - 1]; + const pwStart = attributes.pads[attributes.pads.length / 2 - 1]; + const pwEnd = attributes.pads[attributes.pads.length - 1]; + const dimIdxW = rank - (isChannelsLast ? 2 : 1); + let codeW = ''; + let codeH = ''; + let codeHEnd = ''; + if (pwStart + pwEnd !== 0) { + codeW = ` + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) { + pad++; + continue; + } + let x_val = x[${xIndicesHelper.i2oExpression('xIndices')}]; + ${op1} + }`; + } else { + codeW = ` + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + let x_val = x[${xIndicesHelper.i2oExpression('xIndices')}]; + ${op1} + }`; + } + + if (attributes.kernelShape.length === 2) { + const kh = attributes.kernelShape[attributes.kernelShape.length - 2]; + const sh = attributes.strides[attributes.strides.length - 2]; + const phStart = attributes.pads[attributes.pads.length / 2 - 2]; + const phEnd = attributes.pads[attributes.pads.length - 2]; + const dimIdxH = rank - (isChannelsLast ? 3 : 2); + const dimH = inputDims[dimIdxH]; + if (phStart + phEnd !== 0) { + codeH = ` + for (var j: u32 = 0u; j < ${kh}u; j++) { + xIndices[${dimIdxH}] = indices[${dimIdxH}] * ${sh} - ${phStart} + j; + if (xIndices[${dimIdxH}] < 0 || xIndices[${dimIdxH}] >= ${dimH}) { + pad+= ${kw}; + continue; + } + `; + } else { + codeH = ` + for (var j: u32 = 0u; j < ${kh}u; j++) { + xIndices[${dimIdxH}] = indices[${dimIdxH}] * ${sh} - ${phStart} + j; + `; + } + codeHEnd = ` + } + `; + } + + const poolingCode = ` + @group(0) @binding(0) var x : array<${dataType}>; + @group(0) @binding(1) var output : array<${dataType}>; + + ${outputIndicesHelper.o2iImpl} + ${xIndicesHelper.i2oImpl} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + ${outputIndicesHelper.indicesVariableDeclaration('indices')} + ${outputIndicesHelper.o2iCall('global_idx', 'indices')} + ${outputIndicesHelper.indicesVariableDeclaration('xIndices')} + ${outputIndicesHelper.o2iCall('global_idx', 'xIndices')} + + var value: ${dataType} = ${dataType}(${start}); + var pad = 0; + ${codeH} + ${codeW} + ${codeHEnd} + ${op2} + + output[global_idx] = value; + }`; + return poolingCode; + } else { + if (isChannelsLast) { + throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.'); + } + const kernelSize = ShapeUtil.size(attributes.kernelShape); + const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape); + const stridesRank = kernelStrides.length; + const padsRank = attributes.pads.length; + const hasPads = attributes.pads.reduce((sum, cur) => sum + cur); + let padCode = ''; + if (hasPads) { + padCode = ` + if (xIndices[j] >= inputDims[j]) { + pad++; + isPad = true; + break; + } + } + if (!isPad) { + let x_val = x[${xIndicesHelper.i2oExpression('xIndices')}]; + ${op1} + }`; + } else { + padCode = ` + } + let x_val = x[${xIndicesHelper.i2oExpression('xIndices')}]; + ${op1} + `; + } + const poolingCode = ` + @group(0) @binding(0) var x : array<${dataType}>; + @group(0) @binding(1) var output : array<${dataType}>; + + ${outputIndicesHelper.o2iImpl} + ${xIndicesHelper.i2oImpl} + + const pads = array(${attributes.pads.map(i => `${i}u`).join(',')}); + const inputDims = array(${inputDims.map(i => `${i}u`).join(',')}); + const kernelStrides = array(${kernelStrides.map(i => `${i}u`).join(',')}); + const strides = array(${attributes.strides.map(i => `${i}u`).join(',')}); + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + ${outputIndicesHelper.indicesVariableDeclaration('indices')} + ${outputIndicesHelper.o2iCall('global_idx', 'indices')} + ${outputIndicesHelper.indicesVariableDeclaration('xIndices')} + ${outputIndicesHelper.o2iCall('global_idx', 'xIndices')} + + var offsets: array; + + var value = ${dataType}(${start}); + var pad = 0; + var isPad = false; + + for (var i: u32 = 0u; i < ${kernelSize}u; i++) { + var offset = i; + for (var j = 0u; j < ${stridesRank - 1}u; j++) { + offsets[j] = offset / kernelStrides[j]; + offset -= offsets[j] * kernelStrides[j]; + } + offsets[${stridesRank - 1}] = offset; + + isPad = false; + for (var j = ${rank - stridesRank}u; j < ${rank}u; j++) { + xIndices[j] = indices[j] * strides[j - ${rank - stridesRank}u] + + offsets[j - ${rank - stridesRank}u] - pads[j - 2u]; + ${padCode} + } + ${op2} + + output[global_idx] = value; + }`; + return poolingCode; + } +}; + +export interface FormatAttributes { + readonly format: 'NHWC'|'NCHW'; +} + +export interface PoolCommonAttributes extends FormatAttributes { + readonly autoPad: string; + readonly ceilMode: number; + readonly kernelShape: readonly number[]; + readonly strides: readonly number[]; + readonly pads: readonly number[]; +} + +const parsePoolCommonAttributes = (attributes: Record): PoolCommonAttributes => ({ + format: attributes.format as FormatAttributes['format'], + autoPad: ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number], + ceilMode: attributes.ceil_mode as number, + kernelShape: attributes.kernel_shape as [number, number], + strides: attributes.strides as [number, number], + pads: attributes.pads as [number, number, number, number] +}); + +export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey { + readonly countIncludePad: boolean; +} + +const createAveragePoolProgramInfo = + (inputs: readonly TensorView[], metadata: ProgramMetadata, isGlobalOperator: boolean, + attributes: AveragePoolAttributes): ProgramInfo => { + const [adjustedAttributes, outputShape] = + getAdjustedPoolAttributesAndOutputShape(inputs, attributes, isGlobalOperator); + const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape); + + const dataType = 'f32'; + + const op1 = 'value += x_val;'; + let op2 = ''; + if (adjustedAttributes.countIncludePad) { + op2 += `value /= ${dataType}(${kernelSize});`; + } else { + op2 += `value /= ${dataType}(${kernelSize} - pad);`; + } + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource: shaderHelper => generatePoolingCode( + shaderHelper, inputs[0].dims, outputShape, adjustedAttributes, op1, op2, dataType, '0.0'), + dispatchGroup: () => ({x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}) + }; + }; + +export const parseAveragePoolAttributes = (attributes: Record): AveragePoolAttributes => { + const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true; + + const attr = parsePoolCommonAttributes(attributes); + // TODO: support attribute 'ceil_mode' + if (attr.ceilMode !== 0) { + throw new Error('using ceil() in shape computation is not yet supported for AveragePool'); + } + + return createAttributeWithCacheKey({countIncludePad, ...attr}); +}; + +export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): number => { + validateInputs(context.inputs); + const metadata = {name: 'AveragePool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey}; + context.compute({...metadata, get: () => createAveragePoolProgramInfo(context.inputs, metadata, false, attributes)}); + return 0; +}; + +const globalPoolAttributes = { + autoPad: '', + ceilMode: 0, + countIncludePad: false, + kernelShape: [], + strides: [], + pads: [], + storageOrder: 0, + dilations: [], + cacheKey: '' +}; + +export const parseGlobalAveragePoolAttributes = (attributes: Record): AveragePoolAttributes => { + const format = attributes.format as FormatAttributes['format']; + return {format, ...globalPoolAttributes, cacheKey: format}; +}; + +export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): number => { + validateInputs(context.inputs); + const metadata = {name: 'GlobalAveragePool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey}; + context.compute({...metadata, get: () => createAveragePoolProgramInfo(context.inputs, metadata, true, attributes)}); + return 0; +}; + +export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey { + readonly storageOrder: number; + readonly dilations: number[]; +} + +const createMaxPoolProgramInfo = + (inputs: readonly TensorView[], metadata: ProgramMetadata, isGlobalOperator: boolean, + attributes: MaxPoolAttributes): ProgramInfo => { + const [adjustedAttributes, outputShape] = + getAdjustedPoolAttributesAndOutputShape(inputs, attributes, isGlobalOperator); + const op1 = ` + value = max(x_val, value); + `; + const op2 = ''; + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource: shaderHelper => + generatePoolingCode(shaderHelper, inputs[0].dims, outputShape, adjustedAttributes, op1, op2, 'f32', '-1e5'), + dispatchGroup: () => ({x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}) + }; + }; + +export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): number => { + validateInputs(context.inputs); + const metadata = {name: 'MaxPool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey}; + context.compute({...metadata, get: () => createMaxPoolProgramInfo(context.inputs, metadata, false, attributes)}); + return 0; +}; + +export const parseMaxPoolAttributes = (attributes: Record): MaxPoolAttributes => { + const storageOrder = attributes.storage_order as number; + const dilations = attributes.dilations as [number, number]; + + const attr = parsePoolCommonAttributes(attributes); + // TODO: support attribute 'ceil_mode' and 'storage_order' + if (storageOrder !== 0) { + throw new Error('column major storage order is not yet supported for MaxPool'); + } + if (attr.ceilMode !== 0) { + throw new Error('using ceil() in shape computation is not yet supported for MaxPool'); + } + + return createAttributeWithCacheKey({storageOrder, dilations, ...attr}); +}; + +export const parseGlobalMaxPoolAttributes = (attributes: Record): MaxPoolAttributes => { + const format = attributes.format as FormatAttributes['format']; + return {format, ...globalPoolAttributes, cacheKey: format}; +}; + +export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): number => { + validateInputs(context.inputs); + const metadata = {name: 'GlobalMaxPool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey}; + context.compute({...metadata, get: () => createMaxPoolProgramInfo(context.inputs, metadata, true, attributes)}); + return 0; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts new file mode 100644 index 0000000000000..24a2d7bf8c0e3 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; + +import {createIndicesHelper, ShaderHelper} from './common'; + +export interface TransposeAttributes extends AttributeWithCacheKey { + readonly perm: number[]; +} + +export const transposeProgramMetadata = { + name: 'Transpose', + inputTypes: [GpuDataType.default] +}; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 1) { + throw new Error('Transpose requires 1 input.'); + } + + if (inputs[0].dataType !== DataType.float) { + throw new Error('input should be float tensor'); + } +}; + +const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] => + (perm && perm.length !== inputShape.length) ? [...(inputShape.keys())].reverse() : perm; + +const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] => + ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape, perm)); + +const permFunctionBody = (perm: number[], rank: number): string => { + const reverseFunc = []; + reverseFunc.push(`fn perm(a: ptr>, i: ptr>) {`); + for (let i = 0; i < rank; ++i) { + reverseFunc.push(`\t(*a)[${perm[i]}]=(*i)[${i}];`); + } + reverseFunc.push('\t}'); + return reverseFunc.join('\n'); +}; + +export const createTransposeProgramInfo = (input: TensorView, permAttr: number[]): ProgramInfo => { + const dataType = 'f32'; // TODO: support other data type + const inputShape = input.dims; + const perm = getAdjustedPerm(inputShape, permAttr); + const outputShape = getOutputShape(inputShape, perm); + const rank = inputShape.length; + const outputSize = ShapeUtil.size(outputShape); + // A dims=[${inputs[0].dims.toString()}] + // out Dims=[${unpackedOutputShape.toString()}] + // based on perm=[${perm.toString()}] + + const outputIndicesHelper = createIndicesHelper('output', outputShape); + const inputIndicesHelper = createIndicesHelper('a', inputShape); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + @group(0) @binding(0) var a : array<${dataType}>; + @group(0) @binding(1) var output : array<${dataType}>; + + ${permFunctionBody(perm, rank)} + ${outputIndicesHelper.o2iImpl} + ${inputIndicesHelper.i2oImpl} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + ${outputIndicesHelper.indicesVariableDeclaration('indices')} + ${outputIndicesHelper.o2iCall('global_idx', 'indices')} + ${inputIndicesHelper.indicesVariableDeclaration('aIndices')} + perm(&aIndices, &indices); + + output[global_idx] = a[${inputIndicesHelper.i2oExpression('aIndices')}]; + }`; + return { + ...transposeProgramMetadata, + outputs: [{dims: outputShape, dataType: input.dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; +}; + +export const transpose = (context: ComputeContext, attributes: TransposeAttributes): number => { + validateInputs(context.inputs); + context.compute({ + ...transposeProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createTransposeProgramInfo(context.inputs[0], attributes.perm) + }); + return 0; +}; + +export const parseTransposeAttributes = (attributes: Record): TransposeAttributes => + createAttributeWithCacheKey({perm: attributes.perm as number[]}); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts new file mode 100644 index 0000000000000..93643914609ee --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {ShaderHelper} from './common'; + +type BuiltinFunctionName = string; +type ElementwiseCustomExpression = (expression: string) => string; +type ElementwiseFunctionCall = BuiltinFunctionName|ElementwiseCustomExpression; + +const createElementwiseProgramShader = + (shaderHelper: ShaderHelper, datasize: number, funcCall: ElementwiseFunctionCall, + additionalImplementation?: string): string => { + const vecSize = Math.ceil(datasize / 4); + + let expression = ''; + if (typeof funcCall === 'string') { + expression = `${funcCall}(a)`; + } else { + expression = funcCall('a'); + } + return ` + @group(0) @binding(0) var inputData : array>; + @group(0) @binding(1) var outputData : array>; + + ${additionalImplementation ?? ''} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + + let a = inputData[global_idx]; + outputData[global_idx] = ${expression}; + }`; + }; + +const createElementwiseProgramInfo = + (metadata: ProgramMetadata, input: TensorView, funcCall: ElementwiseFunctionCall, + additionalImplementation?: string): ProgramInfo => ({ + ...metadata, + getShaderSource: shaderHelper => + createElementwiseProgramShader(shaderHelper, ShapeUtil.size(input.dims), funcCall, additionalImplementation), + outputs: [{dims: input.dims, dataType: input.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: (inputTensors) => + ({x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}) + }); + +const createElementwiseProgramInfoLoader = + (input: TensorView, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string, + cacheKey?: string): ProgramInfoLoader => { + const metadata: ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint: cacheKey}; + return { + ...metadata, + get: () => createElementwiseProgramInfo(metadata, input, funcCall, additionalImplementation) + }; + }; + +export const abs = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Abs', 'abs')); + return 0; +}; + +export const acos = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Acos', 'acos')); + return 0; +}; + +export const acosh = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Acosh', 'acosh')); + return 0; +}; + +export const asin = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Asin', 'asin')); + return 0; +}; + +export const asinh = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Asinh', 'asinh')); + return 0; +}; + +export const atan = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Atan', 'atan')); + return 0; +}; +export const atanh = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Atanh', 'atanh')); + return 0; +}; + +export interface ClipAttributes extends AttributeWithCacheKey { + readonly min: number; + readonly max: number; +} + +export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): number => { + context.compute( + createElementwiseProgramInfoLoader( + context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` + const clip_min_: vec4 = vec4(f32(${attributes.min})); + const clip_max_: vec4 = vec4(f32(${attributes.max})); +`, + attributes.cacheKey), + {inputs: [0]}); + return 0; +}; +const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { + const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; + const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + return createAttributeWithCacheKey({min, max}); +}; + +export const clip = (context: ComputeContext): number => { + const attributes = generateClipAttributesFromInputs(context.inputs); + return clipV10(context, attributes); +}; + +export const ceil = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Ceil', 'ceil')); + return 0; +}; + +export const cos = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cos', 'cos')); + return 0; +}; + +export const cosh = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cosh', 'cosh')); + return 0; +}; + +export interface EluAttributes extends AttributeWithCacheKey { + readonly alpha: number; +} + +export const elu = (context: ComputeContext, attributes: EluAttributes): number => { + context.compute(createElementwiseProgramInfoLoader( + context.inputs[0], 'Elu', a => `elu_vf32(${a})`, ` + const elu_alpha_: f32 = f32(${attributes.alpha}); + + fn elu_f32(a: f32) -> f32 { + return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0); + } + + fn elu_vf32(v: vec4) -> vec4 { + return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w)); + }`, + attributes.cacheKey)); + return 0; +}; + +export const parseEluAttributes = (attributes: Record): EluAttributes => + createAttributeWithCacheKey(attributes as {alpha: number}); + +export const erf = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, ` + const r0: f32 = 0.3275911; + const r1: f32 = 0.254829592; + const r2: f32 = -0.284496736; + const r3: f32 = 1.421413741; + const r4: f32 = -1.453152027; + const r5: f32 = 1.061405429; + + fn erf_vf32(v: vec4) -> vec4 { + let absv = abs(v); + let x = 1.0 / (1.0 + r0 * absv); + return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); + }`)); + return 0; +}; + +export const floor = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Floor', 'floor')); + return 0; +}; + +export const neg = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Neg', a => `-${a}`)); + return 0; +}; + +export const reciprocal = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Reciprocal', a => `1.0/${a}`)); + return 0; +}; + +export const sigmoid = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); + return 0; +}; + +export const sin = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sin', 'sin')); + return 0; +}; + +export const sinh = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sinh', 'sinh')); + return 0; +}; + +export const sqrt = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sqrt', 'sqrt')); + return 0; +}; + +export const tan = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tan', 'tan')); + return 0; +}; + +export const tanh = (context: ComputeContext): number => { + context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tanh', 'tanh')); + return 0; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts new file mode 100644 index 0000000000000..951e76de5449e --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {WebGpuBackend} from '../backend-webgpu'; +import {LOG_DEBUG} from '../log'; + +import {createShaderHelper} from './ops/common'; +import {Artifact, GpuData, ProgramInfo} from './types'; + +/** + * ProgramManager is the main class behind running computations + * It builds ProgramInfo's into Artifacts + * It compiles given ProgramInfo's into WebGL Prorams (cached as Artifacts) + * Uses the artifact to run the computation by calling Draw on + * the WebGL drawing buffer + * ProgramManager automatically maps (binds) input variables to their + * corresponding Location's in the binary program + */ +export class ProgramManager { + repo: Map; // this should be per-session object + attributesBound: boolean; + + constructor(private backend: WebGpuBackend) { + this.repo = new Map(); + this.attributesBound = false; + } + getArtifact(key: unknown): Artifact|undefined { + return this.repo.get(key); + } + setArtifact(key: unknown, artifact: Artifact): void { + this.repo.set(key, artifact); + } + run(buildArtifact: Artifact, inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number]): void { + const device = this.backend.device; + const computePassEncoder = this.backend.getComputePassEncoder(); + + if (this.backend.profilingEnabled) { + // profiling write start timestamp + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 0); + } + + computePassEncoder.setPipeline(buildArtifact.computePipeline); + const entries = []; + for (const input of inputs) { + entries.push({binding: entries.length, resource: {buffer: input.buffer}}); + } + for (const output of outputs) { + entries.push({binding: entries.length, resource: {buffer: output.buffer}}); + } + const bindGroup = device.createBindGroup({layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries}); + computePassEncoder.setBindGroup(0, bindGroup); + + computePassEncoder.dispatchWorkgroups(...dispatchGroup); + + this.backend.pendingDispatchNumber++; + + if (this.backend.profilingEnabled) { + // profiling write end timestamp + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 1); + // eslint-disable-next-line no-bitwise + const queryData = this.backend.gpuDataManager.create(16, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE); + // eslint-disable-next-line no-bitwise + const syncData = this.backend.gpuDataManager.create(16, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); + + this.backend.endComputePass(); + this.backend.getCommandEncoder().resolveQuerySet(this.backend.profilingQuerySet, 0, 2, queryData.buffer, 0); + this.backend.getCommandEncoder().copyBufferToBuffer(queryData.buffer, 0, syncData.buffer, 0, 16); + this.backend.flush(); + + const kernelId = this.backend.currentKernelId!; + const kernelName = this.backend.kernels.get(kernelId)![0]; + + syncData.buffer.mapAsync(GPUMapMode.READ).then(() => { + const mappedData = new BigUint64Array(syncData.buffer.getMappedRange()); + const startTimeU64 = mappedData[0]; + const endTimeU64 = mappedData[1]; + + syncData.buffer.unmap(); + + if (typeof this.backend.profilingTimeBase === 'undefined') { + this.backend.profilingTimeBase = startTimeU64; + } + + const startTime = Number(startTimeU64 - this.backend.profilingTimeBase); + const endTime = Number(endTimeU64 - this.backend.profilingTimeBase); + + if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) { + throw new RangeError('incorrect timestamp range'); + } + + this.backend.gpuDataManager.release(queryData.id); + this.backend.gpuDataManager.release(syncData.id); + + // eslint-disable-next-line no-console + console.log(`[profiling] kernel "${kernelId}|${kernelName}" execution time: ${endTime - startTime} ns`); + }); + } + + if (this.backend.pendingDispatchNumber >= 16) { + this.backend.flush(); + } + } + dispose(): void { + // this.repo.forEach(a => this.glContext.deleteProgram(a.program)); + } + build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact { + const device = this.backend.device; + + const code = programInfo.getShaderSource(createShaderHelper(normalizedDispatchGroupSize)); + const shaderModule = device.createShaderModule({code}); + LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`); + + const computePipeline = + device.createComputePipeline({compute: {module: shaderModule, entryPoint: 'main'}, layout: 'auto'}); + + return {programInfo, computePipeline}; + } + + normalizeDispatchGroupSize(dispatchGroup: ReturnType): [number, number, number] { + const x = typeof dispatchGroup === 'number' ? dispatchGroup : dispatchGroup.x; + const y = typeof dispatchGroup === 'number' ? 1 : (dispatchGroup.y || 1); + const z = typeof dispatchGroup === 'number' ? 1 : (dispatchGroup.z || 1); + const limitPerDimension = this.backend.device.limits.maxComputeWorkgroupsPerDimension; + if (x <= limitPerDimension && y <= limitPerDimension && z <= limitPerDimension) { + return [x, y, z]; + } + const size = x * y * z; + let dispatchAverage = Math.ceil(Math.sqrt(size)); + if (dispatchAverage > limitPerDimension) { + dispatchAverage = Math.ceil(Math.cbrt(size)); + if (dispatchAverage > limitPerDimension) { + throw new Error('Total dispatch size exceeds WebGPU maximum.'); + } + return [dispatchAverage, dispatchAverage, dispatchAverage]; + } else { + return [dispatchAverage, dispatchAverage, 1]; + } + } +} diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts new file mode 100644 index 0000000000000..634e3a167184f --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {Tensor, TensorView} from '../tensor'; + +import {ShaderHelper} from './ops/common'; + +export enum GpuDataType { + default = 0, + upload = 1, + profile = 2 +} +export type GpuDataId = number; + +export interface GpuData { + type: GpuDataType; + id: GpuDataId; + buffer: GPUBuffer; +} + +export interface TensorInfo { + id?: Tensor.Id; + dims: readonly number[]; + dataType: number; + gpuDataType: GpuDataType; +} + + +export interface ProgramVariable { + type: 'float'|'int'; + name: string; + arrayLength?: number; + data: number|number[]; +} + + +export interface ProgramMetadata { + /** + * the name of the program. used for debugging and profiling + */ + name: string; + + /** + * gpu data types for each input + */ + inputTypes: GpuDataType[]; + /** + * an optional string as a cache hint in the artifact cache + */ + cacheHint?: string; +} + +/** + * A ProgramInfoLoader allows + */ +export interface ProgramInfoLoader extends ProgramMetadata { + /** + * a function to get the program info + */ + get(): ProgramInfo; +} + +/** + * A set of data that represent a shader program + */ +export interface ProgramInfo extends ProgramMetadata { + /** + * information of uniform variables + */ + variables?: ProgramVariable[]; + /** + * tensor info for outputs + */ + outputs: TensorInfo[]; + /** + * the shader's processing source code + */ + getShaderSource: (shaderHelper: ShaderHelper) => string; + /** + * default is "main" + */ + // entryPoint: string; + + dispatchGroup: (inputs: readonly TensorView[]) => { + x: number; + y?: number; + z?: number; + }; +} + +export interface Artifact { + programInfo: ProgramInfo; + computePipeline: GPUComputePipeline; +} + +export interface ComputeContextInputsOutputsMapping { + /** + * specify the mapping to the program's inputs. the value can be a number or a tensor view. + * - if it's a number, it's the index of the kernel's input + * - if it's a tensor view, it's an existing tensor view that will be used as the input + * + * if inputs is not specified, the mapping will be the kernel's inputs in order. + */ + readonly inputs?: ReadonlyArray; + /** + * specify the mapping to the program's outputs. the value must be a number. + * - if it's a non-negative number, it's the index of the kernel's output + * - if it's -1, it's an output that will be created as a temporary value. this value will be released after + * the kernel is executed. + * - if it's -2, it's an output that will be created as a persistent value. this value will be released when the + * kernel is released. + * + * if outputs is not specified, the mapping will be the kernel's outputs in order. + */ + readonly outputs?: readonly number[]; +} + +/** + * A ComputeContext instance carries the states that representing the current running of a kernel. + */ +export interface ComputeContext { + /** + * stores the pointer to OpKernelContext + */ + readonly opKernelContext: number; + + /** + * a list of inputs, each input is an instance of TensorView + */ + readonly inputs: readonly TensorView[]; + + /** + * a custom data object that can be used to store any data that is needed by the kernel + */ + readonly customData: {[key: string]: unknown}; + + compute(program: ProgramInfoLoader|ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): + TensorView[]; + output(index: number, dims: readonly number[]): number; +} diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index b72bfe42c6812..9a247c56189a8 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -63,8 +63,14 @@ self.onmessage = (ev: MessageEvent): void => { case 'run': try { const {sessionId, inputIndices, inputs, outputIndices, options} = ev.data.in!; - const outputs = run(sessionId, inputIndices, inputs, outputIndices, options); - postMessage({type: 'run', out: outputs} as OrtWasmMessage, extractTransferableBuffers(outputs)); + run(sessionId, inputIndices, inputs, outputIndices, options) + .then( + outputs => { + postMessage({type: 'run', out: outputs} as OrtWasmMessage, extractTransferableBuffers(outputs)); + }, + err => { + postMessage({type: 'run', err} as OrtWasmMessage); + }); } catch (err) { postMessage({type: 'run', err} as OrtWasmMessage); } diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 1e04fadd908b8..02c61800d9e66 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -3,9 +3,10 @@ import {env, InferenceSession} from 'onnxruntime-common'; +import {init as initJsep} from './jsep/init'; import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; import * as core from './wasm-core-impl'; -import {initializeWebAssembly} from './wasm-factory'; +import {getInstance, initializeWebAssembly} from './wasm-factory'; const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined'; let proxyWorker: Worker|undefined; @@ -141,9 +142,14 @@ export const initOrt = async(numThreads: number, loggingLevel: number): Promise< initOrtCallbacks = [resolve, reject]; const message: OrtWasmMessage = {type: 'init-ort', in : {numThreads, loggingLevel}}; proxyWorker!.postMessage(message); + + // TODO: support JSEP in worker }); } else { core.initOrt(numThreads, loggingLevel); + + // init JSEP if available + await initJsep(getInstance()); } }; diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index a507b09e89315..038a46d82e61a 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -7,28 +7,10 @@ import {promisify} from 'util'; import {SerializableModeldata} from './proxy-messages'; import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initOrt, releaseSession, run} from './proxy-wrapper'; +import {logLevelStringToEnum} from './wasm-common'; let ortInit: boolean; - -const getLogLevel = (logLevel: 'verbose'|'info'|'warning'|'error'|'fatal'): number => { - switch (logLevel) { - case 'verbose': - return 0; - case 'info': - return 1; - case 'warning': - return 2; - case 'error': - return 3; - case 'fatal': - return 4; - default: - throw new Error(`unsupported logging level: ${logLevel}`); - } -}; - - export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { private sessionId: number; @@ -45,7 +27,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { if (!ortInit) { - await initOrt(env.wasm.numThreads!, getLogLevel(env.logLevel!)); + await initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel!)); ortInit = true; } diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 4d32ced2391db..10cc48257dc52 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -45,6 +45,12 @@ const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => // eslint-disable-next-line camelcase session.use_ort_model_bytes_directly = '1'; } + + // if using JSEP with WebGPU, always disable memory pattern + if (options.executionProviders && + options.executionProviders.some(ep => (typeof ep === 'string' ? ep : ep.name) === 'webgpu')) { + options.enableMemPattern = false; + } }; const setExecutionProviders = @@ -58,6 +64,9 @@ const setExecutionProviders = case 'xnnpack': epName = 'XNNPACK'; break; + case 'webgpu': + epName = 'JS'; + break; case 'wasm': case 'cpu': continue; diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts new file mode 100644 index 0000000000000..d0df08419fb5d --- /dev/null +++ b/js/web/lib/wasm/wasm-common.ts @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {Tensor} from 'onnxruntime-common'; + +/** + * Copied from ONNX definition. Use this to drop dependency 'onnx_proto' to decrease compiled .js file size. + */ +export const enum DataType { + undefined = 0, + float = 1, + uint8 = 2, + int8 = 3, + uint16 = 4, + int16 = 5, + int32 = 6, + int64 = 7, + string = 8, + bool = 9, + float16 = 10, + double = 11, + uint32 = 12, + uint64 = 13, + complex64 = 14, + complex128 = 15, + bfloat16 = 16 +} + +/** + * Map string tensor data to enum value + */ +export const tensorDataTypeStringToEnum = (type: string): DataType => { + switch (type) { + case 'int8': + return DataType.int8; + case 'uint8': + return DataType.uint8; + case 'bool': + return DataType.bool; + case 'int16': + return DataType.int16; + case 'uint16': + return DataType.uint16; + case 'int32': + return DataType.int32; + case 'uint32': + return DataType.uint32; + case 'float32': + return DataType.float; + case 'float64': + return DataType.double; + case 'string': + return DataType.string; + case 'int64': + return DataType.int64; + case 'uint64': + return DataType.uint64; + + default: + throw new Error(`unsupported data type: ${type}`); + } +}; + +/** + * Map enum value to string tensor data + */ +export const tensorDataTypeEnumToString = (typeProto: DataType): Tensor.Type => { + switch (typeProto) { + case DataType.int8: + return 'int8'; + case DataType.uint8: + return 'uint8'; + case DataType.bool: + return 'bool'; + case DataType.int16: + return 'int16'; + case DataType.uint16: + return 'uint16'; + case DataType.int32: + return 'int32'; + case DataType.uint32: + return 'uint32'; + case DataType.float: + return 'float32'; + case DataType.double: + return 'float64'; + case DataType.string: + return 'string'; + case DataType.int64: + return 'int64'; + case DataType.uint64: + return 'uint64'; + + default: + throw new Error(`unsupported data type: ${typeProto}`); + } +}; + +/** + * get tensor element size in bytes by the given data type + * @returns size in integer or undefined if the data type is not supported + */ +export const getTensorElementSize = (dateType: number): number| + undefined => [undefined, 4, 1, 1, 2, 2, 4, 8, undefined, 1, 2, 8, 4, 8, undefined, undefined, undefined][dateType]; + +/** + * get typed array constructor by the given tensor type + */ +export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32ArrayConstructor|Uint8ArrayConstructor| + Int8ArrayConstructor|Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor| + Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => { + switch (type) { + case 'float32': + return Float32Array; + case 'uint8': + return Uint8Array; + case 'int8': + return Int8Array; + case 'uint16': + return Uint16Array; + case 'int16': + return Int16Array; + case 'int32': + return Int32Array; + case 'bool': + return Uint8Array; + case 'float64': + return Float64Array; + case 'uint32': + return Uint32Array; + case 'int64': + return BigInt64Array; + case 'uint64': + return BigUint64Array; + default: + throw new Error(`unsupported type: ${type}`); + } + }; + +/** + * Map string log level to integer value + */ +export const logLevelStringToEnum = (logLevel: 'verbose'|'info'|'warning'|'error'|'fatal'): number => { + switch (logLevel) { + case 'verbose': + return 0; + case 'info': + return 1; + case 'warning': + return 2; + case 'error': + return 3; + case 'fatal': + return 4; + default: + throw new Error(`unsupported logging level: ${logLevel}`); + } +}; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 3334a56d488e3..16291aeb500ac 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -7,6 +7,7 @@ import {SerializableModeldata, SerializableSessionMetadata, SerializableTensor} import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; import {allocWasmString} from './string-utils'; +import {tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; /** @@ -118,292 +119,179 @@ export const releaseSession = (sessionId: number): void => { }; /** - * Copied from ONNX definition. Use this to drop dependency 'onnx_proto' to decrease compiled .js file size. + * perform inference run */ -const enum DataType { - undefined = 0, - float = 1, - uint8 = 2, - int8 = 3, - uint16 = 4, - int16 = 5, - int32 = 6, - int64 = 7, - string = 8, - bool = 9, - float16 = 10, - double = 11, - uint32 = 12, - uint64 = 13, - complex64 = 14, - complex128 = 15, - bfloat16 = 16 -} - - -const tensorDataTypeStringToEnum = (type: string): DataType => { - switch (type) { - case 'int8': - return DataType.int8; - case 'uint8': - return DataType.uint8; - case 'bool': - return DataType.bool; - case 'int16': - return DataType.int16; - case 'uint16': - return DataType.uint16; - case 'int32': - return DataType.int32; - case 'uint32': - return DataType.uint32; - case 'float32': - return DataType.float; - case 'float64': - return DataType.double; - case 'string': - return DataType.string; - case 'int64': - return DataType.int64; - case 'uint64': - return DataType.uint64; - - default: - throw new Error(`unsupported data type: ${type}`); +export const run = async( + sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], + options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + const session = activeSessions.get(sessionId); + if (!session) { + throw new Error('invalid session id'); } -}; + const sessionHandle = session[0]; + const inputNamesUTF8Encoded = session[1]; + const outputNamesUTF8Encoded = session[2]; -const tensorDataTypeEnumToString = (typeProto: DataType): Tensor.Type => { - switch (typeProto) { - case DataType.int8: - return 'int8'; - case DataType.uint8: - return 'uint8'; - case DataType.bool: - return 'bool'; - case DataType.int16: - return 'int16'; - case DataType.uint16: - return 'uint16'; - case DataType.int32: - return 'int32'; - case DataType.uint32: - return 'uint32'; - case DataType.float: - return 'float32'; - case DataType.double: - return 'float64'; - case DataType.string: - return 'string'; - case DataType.int64: - return 'int64'; - case DataType.uint64: - return 'uint64'; - - default: - throw new Error(`unsupported data type: ${typeProto}`); - } -}; + const inputCount = inputIndices.length; + const outputCount = outputIndices.length; -const numericTensorTypeToTypedArray = (type: Tensor.Type): Float32ArrayConstructor|Uint8ArrayConstructor| - Int8ArrayConstructor|Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor| - Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => { - switch (type) { - case 'float32': - return Float32Array; - case 'uint8': - return Uint8Array; - case 'int8': - return Int8Array; - case 'uint16': - return Uint16Array; - case 'int16': - return Int16Array; - case 'int32': - return Int32Array; - case 'bool': - return Uint8Array; - case 'float64': - return Float64Array; - case 'uint32': - return Uint32Array; - case 'int64': - return BigInt64Array; - case 'uint64': - return BigUint64Array; - default: - throw new Error(`unsupported type: ${type}`); - } - }; + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; -/** - * perform inference run - */ -export const run = - (sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): SerializableTensor[] => { - const wasm = getInstance(); - const session = activeSessions.get(sessionId); - if (!session) { - throw new Error('invalid session id'); - } - const sessionHandle = session[0]; - const inputNamesUTF8Encoded = session[1]; - const outputNamesUTF8Encoded = session[2]; + const inputValues: number[] = []; + const inputAllocs: number[] = []; - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; + try { + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; + // create input tensors + for (let i = 0; i < inputCount; i++) { + const dataType = inputs[i][0]; + const dims = inputs[i][1]; + const data = inputs[i][2]; - const inputValues: number[] = []; - const inputAllocs: number[] = []; + let dataOffset: number; + let dataByteLength: number; - try { - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // create input tensors - for (let i = 0; i < inputCount; i++) { - const dataType = inputs[i][0]; - const dims = inputs[i][1]; - const data = inputs[i][2]; - - let dataOffset: number; - let dataByteLength: number; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - let dataIndex = dataOffset / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], inputAllocs); - } - } else { - dataByteLength = data.byteLength; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), dataOffset); + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + dataOffset = wasm._malloc(dataByteLength); + inputAllocs.push(dataOffset); + let dataIndex = dataOffset / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); } + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], inputAllocs); + } + } else { + dataByteLength = data.byteLength; + dataOffset = wasm._malloc(dataByteLength); + inputAllocs.push(dataOffset); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), dataOffset); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), dataOffset, dataByteLength, dimsOffset, dims.length); - if (tensor === 0) { - throw new Error('Can\'t create a tensor'); - } - inputValues.push(tensor); - } finally { - wasm.stackRestore(stack); - } + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), dataOffset, dataByteLength, dimsOffset, dims.length); + if (tensor === 0) { + throw new Error('Can\'t create a tensor'); } + inputValues.push(tensor); + } finally { + wasm.stackRestore(stack); + } + } - const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const inputNamesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); - const outputNamesOffset = wasm.stackAlloc(outputCount * 4); - - try { - let inputValuesIndex = inputValuesOffset / 4; - let inputNamesIndex = inputNamesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - let outputNamesIndex = outputNamesOffset / 4; - for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputValues[i]; - wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; - } - for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = 0; - wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; - } + const beforeRunStack = wasm.stackSave(); + const inputValuesOffset = wasm.stackAlloc(inputCount * 4); + const inputNamesOffset = wasm.stackAlloc(inputCount * 4); + const outputValuesOffset = wasm.stackAlloc(outputCount * 4); + const outputNamesOffset = wasm.stackAlloc(outputCount * 4); + + try { + let inputValuesIndex = inputValuesOffset / 4; + let inputNamesIndex = inputNamesOffset / 4; + let outputValuesIndex = outputValuesOffset / 4; + let outputNamesIndex = outputNamesOffset / 4; + for (let i = 0; i < inputCount; i++) { + wasm.HEAPU32[inputValuesIndex++] = inputValues[i]; + wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + } + for (let i = 0; i < outputCount; i++) { + wasm.HEAPU32[outputValuesIndex++] = 0; + wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; + } + + // support RunOptions + let errorCode = wasm._OrtRun( + sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, + outputValuesOffset, runOptionsHandle); + + // eslint-disable-next-line @typescript-eslint/naming-convention + const runPromise = wasm.jsepRunPromise; + if (runPromise && typeof runPromise.then !== 'undefined') { + errorCode = await runPromise; + } + + const output: SerializableTensor[] = []; + + if (errorCode === 0) { + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); - // support RunOptions - let errorCode = wasm._OrtRun( - sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, - outputValuesOffset, runOptionsHandle); - - const output: SerializableTensor[] = []; - - if (errorCode === 0) { - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); - - let type: Tensor.Type|undefined, dataOffset = 0; - try { - errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - throw new Error(`Can't access output tensor data. error code = ${errorCode}`); - } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.length === 0 ? 1 : dims.reduce((a, b) => a * b); - type = tensorDataTypeEnumToString(dataType); - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData]); - } else { - const typedArrayConstructor = numericTensorTypeToTypedArray(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength) - .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data]); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - wasm._OrtReleaseTensor(tensor); + let type: Tensor.Type|undefined, dataOffset = 0; + try { + errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + throw new Error(`Can't access output tensor data. error code = ${errorCode}`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.length === 0 ? 1 : dims.reduce((a, b) => a * b); + type = tensorDataTypeEnumToString(dataType); + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } + output.push([type, dims, stringData]); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + output.push([type, dims, data]); } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + wasm._OrtReleaseTensor(tensor); } - - if (errorCode === 0) { - return output; - } else { - throw new Error(`failed to call OrtRun(). error code = ${errorCode}.`); - } - } finally { - wasm.stackRestore(beforeRunStack); } - } finally { - inputValues.forEach(wasm._OrtReleaseTensor); - inputAllocs.forEach(wasm._free); + } - wasm._OrtReleaseRunOptions(runOptionsHandle); - runOptionsAllocs.forEach(wasm._free); + if (errorCode === 0) { + return output; + } else { + throw new Error(`failed to call OrtRun(). error code = ${errorCode}.`); } - }; + } finally { + wasm.stackRestore(beforeRunStack); + } + } finally { + inputValues.forEach(wasm._OrtReleaseTensor); + inputAllocs.forEach(wasm._free); + + wasm._OrtReleaseRunOptions(runOptionsHandle); + runOptionsAllocs.forEach(wasm._free); + } +}; /** * end profiling diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 656ef4258a51a..ad4456629041e 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -27,6 +27,7 @@ "@types/minimist": "^1.2.2", "@types/mocha": "^10.0.1", "@types/platform": "^1.3.4", + "@webgpu/types": "^0.1.30", "base64-js": "^1.5.1", "chai": "^4.3.7", "electron": "^23.1.2", @@ -335,6 +336,12 @@ "@types/node": "*" } }, + "node_modules/@webgpu/types": { + "version": "0.1.30", + "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.30.tgz", + "integrity": "sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg==", + "dev": true + }, "node_modules/accepts": { "version": "1.3.8", "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.8.tgz", @@ -4277,6 +4284,12 @@ "@types/node": "*" } }, + "@webgpu/types": { + "version": "0.1.30", + "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.30.tgz", + "integrity": "sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg==", + "dev": true + }, "accepts": { "version": "1.3.8", "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.8.tgz", diff --git a/js/web/package.json b/js/web/package.json index ad17158b57499..9b18dcbd237c8 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -43,6 +43,7 @@ "@types/minimist": "^1.2.2", "@types/mocha": "^10.0.1", "@types/platform": "^1.3.4", + "@webgpu/types": "^0.1.30", "base64-js": "^1.5.1", "chai": "^4.3.7", "electron": "^23.1.2", diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 8d02f81e11dbd..e20c391513c67 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -34,6 +34,7 @@ Options: -b=<...>, --backend=<...> Specify one or more backend(s) to run the test upon. Backends can be one or more of the following, splitted by comma: webgl + webgpu wasm xnnpack -e=<...>, --env=<...> Specify the environment to run the test. Should be one of the following: @@ -72,6 +73,7 @@ Options: --webgl-matmul-max-batch-size Set the WebGL matmulMaxBatchSize --webgl-texture-cache-mode Set the WebGL texture cache mode (initializerOnly/full) --webgl-texture-pack-mode Set the WebGL texture pack mode (true/false) + --webgpu-profiling-mode Set the WebGPU profiling mode (off/default) *** Browser Options *** @@ -102,7 +104,7 @@ Examples: export declare namespace TestRunnerCliArgs { type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; - type Backend = 'cpu'|'webgl'|'wasm'|'onnxruntime'|'xnnpack'; + type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'prod'|'dev'|'perf'; } @@ -316,11 +318,20 @@ function parseWebglFlags(args: minimist.ParsedArgs): Env.WebGLFlags { return {contextId, matmulMaxBatchSize, textureCacheMode, pack}; } +function parseWebgpuFlags(args: minimist.ParsedArgs): Env.WebGpuFlags { + const profilingMode = args['webgpu-profiling-mode']; + if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') { + throw new Error('Flag "webgpu-profiling-mode" is invalid'); + } + return {profilingMode}; +} + function parseGlobalEnvFlags(args: minimist.ParsedArgs): Env { - const wasmFlags = parseWasmFlags(args); - const webglFlags = parseWebglFlags(args); + const wasm = parseWasmFlags(args); + const webgl = parseWebglFlags(args); + const webgpu = parseWebgpuFlags(args); const cpuFlags = parseCpuFlags(args); - return {webgl: webglFlags, wasm: wasmFlags, cpuFlags}; + return {webgl, wasm, webgpu, ...cpuFlags}; } export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs { @@ -348,11 +359,16 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } // Option: -b=<...>, --backend=<...> - const browserBackends = ['webgl', 'wasm', 'xnnpack']; + const browserBackends = ['webgl', 'webgpu', 'wasm', 'xnnpack']; + + // TODO: remove this when Chrome support WebGPU. + // we need this for now because Chrome does not support webgpu yet, + // and ChromeCanary is not in CI. + const defaultBrowserBackends = ['webgl', /* 'webgpu', */ 'wasm', 'xnnpack']; const nodejsBackends = ['cpu', 'wasm']; const backendArgs = args.backend || args.b; - const backend = - (typeof backendArgs !== 'string') ? (env === 'node' ? nodejsBackends : browserBackends) : backendArgs.split(','); + const backend = (typeof backendArgs !== 'string') ? (env === 'node' ? nodejsBackends : defaultBrowserBackends) : + backendArgs.split(','); for (const b of backend) { if ((env !== 'node' && browserBackends.indexOf(b) === -1) || (env === 'node' && nodejsBackends.indexOf(b) === -1)) { throw new Error(`backend ${b} is not supported in env ${env}`); diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index a2cea3e71db36..72938789bc2df 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -53,8 +53,10 @@ async function main() { // The default backends and opset version lists. Those will be used in suite tests. const DEFAULT_BACKENDS: readonly TestRunnerCliArgs.Backend[] = - args.env === 'node' ? ['cpu', 'wasm'] : ['wasm', 'webgl']; - const DEFAULT_OPSET_VERSIONS: readonly number[] = [13, 12, 11, 10, 9, 8, 7]; + args.env === 'node' ? ['cpu', 'wasm'] : ['wasm', 'webgl', 'webgpu']; + const DEFAULT_OPSET_VERSIONS = fs.readdirSync(TEST_DATA_MODEL_NODE_ROOT, {withFileTypes: true}) + .filter(dir => dir.isDirectory() && dir.name.startsWith('opset')) + .map(dir => dir.name.slice(5)); const FILE_CACHE_ENABLED = args.fileCache; // whether to enable file cache const FILE_CACHE_MAX_FILE_SIZE = 1 * 1024 * 1024; // The max size of the file that will be put into file cache @@ -205,7 +207,7 @@ async function main() { } } - function loadNodeTests(backend: string, version: number): Test.ModelTestGroup { + function loadNodeTests(backend: string, version: string): Test.ModelTestGroup { return suiteFromFolder( `node-opset_v${version}-${backend}`, path.join(TEST_DATA_MODEL_NODE_ROOT, `opset${version}`), backend, testlist[backend].node); @@ -333,7 +335,7 @@ async function main() { [searchPattern, path.join(TEST_DATA_MODEL_NODE_ROOT, '**', searchPattern).replace(/\\/g, '/')]; // 4 - check the globby result of NODE root combined with opset versions and searchPattern globbyPattern.push(...DEFAULT_OPSET_VERSIONS.map( - v => path.join(TEST_DATA_MODEL_NODE_ROOT, `v${v}`, '**', searchPattern).replace(/\\/g, '/'))); + v => path.join(TEST_DATA_MODEL_NODE_ROOT, `opset${v}`, '**', searchPattern).replace(/\\/g, '/'))); folderCandidates.push(...globbySync(globbyPattern, {onlyDirectories: true, absolute: true})); @@ -456,11 +458,13 @@ async function main() { } else { // STEP 5. use Karma to run test npmlog.info('TestRunnerCli.Run', '(4/4) Running karma to start test runner...'); + const webgpu = args.backends.indexOf('webgpu') > -1; const browser = getBrowserNameFromEnv( args.env, args.bundleMode === 'perf' ? 'perf' : args.debug ? 'debug' : - 'test'); + 'test', + webgpu, config.options.globalEnvFlags?.webgpu?.profilingMode === 'default'); const karmaArgs = ['karma', 'start', `--browsers ${browser}`]; if (args.debug) { karmaArgs.push('--log-level info --timeout-mocha 9999999'); @@ -470,6 +474,9 @@ async function main() { if (args.noSandbox) { karmaArgs.push('--no-sandbox'); } + if (webgpu) { + karmaArgs.push('--force-localhost'); + } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); if (browser === 'Edge') { // There are currently 2 Edge browser launchers: @@ -561,10 +568,11 @@ async function main() { fs.writeJSONSync(path.join(TEST_ROOT, './testdata-config.json'), config); } - function getBrowserNameFromEnv(env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test') { + function getBrowserNameFromEnv( + env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean, profile: boolean) { switch (env) { case 'chrome': - return selectChromeBrowser(mode); + return selectChromeBrowser(mode, webgpu, profile); case 'edge': return 'Edge'; case 'firefox': @@ -580,14 +588,23 @@ async function main() { } } - function selectChromeBrowser(mode: 'debug'|'perf'|'test') { - switch (mode) { - case 'debug': - return 'ChromeDebug'; - case 'perf': - return 'ChromePerf'; - default: - return 'ChromeTest'; + function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean, profile: boolean) { + if (webgpu) { + switch (mode) { + case 'debug': + return profile ? 'ChromeCanaryProfileDebug' : 'ChromeCanaryDebug'; + default: + return profile ? 'ChromeCanaryProfileTest' : 'ChromeCanaryDebug'; + } + } else { + switch (mode) { + case 'debug': + return 'ChromeDebug'; + case 'perf': + return 'ChromePerf'; + default: + return 'ChromeTest'; + } } } } diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index cb59689c4b027..17928899c91b1 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -160,36 +160,36 @@ "test_sum_example", "test_sum_one_input", "test_sum_two_inputs", - "test_reduce_log_sum_asc_axes", - "test_reduce_log_sum_default", - "test_reduce_log_sum_desc_axes", - "test_reduce_max_default_axes_keepdim_example", - "test_reduce_max_default_axes_keepdims_random", - "test_reduce_max_do_not_keepdims_example", - "test_reduce_max_do_not_keepdims_random", - "test_reduce_max_keepdims_example", - "test_reduce_max_keepdims_random", - "test_reduce_mean_default_axes_keepdims_example", - "test_reduce_mean_default_axes_keepdims_random", - "test_reduce_mean_do_not_keepdims_example", - "test_reduce_mean_do_not_keepdims_random", - "test_reduce_mean_keepdims_example", - "test_reduce_mean_keepdims_random", - "test_reduce_min_default_axes_keepdims_example", - "test_reduce_min_default_axes_keepdims_random", - "test_reduce_min_do_not_keepdims_example", - "test_reduce_min_do_not_keepdims_random", - "test_reduce_min_keepdims_example", - "test_reduce_min_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_log_sum_asc_axes", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_log_sum_default", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_log_sum_desc_axes", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_max_default_axes_keepdim_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_max_default_axes_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_max_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_max_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_max_keepdims_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_max_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_mean_default_axes_keepdims_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_mean_default_axes_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_mean_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_mean_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_mean_keepdims_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_mean_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_min_default_axes_keepdims_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_min_default_axes_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_min_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_min_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_min_keepdims_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_min_keepdims_random", { - "name": "test_reduce_prod_default_axes_keepdims_example", + "name": "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_prod_default_axes_keepdims_example", "condition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment }, - "test_reduce_prod_default_axes_keepdims_random", - "test_reduce_prod_do_not_keepdims_example", - "test_reduce_prod_do_not_keepdims_random", - "test_reduce_prod_keepdims_example", - "test_reduce_prod_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_prod_default_axes_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_prod_do_not_keepdims_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_prod_do_not_keepdims_random", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_prod_keepdims_example", + "opset{7,8,9,10,11,12,13,14,15,16,17}/test_reduce_prod_keepdims_random", "opset{7,8,9,10,11,12}/test_reduce_sum_default_axes_keepdims_example", "opset{7,8,9,10,11,12}/test_reduce_sum_default_axes_keepdims_random", "opset{7,8,9,10,11,12}/test_reduce_sum_do_not_keepdims_example", @@ -282,6 +282,1045 @@ "xor.jsonc" ] }, + "webgpu": { + "onnx": [], + "node": [ + "test_abs", + "test_acos_example", + "test_acos", + "test_acosh_example", + "test_acosh", + // // "test_adagrad_multiple", + // // "test_adagrad", + // // "test_adam_multiple", + // // "test_adam", + "test_add_bcast", + // "test_add_uint8", + "test_add", + // "test_and_bcast3v1d", + // "test_and_bcast3v2d", + // "test_and_bcast4v2d", + // "test_and_bcast4v3d", + // "test_and_bcast4v4d", + // "test_and2d", + // "test_and3d", + // "test_and4d", + // // "test_argmax_default_axis_example_select_last_index", + // // "test_argmax_default_axis_example", + // // "test_argmax_default_axis_random_select_last_index", + // // "test_argmax_default_axis_random", + // // "test_argmax_keepdims_example_select_last_index", + // // "test_argmax_keepdims_example", + // // "test_argmax_keepdims_random_select_last_index", + // // "test_argmax_keepdims_random", + // // "test_argmax_negative_axis_keepdims_example_select_last_index", + // // "test_argmax_negative_axis_keepdims_example", + // // "test_argmax_negative_axis_keepdims_random_select_last_index", + // // "test_argmax_negative_axis_keepdims_random", + // // "test_argmax_no_keepdims_example_select_last_index", + // // "test_argmax_no_keepdims_example", + // // "test_argmax_no_keepdims_random_select_last_index", + // // "test_argmax_no_keepdims_random", + // // "test_argmin_default_axis_example_select_last_index", + // // "test_argmin_default_axis_example", + // // "test_argmin_default_axis_random_select_last_index", + // // "test_argmin_default_axis_random", + // // "test_argmin_keepdims_example_select_last_index", + // // "test_argmin_keepdims_example", + // // "test_argmin_keepdims_random_select_last_index", + // // "test_argmin_keepdims_random", + // // "test_argmin_negative_axis_keepdims_example_select_last_index", + // // "test_argmin_negative_axis_keepdims_example", + // // "test_argmin_negative_axis_keepdims_random_select_last_index", + // // "test_argmin_negative_axis_keepdims_random", + // // "test_argmin_no_keepdims_example_select_last_index", + // // "test_argmin_no_keepdims_example", + // // "test_argmin_no_keepdims_random_select_last_index", + // // "test_argmin_no_keepdims_random", + "test_asin_example", + "test_asin", + "test_asinh_example", + "test_asinh", + "test_atan_example", + "test_atan", + "test_atanh_example", + "test_atanh", + // "test_averagepool_1d_default", + // "test_averagepool_2d_ceil", + "test_averagepool_2d_default", + "test_averagepool_2d_pads_count_include_pad", + "test_averagepool_2d_pads", + "test_averagepool_2d_precomputed_pads_count_include_pad", + "test_averagepool_2d_precomputed_pads", + "test_averagepool_2d_precomputed_same_upper", + "test_averagepool_2d_precomputed_strides", + "test_averagepool_2d_same_lower", + "test_averagepool_2d_same_upper", + "test_averagepool_2d_strides", + // "test_averagepool_3d_default", + "test_basic_conv_with_padding", + "test_basic_conv_without_padding", + // "test_basic_convinteger", + "test_batchnorm_epsilon_training_mode", + "test_batchnorm_epsilon", + "test_batchnorm_example_training_mode", + "test_batchnorm_example", + // // "test_bernoulli_double_expanded", + // // "test_bernoulli_double", + // // "test_bernoulli_expanded", + // // "test_bernoulli_seed_expanded", + // // "test_bernoulli_seed", + // // "test_bernoulli", + // // "test_bitshift_left_uint16", + // // "test_bitshift_left_uint32", + // // "test_bitshift_left_uint64", + // // "test_bitshift_left_uint8", + // // "test_bitshift_right_uint16", + // // "test_bitshift_right_uint32", + // // "test_bitshift_right_uint64", + // // "test_bitshift_right_uint8", + // // "test_blackmanwindow_expanded", + // // "test_blackmanwindow_symmetric_expanded", + // // "test_blackmanwindow_symmetric", + // // "test_blackmanwindow", + // // "test_cast_BFLOAT16_to_FLOAT", + // // "test_cast_DOUBLE_to_FLOAT", + // // "test_cast_DOUBLE_to_FLOAT16", + // // "test_cast_FLOAT_to_BFLOAT16", + // // "test_cast_FLOAT_to_DOUBLE", + // // "test_cast_FLOAT_to_FLOAT16", + // // "test_cast_FLOAT_to_STRING", + // // "test_cast_FLOAT16_to_DOUBLE", + // // "test_cast_FLOAT16_to_FLOAT", + // // "test_cast_STRING_to_FLOAT", + // // "test_castlike_BFLOAT16_to_FLOAT_expanded", + // // "test_castlike_BFLOAT16_to_FLOAT", + // // "test_castlike_DOUBLE_to_FLOAT_expanded", + // // "test_castlike_DOUBLE_to_FLOAT", + // // "test_castlike_DOUBLE_to_FLOAT16_expanded", + // // "test_castlike_DOUBLE_to_FLOAT16", + // // "test_castlike_FLOAT_to_BFLOAT16_expanded", + // // "test_castlike_FLOAT_to_BFLOAT16", + // // "test_castlike_FLOAT_to_DOUBLE_expanded", + // // "test_castlike_FLOAT_to_DOUBLE", + // // "test_castlike_FLOAT_to_FLOAT16_expanded", + // // "test_castlike_FLOAT_to_FLOAT16", + // // "test_castlike_FLOAT_to_STRING_expanded", + // // "test_castlike_FLOAT_to_STRING", + // // "test_castlike_FLOAT16_to_DOUBLE_expanded", + // // "test_castlike_FLOAT16_to_DOUBLE", + // // "test_castlike_FLOAT16_to_FLOAT_expanded", + // // "test_castlike_FLOAT16_to_FLOAT", + // // "test_castlike_STRING_to_FLOAT_expanded", + // // "test_castlike_STRING_to_FLOAT", + "test_ceil_example", + "test_ceil", + // "test_celu_expanded", + // "test_celu", + // "test_clip_default_inbounds", + // "test_clip_default_int8_inbounds", + // "test_clip_default_int8_max", + // "test_clip_default_int8_min", + // "test_clip_default_max", + // "test_clip_default_min", + // "test_clip_example", + // "test_clip_inbounds", + // "test_clip_outbounds", + // "test_clip_splitbounds", + // "test_clip", + // // "test_compress_0", + // // "test_compress_1", + // // "test_compress_default_axis", + // // "test_compress_negative_axis", + // "test_concat_1d_axis_0", + // "test_concat_1d_axis_negative_1", + // "test_concat_2d_axis_0", + // "test_concat_2d_axis_1", + // "test_concat_2d_axis_negative_1", + // "test_concat_2d_axis_negative_2", + // "test_concat_3d_axis_0", + // "test_concat_3d_axis_1", + // "test_concat_3d_axis_2", + // "test_concat_3d_axis_negative_1", + // "test_concat_3d_axis_negative_2", + // "test_concat_3d_axis_negative_3", + "test_conv_with_autopad_same", + "test_conv_with_strides_and_asymmetric_padding", + "test_conv_with_strides_no_padding", + "test_conv_with_strides_padding", + // // "test_convinteger_with_padding", + // // "test_convinteger_without_padding", + // // "test_convtranspose_1d", + // // "test_convtranspose_3d", + // // "test_convtranspose_autopad_same", + // // "test_convtranspose_dilations", + // // "test_convtranspose_kernel_shape", + // // "test_convtranspose_output_shape", + // // "test_convtranspose_pad", + // // "test_convtranspose_pads", + // // "test_convtranspose_with_kernel", + // // "test_convtranspose", + "test_cos_example", + "test_cos", + "test_cosh_example", + "test_cosh", + // "test_cumsum_1d_exclusive", + // "test_cumsum_1d_reverse_exclusive", + // "test_cumsum_1d_reverse", + // "test_cumsum_1d", + // "test_cumsum_2d_axis_0", + // "test_cumsum_2d_axis_1", + // "test_cumsum_2d_negative_axis", + // "test_depthtospace_crd_mode_example", + // "test_depthtospace_crd_mode", + // "test_depthtospace_dcr_mode", + // "test_depthtospace_example", + // "test_depthtospace", + // // "test_dequantizelinear_axis", + // // "test_dequantizelinear", + // // "test_det_2d", + // // "test_det_nd", + // // "test_dft_axis", + // // "test_dft_inverse", + // // "test_dft", + "test_div_bcast", + "test_div_example", + // "test_div_uint8", + "test_div", + // // "test_dropout_default_mask_ratio", + // // "test_dropout_default_mask", + // // "test_dropout_default_old", + // // "test_dropout_default_ratio", + // // "test_dropout_default", + // // "test_dropout_random_old", + // // "test_dropout_random", + // // "test_dynamic_slice_default_axes", + // // "test_dynamic_slice_end_out_of_bounds", + // // "test_dynamic_slice_neg", + // // "test_dynamic_slice_start_out_of_bounds", + // // "test_dynamic_slice", + // // "test_dynamicquantizelinear_expanded", + // // "test_dynamicquantizelinear_max_adjusted_expanded", + // // "test_dynamicquantizelinear_max_adjusted", + // // "test_dynamicquantizelinear_min_adjusted_expanded", + // // "test_dynamicquantizelinear_min_adjusted", + // // "test_dynamicquantizelinear", + // // "test_edge_pad", + // "test_einsum_batch_diagonal", + // "test_einsum_batch_matmul", + // "test_einsum_inner_prod", + // "test_einsum_sum", + // "test_einsum_transpose", + "test_elu_default", + "test_elu_example", + "test_elu", + // "test_equal_bcast", + // "test_equal", + "test_erf", + // "test_exp_example", + // "test_exp", + // "test_expand_dim_changed", + // "test_expand_dim_unchanged", + // "test_eyelike_populate_off_main_diagonal", + // "test_eyelike_with_dtype", + // "test_eyelike_without_dtype", + // "test_flatten_axis0", + // "test_flatten_axis1", + // "test_flatten_axis2", + // "test_flatten_axis3", + // "test_flatten_default_axis", + // "test_flatten_negative_axis1", + // "test_flatten_negative_axis2", + // "test_flatten_negative_axis3", + // "test_flatten_negative_axis4", + "test_floor_example", + "test_floor", + // "test_gather_0", + // "test_gather_1", + // "test_gather_2d_indices", + // "test_gather_elements_0", + // "test_gather_elements_1", + // "test_gather_elements_negative_indices", + // "test_gather_negative_indices", + // // "test_gathernd_example_float32", + // // "test_gathernd_example_int32_batch_dim1", + // // "test_gathernd_example_int32", + "test_gemm_all_attributes", + "test_gemm_alpha", + "test_gemm_beta", + "test_gemm_broadcast", + "test_gemm_default_matrix_bias", + "test_gemm_default_no_bias", + "test_gemm_default_scalar_bias", + "test_gemm_default_single_elem_vector_bias", + "test_gemm_default_vector_bias", + "test_gemm_default_zero_bias", + "test_gemm_nobroadcast", + "test_gemm_transposeA", + "test_gemm_transposeB", + "test_globalaveragepool_precomputed", + "test_globalaveragepool", + "test_globalmaxpool_precomputed", + "test_globalmaxpool", + // "test_greater_bcast", + // "test_greater_equal_bcast_expanded", + // "test_greater_equal_bcast", + // "test_greater_equal_expanded", + // "test_greater_equal", + // "test_greater", + // // "test_gridsample_aligncorners_true", + // // "test_gridsample_bicubic", + // // "test_gridsample_bilinear", + // // "test_gridsample_border_padding", + // // "test_gridsample_nearest", + // // "test_gridsample_reflection_padding", + // // "test_gridsample_zeros_padding", + // // "test_gridsample", + // // "test_gru_batchwise", + // // "test_gru_defaults", + // // "test_gru_seq_length", + // // "test_gru_with_initial_bias", + // // "test_hammingwindow_expanded", + // // "test_hammingwindow_symmetric_expanded", + // // "test_hammingwindow_symmetric", + // // "test_hammingwindow", + // // "test_hannwindow_expanded", + // // "test_hannwindow_symmetric_expanded", + // // "test_hannwindow_symmetric", + // // "test_hannwindow", + // // "test_hardmax_axis_0", + // // "test_hardmax_axis_1", + // // "test_hardmax_axis_2", + // // "test_hardmax_default_axis", + // // "test_hardmax_example", + // // "test_hardmax_negative_axis", + // // "test_hardmax_one_hot", + // // "test_hardsigmoid_default", + // // "test_hardsigmoid_example", + // // "test_hardsigmoid", + // // "test_hardswish_expanded", + // // "test_hardswish", + // // "test_instancenorm_epsilon", + // // "test_instancenorm_example", + // "test_isinf_negative", + // "test_isinf_positive", + // "test_isinf", + // "test_isnan", + // // "test_layer_normalization_2d_axis_negative_1_expanded", + // // "test_layer_normalization_2d_axis_negative_1", + // // "test_layer_normalization_2d_axis_negative_2_expanded", + // // "test_layer_normalization_2d_axis_negative_2", + // // "test_layer_normalization_2d_axis0_expanded", + // // "test_layer_normalization_2d_axis0", + // // "test_layer_normalization_2d_axis1_expanded", + // // "test_layer_normalization_2d_axis1", + // // "test_layer_normalization_3d_axis_negative_1_epsilon_expanded", + // // "test_layer_normalization_3d_axis_negative_1_epsilon", + // // "test_layer_normalization_3d_axis_negative_2_epsilon_expanded", + // // "test_layer_normalization_3d_axis_negative_2_epsilon", + // // "test_layer_normalization_3d_axis_negative_3_epsilon_expanded", + // // "test_layer_normalization_3d_axis_negative_3_epsilon", + // // "test_layer_normalization_3d_axis0_epsilon_expanded", + // // "test_layer_normalization_3d_axis0_epsilon", + // // "test_layer_normalization_3d_axis1_epsilon_expanded", + // // "test_layer_normalization_3d_axis1_epsilon", + // // "test_layer_normalization_3d_axis2_epsilon_expanded", + // // "test_layer_normalization_3d_axis2_epsilon", + // // "test_layer_normalization_4d_axis_negative_1_expanded", + // // "test_layer_normalization_4d_axis_negative_1", + // // "test_layer_normalization_4d_axis_negative_2_expanded", + // // "test_layer_normalization_4d_axis_negative_2", + // // "test_layer_normalization_4d_axis_negative_3_expanded", + // // "test_layer_normalization_4d_axis_negative_3", + // // "test_layer_normalization_4d_axis_negative_4_expanded", + // // "test_layer_normalization_4d_axis_negative_4", + // // "test_layer_normalization_4d_axis0_expanded", + // // "test_layer_normalization_4d_axis0", + // // "test_layer_normalization_4d_axis1_expanded", + // // "test_layer_normalization_4d_axis1", + // // "test_layer_normalization_4d_axis2_expanded", + // // "test_layer_normalization_4d_axis2", + // // "test_layer_normalization_4d_axis3_expanded", + // // "test_layer_normalization_4d_axis3", + // // "test_layer_normalization_default_axis_expanded", + // // "test_layer_normalization_default_axis", + // "test_leakyrelu_default", + // "test_leakyrelu_example", + // "test_leakyrelu", + // "test_less_bcast", + // "test_less_equal_bcast_expanded", + // "test_less_equal_bcast", + // "test_less_equal_expanded", + // "test_less_equal", + // "test_less", + "test_log_example", + "test_log", + // // "test_logsoftmax_axis_0_expanded", + // // "test_logsoftmax_axis_0", + // // "test_logsoftmax_axis_1_expanded", + // // "test_logsoftmax_axis_1", + // // "test_logsoftmax_axis_2_expanded", + // // "test_logsoftmax_axis_2", + // // "test_logsoftmax_default_axis_expanded", + // // "test_logsoftmax_default_axis", + // // "test_logsoftmax_example_1_expanded", + // // "test_logsoftmax_example_1", + // // "test_logsoftmax_large_number_expanded", + // // "test_logsoftmax_large_number", + // // "test_logsoftmax_negative_axis_expanded", + // // "test_logsoftmax_negative_axis", + // "test_lrn_default", + // "test_lrn", + // // "test_lstm_batchwise", + // // "test_lstm_defaults", + // // "test_lstm_with_initial_bias", + // // "test_lstm_with_peepholes", + "test_matmul_2d", + "test_matmul_3d", + "test_matmul_4d", + // // "test_matmulinteger", + // "test_max_example", + // "test_max_float16", + // "test_max_float32", + // "test_max_float64", + // "test_max_int16", + // "test_max_int32", + // "test_max_int64", + // "test_max_int8", + // "test_max_one_input", + // "test_max_two_inputs", + // "test_max_uint16", + // "test_max_uint32", + // "test_max_uint64", + // "test_max_uint8", + // "test_maxpool_1d_default", + // "test_maxpool_2d_ceil", + "test_maxpool_2d_default", + // "test_maxpool_2d_dilations", + "test_maxpool_2d_pads", + "test_maxpool_2d_precomputed_pads", + "test_maxpool_2d_precomputed_same_upper", + "test_maxpool_2d_precomputed_strides", + "test_maxpool_2d_same_lower", + "test_maxpool_2d_same_upper", + "test_maxpool_2d_strides", + // "test_maxpool_2d_uint8", + // "test_maxpool_3d_default", + // "test_maxpool_with_argmax_2d_precomputed_pads", + // "test_maxpool_with_argmax_2d_precomputed_strides", + // // "test_maxunpool_export_with_output_shape", + // // "test_maxunpool_export_without_output_shape", + // // "test_mean_example", + // // "test_mean_one_input", + // // "test_mean_two_inputs", + // // "test_melweightmatrix", + // "test_min_example", + // "test_min_float16", + // "test_min_float32", + // "test_min_float64", + // "test_min_int16", + // "test_min_int32", + // "test_min_int64", + // "test_min_int8", + // "test_min_one_input", + // "test_min_two_inputs", + // "test_min_uint16", + // "test_min_uint32", + // "test_min_uint64", + // "test_min_uint8", + // "test_mod_bcast", + // "test_mod_broadcast", + // "test_mod_float_mixed_sign_example", + // "test_mod_fmod_mixed_sign_example", + // "test_mod_int64_fmod", + // "test_mod_int64_mixed_sign_example", + // "test_mod_mixed_sign_float16", + // "test_mod_mixed_sign_float32", + // "test_mod_mixed_sign_float64", + // "test_mod_mixed_sign_int16", + // "test_mod_mixed_sign_int32", + // "test_mod_mixed_sign_int64", + // "test_mod_mixed_sign_int8", + // "test_mod_uint16", + // "test_mod_uint32", + // "test_mod_uint64", + // "test_mod_uint8", + // // "test_momentum_multiple", + // // "test_momentum", + "test_mul_bcast", + "test_mul_example", + // "test_mul_uint8", + "test_mul", + // "test_mvn_expanded", + // "test_mvn", + "test_neg_example", + "test_neg", + // // "test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NC_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NC", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_no_weight_reduction_mean_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_no_weight_reduction_mean_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight", + // // "test_nesterov_momentum", + // // "test_nllloss_NC_expanded", + // // "test_nllloss_NC", + // // "test_nllloss_NCd1_expanded", + // // "test_nllloss_NCd1_ii_expanded", + // // "test_nllloss_NCd1_ii", + // // "test_nllloss_NCd1_mean_weight_negative_ii_expanded", + // // "test_nllloss_NCd1_mean_weight_negative_ii", + // // "test_nllloss_NCd1_weight_expanded", + // // "test_nllloss_NCd1_weight_ii_expanded", + // // "test_nllloss_NCd1_weight_ii", + // // "test_nllloss_NCd1_weight", + // // "test_nllloss_NCd1", + // // "test_nllloss_NCd1d2_expanded", + // // "test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded", + // // "test_nllloss_NCd1d2_no_weight_reduction_mean_ii", + // // "test_nllloss_NCd1d2_reduction_mean_expanded", + // // "test_nllloss_NCd1d2_reduction_mean", + // // "test_nllloss_NCd1d2_reduction_sum_expanded", + // // "test_nllloss_NCd1d2_reduction_sum", + // // "test_nllloss_NCd1d2_with_weight_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_mean_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_mean", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_ii", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum", + // // "test_nllloss_NCd1d2_with_weight", + // // "test_nllloss_NCd1d2", + // // "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", + // // "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", + // // "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded", + // // "test_nllloss_NCd1d2d3_sum_weight_high_ii", + // // "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_nllloss_NCd1d2d3d4d5_mean_weight", + // // "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_nllloss_NCd1d2d3d4d5_none_no_weight", + // "test_nonmaxsuppression_center_point_box_format", + // "test_nonmaxsuppression_flipped_coordinates", + // "test_nonmaxsuppression_identical_boxes", + // "test_nonmaxsuppression_limit_output_size", + // "test_nonmaxsuppression_single_box", + // "test_nonmaxsuppression_suppress_by_IOU_and_scores", + // "test_nonmaxsuppression_suppress_by_IOU", + // "test_nonmaxsuppression_two_batches", + // "test_nonmaxsuppression_two_classes", + // "test_nonzero_example", + // "test_not_2d", + // "test_not_3d", + // "test_not_4d", + // // "test_onehot_negative_indices", + // // "test_onehot_with_axis", + // // "test_onehot_with_negative_axis", + // // "test_onehot_without_axis", + // // "test_optional_get_element_sequence", + // // "test_optional_get_element", + // // "test_optional_has_element_empty", + // // "test_optional_has_element", + // "test_or_bcast3v1d", + // "test_or_bcast3v2d", + // "test_or_bcast4v2d", + // "test_or_bcast4v3d", + // "test_or_bcast4v4d", + // "test_or2d", + // "test_or3d", + // "test_or4d", + "test_pow_bcast_array", + "test_pow_bcast_scalar", + "test_pow_example", + // "test_pow_types_float", + // "test_pow_types_float32_int32", + // "test_pow_types_float32_int64", + // "test_pow_types_float32_uint32", + // "test_pow_types_float32_uint64", + // "test_pow_types_int", + // "test_pow_types_int32_float32", + // "test_pow_types_int32_int32", + // "test_pow_types_int64_float32", + // "test_pow_types_int64_int64", + "test_pow", + // "test_prelu_broadcast", + // "test_prelu_example", + // // "test_qlinearconv", + // // "test_qlinearmatmul_2D", + // // "test_qlinearmatmul_3D", + // // "test_quantizelinear_axis", + // // "test_quantizelinear", + // "test_range_float_type_positive_delta_expanded", + // "test_range_float_type_positive_delta", + // "test_range_int32_type_negative_delta_expanded", + // "test_range_int32_type_negative_delta", + "test_reciprocal_example", + "test_reciprocal", + // "test_reduce_l1_default_axes_keepdims_example", + // "test_reduce_l1_default_axes_keepdims_random", + // "test_reduce_l1_do_not_keepdims_example", + // "test_reduce_l1_do_not_keepdims_random", + // "test_reduce_l1_keep_dims_example", + // "test_reduce_l1_keep_dims_random", + // "test_reduce_l1_negative_axes_keep_dims_example", + // "test_reduce_l1_negative_axes_keep_dims_random", + // "test_reduce_l2_default_axes_keepdims_example", + // "test_reduce_l2_default_axes_keepdims_random", + // "test_reduce_l2_do_not_keepdims_example", + // "test_reduce_l2_do_not_keepdims_random", + // "test_reduce_l2_keep_dims_example", + // "test_reduce_l2_keep_dims_random", + // "test_reduce_l2_negative_axes_keep_dims_example", + // "test_reduce_l2_negative_axes_keep_dims_random", + // "test_reduce_log_sum_asc_axes", + // "test_reduce_log_sum_default", + // "test_reduce_log_sum_desc_axes", + // "test_reduce_log_sum_exp_default_axes_keepdims_example", + // "test_reduce_log_sum_exp_default_axes_keepdims_random", + // "test_reduce_log_sum_exp_do_not_keepdims_example", + // "test_reduce_log_sum_exp_do_not_keepdims_random", + // "test_reduce_log_sum_exp_keepdims_example", + // "test_reduce_log_sum_exp_keepdims_random", + // "test_reduce_log_sum_exp_negative_axes_keepdims_example", + // "test_reduce_log_sum_exp_negative_axes_keepdims_random", + // "test_reduce_log_sum_negative_axes", + // "test_reduce_log_sum", + // "test_reduce_max_default_axes_keepdim_example", + // "test_reduce_max_default_axes_keepdims_random", + // "test_reduce_max_do_not_keepdims_example", + // "test_reduce_max_do_not_keepdims_random", + // "test_reduce_max_keepdims_example", + // "test_reduce_max_keepdims_random", + // "test_reduce_max_negative_axes_keepdims_example", + // "test_reduce_max_negative_axes_keepdims_random", + // "test_reduce_mean_default_axes_keepdims_example", + // "test_reduce_mean_default_axes_keepdims_random", + // "test_reduce_mean_do_not_keepdims_example", + // "test_reduce_mean_do_not_keepdims_random", + // "test_reduce_mean_keepdims_example", + // "test_reduce_mean_keepdims_random", + // "test_reduce_mean_negative_axes_keepdims_example", + // "test_reduce_mean_negative_axes_keepdims_random", + // "test_reduce_min_default_axes_keepdims_example", + // "test_reduce_min_default_axes_keepdims_random", + // "test_reduce_min_do_not_keepdims_example", + // "test_reduce_min_do_not_keepdims_random", + // "test_reduce_min_keepdims_example", + // "test_reduce_min_keepdims_random", + // "test_reduce_min_negative_axes_keepdims_example", + // "test_reduce_min_negative_axes_keepdims_random", + // "test_reduce_prod_default_axes_keepdims_example", + // "test_reduce_prod_default_axes_keepdims_random", + // "test_reduce_prod_do_not_keepdims_example", + // "test_reduce_prod_do_not_keepdims_random", + // "test_reduce_prod_keepdims_example", + // "test_reduce_prod_keepdims_random", + // "test_reduce_prod_negative_axes_keepdims_example", + // "test_reduce_prod_negative_axes_keepdims_random", + // "test_reduce_sum_default_axes_keepdims_example", + // "test_reduce_sum_default_axes_keepdims_random", + // "test_reduce_sum_do_not_keepdims_example", + // "test_reduce_sum_do_not_keepdims_random", + // "test_reduce_sum_empty_axes_input_noop_example", + // "test_reduce_sum_empty_axes_input_noop_random", + // "test_reduce_sum_keepdims_example", + // "test_reduce_sum_keepdims_random", + // "test_reduce_sum_negative_axes_keepdims_example", + // "test_reduce_sum_negative_axes_keepdims_random", + // "test_reduce_sum_square_default_axes_keepdims_example", + // "test_reduce_sum_square_default_axes_keepdims_random", + // "test_reduce_sum_square_do_not_keepdims_example", + // "test_reduce_sum_square_do_not_keepdims_random", + // "test_reduce_sum_square_keepdims_example", + // "test_reduce_sum_square_keepdims_random", + // "test_reduce_sum_square_negative_axes_keepdims_example", + // "test_reduce_sum_square_negative_axes_keepdims_random", + // // "test_reflect_pad", + "test_relu", + // "test_reshape_allowzero_reordered", + "test_reshape_extended_dims", + "test_reshape_negative_dim", + "test_reshape_negative_extended_dims", + "test_reshape_one_dim", + "test_reshape_reduced_dims", + "test_reshape_reordered_all_dims", + "test_reshape_reordered_dims", + "test_reshape_reordered_last_dims", + "test_reshape_zero_and_negative_dim", + "test_reshape_zero_dim", + // "test_resize_downsample_linear", + // "test_resize_downsample_nearest", + // "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside", + // "test_resize_downsample_scales_cubic_align_corners", + // "test_resize_downsample_scales_cubic", + // "test_resize_downsample_scales_linear_align_corners", + // "test_resize_downsample_scales_linear", + // "test_resize_downsample_scales_nearest", + // "test_resize_downsample_sizes_cubic", + // "test_resize_downsample_sizes_linear_pytorch_half_pixel", + // "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn", + // "test_resize_downsample_sizes_nearest", + // "test_resize_nearest", + // "test_resize_tf_crop_and_resize", + // "test_resize_upsample_linear", + // "test_resize_upsample_nearest", + // "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside", + // "test_resize_upsample_scales_cubic_align_corners", + // "test_resize_upsample_scales_cubic_asymmetric", + // "test_resize_upsample_scales_cubic", + // "test_resize_upsample_scales_linear_align_corners", + // "test_resize_upsample_scales_linear", + // "test_resize_upsample_scales_nearest", + // "test_resize_upsample_sizes_cubic", + // "test_resize_upsample_sizes_nearest_ceil_half_pixel", + // "test_resize_upsample_sizes_nearest_floor_align_corners", + // "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", + // "test_resize_upsample_sizes_nearest", + // // "test_reversesequence_batch", + // // "test_reversesequence_time", + // // "test_rnn_seq_length", + // // "test_roialign_aligned_false", + // // "test_roialign_aligned_true", + // // "test_roialign", + // // "test_round", + // // "test_scan_sum", + // // "test_scan9_sum", + // // "test_scatter_elements_with_axis", + // // "test_scatter_elements_with_duplicate_indices", + // // "test_scatter_elements_with_negative_indices", + // // "test_scatter_elements_without_axis", + // // "test_scatter_with_axis", + // // "test_scatter_without_axis", + // // "test_scatternd_add", + // // "test_scatternd_multiply", + // // "test_scatternd", + // // "test_sce_mean_3d_expanded", + // // "test_sce_mean_3d_log_prob_expanded", + // // "test_sce_mean_3d_log_prob", + // // "test_sce_mean_3d", + // // "test_sce_mean_expanded", + // // "test_sce_mean_log_prob_expanded", + // // "test_sce_mean_log_prob", + // // "test_sce_mean_no_weight_ii_3d_expanded", + // // "test_sce_mean_no_weight_ii_3d_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_3d_log_prob", + // // "test_sce_mean_no_weight_ii_3d", + // // "test_sce_mean_no_weight_ii_4d_expanded", + // // "test_sce_mean_no_weight_ii_4d_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_4d_log_prob", + // // "test_sce_mean_no_weight_ii_4d", + // // "test_sce_mean_no_weight_ii_expanded", + // // "test_sce_mean_no_weight_ii_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_log_prob", + // // "test_sce_mean_no_weight_ii", + // // "test_sce_mean_weight_expanded", + // // "test_sce_mean_weight_ii_3d_expanded", + // // "test_sce_mean_weight_ii_3d_log_prob_expanded", + // // "test_sce_mean_weight_ii_3d_log_prob", + // // "test_sce_mean_weight_ii_3d", + // // "test_sce_mean_weight_ii_4d_expanded", + // // "test_sce_mean_weight_ii_4d_log_prob_expanded", + // // "test_sce_mean_weight_ii_4d_log_prob", + // // "test_sce_mean_weight_ii_4d", + // // "test_sce_mean_weight_ii_expanded", + // // "test_sce_mean_weight_ii_log_prob_expanded", + // // "test_sce_mean_weight_ii_log_prob", + // // "test_sce_mean_weight_ii", + // // "test_sce_mean_weight_log_prob_expanded", + // // "test_sce_mean_weight_log_prob", + // // "test_sce_mean_weight", + // // "test_sce_mean", + // // "test_sce_NCd1_mean_weight_negative_ii_expanded", + // // "test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded", + // // "test_sce_NCd1_mean_weight_negative_ii_log_prob", + // // "test_sce_NCd1_mean_weight_negative_ii", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_expanded", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob", + // // "test_sce_NCd1d2d3_sum_weight_high_ii", + // // "test_sce_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded", + // // "test_sce_NCd1d2d3d4d5_mean_weight_log_prob", + // // "test_sce_NCd1d2d3d4d5_mean_weight", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob", + // // "test_sce_NCd1d2d3d4d5_none_no_weight", + // // "test_sce_none_expanded", + // // "test_sce_none_log_prob_expanded", + // // "test_sce_none_log_prob", + // // "test_sce_none_weights_expanded", + // // "test_sce_none_weights_log_prob_expanded", + // // "test_sce_none_weights_log_prob", + // // "test_sce_none_weights", + // // "test_sce_none", + // // "test_sce_sum_expanded", + // // "test_sce_sum_log_prob_expanded", + // // "test_sce_sum_log_prob", + // // "test_sce_sum", + // "test_selu_default", + // "test_selu_example", + // "test_selu", + // // "test_sequence_insert_at_back", + // // "test_sequence_insert_at_front", + // // "test_sequence_map_add_1_sequence_1_tensor_expanded", + // // "test_sequence_map_add_1_sequence_1_tensor", + // // "test_sequence_map_add_2_sequences_expanded", + // // "test_sequence_map_add_2_sequences", + // // "test_sequence_map_extract_shapes_expanded", + // // "test_sequence_map_extract_shapes", + // // "test_sequence_map_identity_1_sequence_1_tensor_expanded", + // // "test_sequence_map_identity_1_sequence_1_tensor", + // // "test_sequence_map_identity_1_sequence_expanded", + // // "test_sequence_map_identity_1_sequence", + // // "test_sequence_map_identity_2_sequences_expanded", + // // "test_sequence_map_identity_2_sequences", + // "test_shrink_hard", + // "test_shrink_soft", + "test_sigmoid_example", + "test_sigmoid", + // "test_sign", + // "test_simple_rnn_batchwise", + // "test_simple_rnn_defaults", + // "test_simple_rnn_with_initial_bias", + "test_sin_example", + "test_sin", + "test_sinh_example", + "test_sinh", + // // "test_size_example", + // // "test_size", + // "test_slice_default_axes", + // "test_slice_default_steps", + // "test_slice_end_out_of_bounds", + // "test_slice_neg_steps", + // "test_slice_neg", + // "test_slice_negative_axes", + // "test_slice_start_out_of_bounds", + // "test_slice", + // "test_softmax_axis_0_expanded", + // "test_softmax_axis_0", + // "test_softmax_axis_1_expanded", + // "test_softmax_axis_1", + // "test_softmax_axis_2_expanded", + // "test_softmax_axis_2", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight", + // "test_softmax_cross_entropy_mean_3d_expanded", + // "test_softmax_cross_entropy_mean_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_3d_log_prob", + // "test_softmax_cross_entropy_mean_3d", + // "test_softmax_cross_entropy_mean_expanded", + // "test_softmax_cross_entropy_mean_log_prob_expanded", + // "test_softmax_cross_entropy_mean_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index", + // "test_softmax_cross_entropy_mean_weight_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d", + // "test_softmax_cross_entropy_mean_weight_ignore_index_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index", + // "test_softmax_cross_entropy_mean_weight_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_log_prob", + // "test_softmax_cross_entropy_mean_weight", + // "test_softmax_cross_entropy_mean", + // "test_softmax_cross_entropy_none_expanded", + // "test_softmax_cross_entropy_none_log_prob_expanded", + // "test_softmax_cross_entropy_none_log_prob", + // "test_softmax_cross_entropy_none_weights_expanded", + // "test_softmax_cross_entropy_none_weights_log_prob_expanded", + // "test_softmax_cross_entropy_none_weights_log_prob", + // "test_softmax_cross_entropy_none_weights", + // "test_softmax_cross_entropy_none", + // "test_softmax_cross_entropy_sum_expanded", + // "test_softmax_cross_entropy_sum_log_prob_expanded", + // "test_softmax_cross_entropy_sum_log_prob", + // "test_softmax_cross_entropy_sum", + // "test_softmax_default_axis_expanded", + // "test_softmax_default_axis", + // "test_softmax_example_expanded", + // "test_softmax_example", + // "test_softmax_large_number_expanded", + // "test_softmax_large_number", + // "test_softmax_negative_axis_expanded", + // "test_softmax_negative_axis", + // // "test_softplus_example", + // // "test_softplus", + // // "test_softsign_example", + // // "test_softsign", + // "test_spacetodepth_example", + // "test_spacetodepth", + // // "test_split_equal_parts_1d", + // // "test_split_equal_parts_2d", + // // "test_split_equal_parts_default_axis", + // // "test_split_variable_parts_1d", + // // "test_split_variable_parts_2d", + // // "test_split_variable_parts_default_axis", + // // "test_split_zero_size_splits", + "test_sqrt_example", + "test_sqrt", + // "test_squeeze_negative_axes", + // "test_squeeze", + // // "test_stft_with_window", + // // "test_stft", + // // "test_strnormalizer_export_monday_casesensintive_lower", + // // "test_strnormalizer_export_monday_casesensintive_nochangecase", + // // "test_strnormalizer_export_monday_casesensintive_upper", + // // "test_strnormalizer_export_monday_empty_output", + // // "test_strnormalizer_export_monday_insensintive_upper_twodim", + // // "test_strnormalizer_nostopwords_nochangecase", + "test_sub_bcast", + "test_sub_example", + // "test_sub_uint8", + "test_sub", + // "test_sum_example", + // "test_sum_one_input", + // "test_sum_two_inputs", + "test_tan_example", + "test_tan", + "test_tanh_example", + "test_tanh", + // // "test_tfidfvectorizer_tf_batch_onlybigrams_skip0", + // // "test_tfidfvectorizer_tf_batch_onlybigrams_skip5", + // // "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5", + // // "test_tfidfvectorizer_tf_only_bigrams_skip0", + // // "test_tfidfvectorizer_tf_onlybigrams_levelempty", + // // "test_tfidfvectorizer_tf_onlybigrams_skip5", + // // "test_tfidfvectorizer_tf_uniandbigrams_skip5", + // "test_thresholdedrelu_default", + // "test_thresholdedrelu_example", + // "test_thresholdedrelu", + // // "test_tile_precomputed", + // // "test_tile", + // // "test_top_k_negative_axis", + // // "test_top_k_smallest", + // // "test_top_k", + // // "test_training_dropout_default_mask", + // // "test_training_dropout_default", + // // "test_training_dropout_mask", + // // "test_training_dropout_zero_ratio_mask", + // // "test_training_dropout_zero_ratio", + // // "test_training_dropout", + "test_transpose_all_permutations_0", + "test_transpose_all_permutations_1", + "test_transpose_all_permutations_2", + "test_transpose_all_permutations_3", + "test_transpose_all_permutations_4", + "test_transpose_all_permutations_5", + "test_transpose_default" + // "test_tril_neg", + // "test_tril_one_row_neg", + // "test_tril_out_neg", + // "test_tril_out_pos", + // "test_tril_pos", + // "test_tril_square_neg", + // "test_tril_square", + // "test_tril_zero", + // "test_tril", + // "test_triu_neg", + // "test_triu_one_row", + // "test_triu_out_neg_out", + // "test_triu_out_pos", + // "test_triu_pos", + // "test_triu_square_neg", + // "test_triu_square", + // "test_triu_zero", + // "test_triu", + // // "test_unique_not_sorted_without_axis", + // // "test_unique_sorted_with_axis_3d", + // // "test_unique_sorted_with_axis", + // // "test_unique_sorted_with_negative_axis", + // // "test_unique_sorted_without_axis", + // "test_unsqueeze_axis_0", + // "test_unsqueeze_axis_1", + // "test_unsqueeze_axis_2", + // "test_unsqueeze_axis_3", + // "test_unsqueeze_negative_axes", + // "test_unsqueeze_three_axes", + // "test_unsqueeze_two_axes", + // "test_unsqueeze_unsorted_axes", + // "test_unsqueeze", + // "test_upsample_nearest", + // "test_where_example", + // "test_where_long_example", + // "test_xor_bcast3v1d", + // "test_xor_bcast3v2d", + // "test_xor_bcast4v2d", + // "test_xor_bcast4v3d", + // "test_xor_bcast4v4d", + // "test_xor2d", + // "test_xor3d", + // "test_xor4d" + ], + "ops": [] + }, "wasm": { "onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"], "node": [ diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 2610cbe1d82e6..b4099aa1269aa 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -54,6 +54,9 @@ if (options.globalEnvFlags) { if (flags.wasm?.initTimeout !== undefined) { ort.env.wasm.initTimeout = flags.wasm.initTimeout; } + if (flags.webgpu?.profilingMode !== undefined) { + ort.env.webgpu.profilingMode = flags.webgpu.profilingMode; + } } // Set logging configuration diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 3f1fa0e0f8c81..26ebcbbd6e212 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -3,7 +3,7 @@ import {expect} from 'chai'; import {readFile} from 'fs'; -import {onnx as onnxProto} from 'onnx-proto'; +import {onnx} from 'onnx-proto'; import * as ort from 'onnxruntime-common'; import {extname} from 'path'; import {inspect, promisify} from 'util'; @@ -14,6 +14,7 @@ import {createWebGLContext} from '../lib/onnxjs/backends/webgl/webgl-context-fac import {Logger, Profiler} from '../lib/onnxjs/instrument'; import {Operator} from '../lib/onnxjs/operators'; import {Tensor} from '../lib/onnxjs/tensor'; +import {ProtoUtil} from '../lib/onnxjs/util'; import {base64toBuffer, createMockGraph} from './test-shared'; import {Test} from './test-types'; @@ -25,6 +26,8 @@ const WEBGL_THRESHOLD_ABSOLUTE_ERROR = 1.0e-3; const WEBGL_THRESHOLD_RELATIVE_ERROR = 1.00001; const WEBGL_HALF_FLOAT_THRESHOLD_ABSOLUTE_ERROR = 0.1; const WEBGL_HALF_FLOAT_THRESHOLD_RELATIVE_ERROR = 1.02; +const WEBGPU_THRESHOLD_ABSOLUTE_ERROR = 1.0e-3; +const WEBGPU_THRESHOLD_RELATIVE_ERROR = 1.00001; const WASM_THRESHOLD_ABSOLUTE_ERROR = 1.0e-4; const WASM_THRESHOLD_RELATIVE_ERROR = 1.000001; const ONNXRUNTIME_THRESHOLD_ABSOLUTE_ERROR = 1.0e-3; @@ -54,12 +57,40 @@ async function loadFile(uri: string): Promise { } } -async function loadTensorProto(uriOrData: string|Uint8Array): Promise { +async function loadTensorProto(uriOrData: string|Uint8Array, allowInt64 = false): Promise { const buf = (typeof uriOrData === 'string') ? await loadFile(uriOrData) : uriOrData; - const tensorProto = onnxProto.TensorProto.decode(buf); - const tensor = Tensor.fromProto(tensorProto); + const tensorProto = onnx.TensorProto.decode(buf); + + let tensor: ort.Tensor; + + // by default, we don't allow (u)int64. this is for backward compatibility. + if (allowInt64 && tensorProto && tensorProto.dataType && + ((tensorProto.dataType === onnx.TensorProto.DataType.INT64 || + tensorProto.dataType === onnx.TensorProto.DataType.UINT64))) { + const signed = tensorProto.dataType === onnx.TensorProto.DataType.INT64; + const dataConstructor = signed ? BigInt64Array : BigUint64Array; + const length = tensorProto.rawData.byteLength / 8; + const data = new dataConstructor(length); + + if (tensorProto.rawData && typeof tensorProto.rawData.byteLength === 'number' && + tensorProto.rawData.byteLength > 0) { + const dataSource = + new DataView(tensorProto.rawData.buffer, tensorProto.rawData.byteOffset, tensorProto.rawData.byteLength); + for (let i = 0; i < length; i++) { + data[i] = signed ? dataSource.getBigInt64(i * 8, true) : dataSource.getBigUint64(i * 8, true); + } + } else { + for (let i = 0; i < length; i++) { + data[i] = BigInt((signed ? tensorProto.int64Data : tensorProto.uint64Data)![i].toString()); + } + } + tensor = new ort.Tensor(signed ? 'int64' : 'uint64', data, ProtoUtil.tensorDimsFromProto(tensorProto.dims)); + } else { + const internalTensor = Tensor.fromProto(tensorProto); + tensor = fromInternalTensor(internalTensor); + } // add property 'name' to the tensor object. - const namedTensor = fromInternalTensor(tensor) as unknown as Test.NamedTensor; + const namedTensor = tensor as unknown as Test.NamedTensor; namedTensor.name = tensorProto.name; return namedTensor; } @@ -70,11 +101,13 @@ async function loadMlProto(_uriOrData: string|Uint8Array): Promise = {}; testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); const start = now(); + Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); const outputs = await context.session.run(feeds); const end = now(); + Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); if (context.perfData.count === 0) { context.perfData.firstRun = end - start; } else { @@ -519,7 +557,7 @@ export class OpTestContext { inferenceHandler: InferenceHandler; constructor(protected opTest: Test.OperatorTest) { - this.backendHint = opTest.backend === 'webgl' ? 'webgl' : 'cpu'; + this.backendHint = opTest.backend ?? 'cpu'; } createOperator(): Operator { return initializeOperator( @@ -558,9 +596,14 @@ async function runOpTestcase( testcase.inputs.map(input => createTensor(input.dims, input.type as Tensor.DataType, input.data)); const results = operator.impl(inferenceHandler, inputTensors, operator.context); - // if ('then' in results) { - // results = await results; - // } + + // try async data read. + for (const result of results) { + try { + await result.getData(); + } catch { + } + } results.forEach((output, i) => { Logger.verbose('TestOpRunner', ` Result'${i}': ${output.type}[${output.dims.join(',')}]`); diff --git a/js/web/tsconfig.json b/js/web/tsconfig.json index b9dc974997b28..865c393b5b2b6 100644 --- a/js/web/tsconfig.json +++ b/js/web/tsconfig.json @@ -4,6 +4,7 @@ "module": "CommonJS", "downlevelIteration": true, "declarationDir": "./types", + "typeRoots": ["./node_modules/@webgpu/types", "./node_modules/@types"] }, "include": ["lib", "script", "test"], "exclude": ["lib/wasm/proxy-worker"] diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index d69c6e3b94060..1c842ddced25a 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -57,6 +57,7 @@ function defaultTerserPluginOptions(target) { const DEFAULT_BUILD_DEFS = { DISABLE_WEBGL: false, + DISABLE_WEBGPU: false, DISABLE_WASM: false, DISABLE_WASM_PROXY: false, DISABLE_WASM_THREAD: false, diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index 8eca964a40b69..42cc24a7964d7 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -98,6 +98,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, +#endif + }, + { + kJsExecutionProvider, +#ifdef USE_JS + true, +#else + false, #endif }, { diff --git a/onnxruntime/core/providers/js/allocator.cc b/onnxruntime/core/providers/js/allocator.cc new file mode 100644 index 0000000000000..c1d0aa9abbf6b --- /dev/null +++ b/onnxruntime/core/providers/js/allocator.cc @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/framework/session_state.h" +#include "core/providers/js/allocator.h" + +namespace onnxruntime { +namespace js { + +void* JsCustomAllocator::Alloc(size_t size) { + void* p = EM_ASM_PTR({ return Module.jsepAlloc($0); }, size); + stats_.num_allocs++; + stats_.bytes_in_use += size; + return p; +} + +void JsCustomAllocator::Free(void* p) { + size_t size = (size_t)(void*)EM_ASM_PTR({ return Module.jsepFree($0); }, p); + stats_.bytes_in_use -= size; +} + +void JsCustomAllocator::GetStats(AllocatorStats* stats) { + *stats = stats_; +} + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/allocator.h b/onnxruntime/core/providers/js/allocator.h new file mode 100644 index 0000000000000..6aa8313c01f38 --- /dev/null +++ b/onnxruntime/core/providers/js/allocator.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" +#include "core/framework/ortdevice.h" + +namespace onnxruntime { +namespace js { + +class JsCPUAllocator : public CPUAllocator { + public: + JsCPUAllocator() + : CPUAllocator( + OrtMemoryInfo("JsCPUAllocator", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0), + 0, OrtMemTypeCPU)){}; +}; + +class JsCustomAllocator : public IAllocator { + public: + JsCustomAllocator() + : IAllocator( + OrtMemoryInfo("JsCustomAllocator", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), + 0, OrtMemTypeDefault)) { + } + + virtual void* Alloc(size_t size) override; + virtual void Free(void* p) override; + void GetStats(AllocatorStats* stats) override; + + private: + AllocatorStats stats_; +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc new file mode 100644 index 0000000000000..c62362d90867f --- /dev/null +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/js/data_transfer.h" + +EM_ASYNC_JS(void, jsepDownload, (const void* src_data, void* dst_data, size_t bytes), { + await Module.jsepCopyAsync(src_data, dst_data, bytes); +}); + +namespace onnxruntime { +namespace js { + +bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || + (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::GPU) || + (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); +} + +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + size_t bytes = src.SizeInBytes(); + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // copy from GPU to GPU + EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); + } else { + // copy from CPU to GPU + EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + } + } else /* if (src_device.Type() == OrtDevice::GPU) */ { + // copy from GPU to CPU + jsepDownload(src_data, dst_data, bytes); + } + + return Status::OK(); +} + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/data_transfer.h b/onnxruntime/core/providers/js/data_transfer.h new file mode 100644 index 0000000000000..3dfb19cfde5ac --- /dev/null +++ b/onnxruntime/core/providers/js/data_transfer.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/data_transfer.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace js { + +class DataTransfer : public IDataTransfer { + public: + DataTransfer(){}; + ~DataTransfer(){}; + + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc new file mode 100644 index 0000000000000..d1308da7f888c --- /dev/null +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -0,0 +1,339 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "js_execution_provider.h" + +#include "core/graph/function_utils.h" +#include "core/framework/compute_capability.h" +#include "core/framework/data_transfer_manager.h" +#include "core/framework/kernel_registry.h" +#include "core/providers/shared/node_unit/node_unit.h" +#include "allocator.h" +#include "data_transfer.h" + +namespace onnxruntime { + +namespace js { +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} + +class Memcpy final : public OpKernel { + public: + Memcpy(const OpKernelInfo& info) : OpKernel(info) {} + + Status Compute(OpKernelContext* ctx) const override { + const auto* X = ctx->Input(0); + Tensor* Y = ctx->Output(0, X->Shape()); + return Info().GetDataTransferManager().CopyTensor(*X, *Y); + } +}; + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, MemcpyToHost); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPU, 0) + .ExecQueueId(0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .ExecQueueId(1) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, Start, End, Op)> + +#define KERNEL_CREATE_INFO(Start, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, Start, Op)> + +#define KERNEL_CREATE_INFO_TYPED(Start, type, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, Start, type, Op)> + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Abs); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Abs); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Neg); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Neg); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Floor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Floor); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Ceil); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Ceil); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Reciprocal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Reciprocal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sqrt); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sqrt); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Exp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Exp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 12, Erf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, Sin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, Cos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, Tan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, Asin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, Acos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, Atan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, Sinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, Cosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, Asinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, Acosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, Atanh); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 10, Clip); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, Clip); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, Clip); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Clip); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, Elu); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 12, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Add); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 12, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Sub); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 12, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Mul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 12, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Div); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 11, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Pow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, Pow); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 14, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 15, Shape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 5, 12, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Reshape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, float, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, AveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalAveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalMaxPool); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, float, AveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, AveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, float, MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, float, MaxPool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, float, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, MaxPool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); + +std::unique_ptr RegisterKernels() { + auto kernel_registry = std::make_unique(); + + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, // default entry to avoid the list becoming empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // element-wise operators + // unary - math + KERNEL_CREATE_INFO_VERSIONED(6, 12, Abs), + KERNEL_CREATE_INFO(13, Abs), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Neg), + KERNEL_CREATE_INFO(13, Neg), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Floor), + KERNEL_CREATE_INFO(13, Floor), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Ceil), + KERNEL_CREATE_INFO(13, Ceil), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Reciprocal), + KERNEL_CREATE_INFO(13, Reciprocal), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sqrt), + KERNEL_CREATE_INFO(13, Sqrt), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Exp), + KERNEL_CREATE_INFO(13, Exp), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Erf), + KERNEL_CREATE_INFO(13, Erf), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), + KERNEL_CREATE_INFO(13, Sigmoid), + + KERNEL_CREATE_INFO(7, Sin), + KERNEL_CREATE_INFO(7, Cos), + KERNEL_CREATE_INFO(7, Tan), + KERNEL_CREATE_INFO(7, Asin), + KERNEL_CREATE_INFO(7, Acos), + KERNEL_CREATE_INFO(7, Atan), + KERNEL_CREATE_INFO(9, Sinh), + KERNEL_CREATE_INFO(9, Cosh), + KERNEL_CREATE_INFO(9, Asinh), + KERNEL_CREATE_INFO(9, Acosh), + KERNEL_CREATE_INFO(9, Atanh), + + // activations + KERNEL_CREATE_INFO_VERSIONED(6, 10, Clip), + KERNEL_CREATE_INFO_VERSIONED(11, 11, Clip), + KERNEL_CREATE_INFO_VERSIONED(12, 12, Clip), + KERNEL_CREATE_INFO(13, Clip), + KERNEL_CREATE_INFO(6, Elu), + + // binary - math + KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Add), + KERNEL_CREATE_INFO(14, Add), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub), + KERNEL_CREATE_INFO(14, Sub), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul), + KERNEL_CREATE_INFO(14, Mul), + KERNEL_CREATE_INFO_VERSIONED(7, 12, Div), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Div), + KERNEL_CREATE_INFO(14, Div), + KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow), + KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), + KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), + KERNEL_CREATE_INFO(15, Pow), + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; + + for (auto& function_table_entry : function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_THROW_IF_ERROR(kernel_registry->Register(std::move(info))); + } + } + + return kernel_registry; +} + +} // namespace js + +using namespace js; + +JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) + : IExecutionProvider{kJsExecutionProvider, true} { +} + +// implement RegisterAllocator to test/validate sharing the CPU EP's allocator +void JsExecutionProvider::RegisterAllocator(AllocatorManager& allocator_manager) { + OrtDevice cpu_device{OrtDevice::CPU, OrtDevice::MemType::DEFAULT, DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; + auto cpu_alloc = GetAllocator(OrtMemTypeCPU); + if (!cpu_alloc) { + cpu_alloc = allocator_manager.GetAllocator(OrtMemTypeCPU, cpu_device); + if (!cpu_alloc) { + AllocatorCreationInfo cpuAllocatorCreationInfo([&](int) { + return std::make_unique(); + }); + cpu_alloc = CreateAllocator(cpuAllocatorCreationInfo); + allocator_manager.InsertAllocator(cpu_alloc); + } + InsertAllocator(cpu_alloc); + } + + OrtDevice custom_device{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0}; + auto custom_alloc = GetAllocator(OrtMemTypeDefault); + if (!custom_alloc) { + custom_alloc = allocator_manager.GetAllocator(OrtMemTypeDefault, custom_device); + if (!custom_alloc) { + AllocatorCreationInfo customAllocatorCreationInfo([&](int) { + return std::make_unique(); + }, + 0, false); + custom_alloc = CreateAllocator(customAllocatorCreationInfo); + allocator_manager.InsertAllocator(custom_alloc); + } + InsertAllocator(custom_alloc); + } +} + +std::vector> JsExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const { + return IExecutionProvider::GetCapability(graph, kernel_lookup); +} + +std::shared_ptr JsExecutionProvider::GetKernelRegistry() const { + static std::shared_ptr registry = js::RegisterKernels(); + return registry; +} + +std::unique_ptr JsExecutionProvider::GetDataTransfer() const { + return std::make_unique(); +} + +JsExecutionProvider::~JsExecutionProvider() { +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h new file mode 100644 index 0000000000000..ce8ec53eca1f6 --- /dev/null +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocatormgr.h" +#include "core/framework/execution_provider.h" +#include "core/graph/constants.h" +#include "core/providers/providers.h" + +struct pthreadpool; +namespace onnxruntime { + +namespace js { + +// forward declaration for this EP's namespace. +template +KernelCreateInfo BuildKernelCreateInfo(); + +} // namespace js + +// placeholder for future use. no options currently +struct JsExecutionProviderInfo { + JsExecutionProviderInfo() = default; + + JsExecutionProviderInfo(const ProviderOptions& po) { + } +}; + +class JsExecutionProvider : public IExecutionProvider { + public: + JsExecutionProvider(const JsExecutionProviderInfo& info); + ~JsExecutionProvider() override; + + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_lookup*/) const override; + + std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; + + void RegisterAllocator(AllocatorManager& /*allocator_manager*/) override; + + DataLayout GetPreferredLayout() const override { return DataLayout::NHWC; } + + FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } + + bool ConcurrentRunSupported() const override { return false; } +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_export.cc b/onnxruntime/core/providers/js/js_export.cc new file mode 100644 index 0000000000000..ca0527a2ef89b --- /dev/null +++ b/onnxruntime/core/providers/js/js_export.cc @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "js_export.h" + +#include "core/framework/op_kernel.h" + +const void* JsepOutput(void* context, int index, void* data) { + uint32_t* data_offset = reinterpret_cast(data); + uint32_t dim = *data_offset++; + size_t dim_size = static_cast(dim); + std::vector dims; + dims.reserve(dim_size); + dims.resize(dim_size); + for (size_t i = 0; i < dim_size; i++) { + dims[i] = static_cast(*data_offset++); + } + + LOGF_DEFAULT(VERBOSE, "JsepOutput(%d, %s)", index, onnxruntime::TensorShape(dims).ToString().c_str()); + + auto output = reinterpret_cast(context)->Output(index, onnxruntime::TensorShape(dims)); + auto r = output->DataRaw(); + + LOGF_DEFAULT(VERBOSE, "JsepOutput -- data=%zu", (size_t)(r)); + return r; +} diff --git a/onnxruntime/core/providers/js/js_export.h b/onnxruntime/core/providers/js/js_export.h new file mode 100644 index 0000000000000..bb1eb356cc9d5 --- /dev/null +++ b/onnxruntime/core/providers/js/js_export.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include + +// TODO: Move to api.h + +extern "C" { +const void* EMSCRIPTEN_KEEPALIVE JsepOutput(void* context, int index, void* data); +}; diff --git a/onnxruntime/core/providers/js/js_kernel.cc b/onnxruntime/core/providers/js/js_kernel.cc new file mode 100644 index 0000000000000..34f592814c1e4 --- /dev/null +++ b/onnxruntime/core/providers/js/js_kernel.cc @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "js_kernel.h" + +namespace onnxruntime { +namespace js { +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h new file mode 100644 index 0000000000000..15fce3727b8b3 --- /dev/null +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifndef NDEBUG +#include +#endif + +#include "core/framework/op_kernel.h" +#include "core/providers/js/js_execution_provider.h" + +struct pthreadpool; + +namespace onnxruntime { +namespace js { + +// This macro is defined to bypass the code format from clang-format, which will overwrite "=>" into "= >" +// We can use it to write JS inline code with arrow functions. + +// clang-format off +#define JS_ARROW => +// clang-format on + +#define JSEP_INIT_KERNEL(optype) EM_ASM({ Module.jsepCreateKernel(#optype, $0, undefined); }, this) +#define JSEP_INIT_KERNEL_ATTRIBUTE(optype, attr, ...) EM_ASM({ Module.jsepCreateKernel(#optype, $0, attr); }, this, __VA_ARGS__) + +#define JSEP_KERNEL_IMPL(classname, optype) \ + class classname : public JsKernel { \ + public: \ + classname(const OpKernelInfo& info) : JsKernel(info) { \ + JSEP_INIT_KERNEL(optype); \ + } \ + }; + +#define JSEP_KERNEL_TYPED_IMPL(classname, optype) \ + template \ + class classname : public JsKernel { \ + public: \ + classname(const OpKernelInfo& info) : JsKernel(info) { \ + JSEP_INIT_KERNEL(optype); \ + } \ + }; + +#define JSEP_CLASS_IMPL_ATTRIBUTE(classname, optype, attr_pre, attr, ...) \ + class classname : public JsKernel { \ + public: \ + classname(const OpKernelInfo& info) : JsKernel(info) { \ + attr_pre \ + JSEP_INIT_KERNEL_ATTRIBUTE(optype, attr, __VA_ARGS__); \ + } \ + }; + +#define JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(classname, optype, attr_name, default_value, ...) \ + JSEP_CLASS_IMPL_ATTRIBUTE(classname, optype, , ({#attr_name : $1}), static_cast(info.GetAttrOrDefault(#attr_name, default_value))) + +#define JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(classname, optype, attr_name_1, default_value_1, attr_name_2, default_value_2, ...) \ + JSEP_CLASS_IMPL_ATTRIBUTE(classname, optype, , ({#attr_name_1 : $1, #attr_name_2 : $2}), \ + static_cast(info.GetAttrOrDefault(#attr_name_1, default_value_1)), \ + static_cast(info.GetAttrOrDefault(#attr_name_2, default_value_2))) + +#define JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT(classname, optype, attr_name, ...) \ + JSEP_CLASS_IMPL_ATTRIBUTE(classname, optype, \ + float value; \ + ORT_ENFORCE(info.GetAttr(#attr_name, &value));, \ + , ({#attr_name : $1}), static_cast(value)) + +// TODO: +// class JsMultiProgramKernel : public OpKernel { /* TBD */ }; + +class JsKernel : public OpKernel { + public: + explicit JsKernel(const OpKernelInfo& info) + : OpKernel(info) {} + ~JsKernel() override { + EM_ASM({ Module.jsepReleaseKernel($0); }, this); + } + + void* SerializeKernelContext(OpKernelContext* context, AllocatorPtr alloc) const { + // + // temp_data_format (every item is (u)int32_t): + // context_prt | input_count | [input_data_0] ... [input_data_N-1] + // + // input_data_format: + // type | data_ptr | dim_size | dim[0] ... dim[N-1] + // + size_t temp_data_size = sizeof(size_t) * 2; + for (int i = 0; i < context->InputCount(); i++) { + temp_data_size += sizeof(size_t) * (3 + context->Input(i)->Shape().NumDimensions()); + } + uint32_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); + if (p_serialized_kernel_context == nullptr) { + return nullptr; + } + + p_serialized_kernel_context[0] = reinterpret_cast(context); + p_serialized_kernel_context[1] = static_cast(context->InputCount()); + size_t index = 2; + for (int i = 0; i < context->InputCount(); i++) { + p_serialized_kernel_context[index++] = static_cast(context->Input(i)->GetElementType()); + p_serialized_kernel_context[index++] = reinterpret_cast(context->Input(i)->DataRaw()); + p_serialized_kernel_context[index++] = static_cast(context->Input(i)->Shape().NumDimensions()); + for (size_t d = 0; d < context->Input(i)->Shape().NumDimensions(); d++) { + p_serialized_kernel_context[index++] = static_cast(context->Input(i)->Shape()[d]); + } + } + +#ifndef NDEBUG + std::ostringstream os; + os << "temp data size: " << temp_data_size << ". Data:"; + size_t temp_data_count = temp_data_size >> 2; + for (size_t i = 0; i < temp_data_count; i++) { + os << " " << p_serialized_kernel_context[i]; + } + LOGS_DEFAULT(VERBOSE) << os.str(); +#endif + + return p_serialized_kernel_context; + } + + virtual Status ComputeInternal(OpKernelContext* context) const { + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceCPUAllocator(&alloc)); + + auto p_serialized_kernel_context = SerializeKernelContext(context, alloc); + + int status = EM_ASM_INT({ return Module.jsepRun($0, $1); }, this, p_serialized_kernel_context); + + LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" + << (size_t)(context->Output(0)->DataRaw()) << "."; + + alloc->Free(p_serialized_kernel_context); + + if (status == 0) { + return Status::OK(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to run JSEP kernel"); + } + } + + Status Compute(OpKernelContext* context) const override { + return ComputeInternal(context); + } +}; +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory.cc b/onnxruntime/core/providers/js/js_provider_factory.cc new file mode 100644 index 0000000000000..5b7329a87cf6a --- /dev/null +++ b/onnxruntime/core/providers/js/js_provider_factory.cc @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/error_code_helper.h" +#include "core/providers/js/js_execution_provider.h" +#include "core/providers/js/js_provider_factory_creator.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +struct JsProviderFactory : IExecutionProviderFactory { + JsProviderFactory(const ProviderOptions& provider_options) + : info_{provider_options} { + } + + std::unique_ptr CreateProvider() override { + return std::make_unique(info_); + } + + private: + JsExecutionProviderInfo info_; +}; + +std::shared_ptr JsProviderFactoryCreator::Create( + const ProviderOptions& provider_options) { + return std::make_shared(provider_options); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory_creator.h b/onnxruntime/core/providers/js/js_provider_factory_creator.h new file mode 100644 index 0000000000000..dbabe255c2d7b --- /dev/null +++ b/onnxruntime/core/providers/js/js_provider_factory_creator.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/provider_options.h" +#include "core/providers/providers.h" + +namespace onnxruntime { + +struct JsProviderFactoryCreator { + static std::shared_ptr Create(const ProviderOptions& provider_options); +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/binary.cc b/onnxruntime/core/providers/js/operators/binary.cc new file mode 100644 index 0000000000000..ffad51f7e5af0 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/binary.cc @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + KERNEL_CLASS); + +#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + KERNEL_CLASS); + +JSEP_KERNEL_IMPL(Add, Add) +REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 7, 12, float, Add); +REG_ELEMENTWISE_VERSIONED_KERNEL(Add, 13, 13, float, Add); +REG_ELEMENTWISE_KERNEL(Add, 14, float, Add); + +JSEP_KERNEL_IMPL(Sub, Sub) +REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 7, 12, float, Sub); +REG_ELEMENTWISE_VERSIONED_KERNEL(Sub, 13, 13, float, Sub); +REG_ELEMENTWISE_KERNEL(Sub, 14, float, Sub); + +JSEP_KERNEL_IMPL(Mul, Mul) +REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 7, 12, float, Mul); +REG_ELEMENTWISE_VERSIONED_KERNEL(Mul, 13, 13, float, Mul); +REG_ELEMENTWISE_KERNEL(Mul, 14, float, Mul); + +JSEP_KERNEL_IMPL(Div, Div) +REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 7, 12, float, Div); +REG_ELEMENTWISE_VERSIONED_KERNEL(Div, 13, 13, float, Div); +REG_ELEMENTWISE_KERNEL(Div, 14, float, Div); + +JSEP_KERNEL_IMPL(Pow, Pow) +REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 7, 11, float, Pow); +REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 12, 12, float, Pow); +REG_ELEMENTWISE_VERSIONED_KERNEL(Pow, 13, 14, float, Pow); +REG_ELEMENTWISE_KERNEL(Pow, 15, float, Pow); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv.cc b/onnxruntime/core/providers/js/operators/conv.cc new file mode 100644 index 0000000000000..c7c9f7f7c3f0e --- /dev/null +++ b/onnxruntime/core/providers/js/operators/conv.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" + +#include "conv.h" + +namespace onnxruntime { +namespace js { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Conv, \ + kMSInternalNHWCDomain, \ + 11, \ + T, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Conv); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Conv, \ + kOnnxDomain, \ + 11, \ + T, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Conv); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Conv, \ + kOnnxDomain, \ + 1, 10, \ + T, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Conv); + +REGISTER_KERNEL_TYPED(float) + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h new file mode 100644 index 0000000000000..22f7721276677 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" + +namespace onnxruntime { +namespace js { + +template +class Conv : public JsKernel { + public: + Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { + TensorShapeVector kernel_shape; + if (conv_attrs_.kernel_shape_specified) { + ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); + } + + int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); + + // currently only support Conv 1D/2D. TODO: support Conv3D and other + if (conv_attrs_.dilations.size() == 1 || + (conv_attrs_.kernel_shape_specified && kernel_shape.size() == 1) || + conv_attrs_.strides.size() == 1) { + JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ + "format" : $8 ? "NHWC" : "NCHW", + "auto_pad" : $1, + "dilations" : [$2], + "group" : $3, + "kernel_shape" : [$4], + "pads" : [ $5, $6 ], + "strides" : [$7], + "w_is_const" : () JS_ARROW(!!HEAP8[$9]) + }), + static_cast(conv_attrs_.auto_pad), + static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), + static_cast(conv_attrs_.group), + static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0), + static_cast(conv_attrs_.pads.size() > 0 ? conv_attrs_.pads[0] : 0), + static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), + static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), + static_cast(channels_last), + reinterpret_cast(&w_is_const_)); + } else { + JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ + "format" : $13 ? "NHWC" : "NCHW", + "auto_pad" : $1, + "dilations" : [ $2, $3 ], + "group" : $4, + "kernel_shape" : [ $5, $6 ], + "pads" : [ $7, $8, $9, $10 ], + "strides" : [ $11, $12 ], + "w_is_const" : () JS_ARROW(!!HEAP8[$14]) + }), + static_cast(conv_attrs_.auto_pad), + static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), + static_cast(conv_attrs_.dilations.size() > 1 ? conv_attrs_.dilations[1] : 0), + static_cast(conv_attrs_.group), + static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0), + static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0), + static_cast(conv_attrs_.pads.size() > 0 ? conv_attrs_.pads[0] : 0), + static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), + static_cast(conv_attrs_.pads.size() > 2 ? conv_attrs_.pads[2] : 0), + static_cast(conv_attrs_.pads.size() > 3 ? conv_attrs_.pads[3] : 0), + static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), + static_cast(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0), + static_cast(channels_last), + reinterpret_cast(&w_is_const_)); + } + } + + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /* prepacked_weights */) override { + is_packed = false; + + if (input_idx == 1) { + // Only handle the common case of conv2D + if (tensor.Shape().NumDimensions() != 4 || tensor.SizeInBytes() == 0) { + return Status::OK(); + } + + w_is_const_ = true; + } + + return Status::OK(); + } + + protected: + ConvAttributes conv_attrs_; + bool w_is_const_; + // Tensor w_transposed_; +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/gemm.cc b/onnxruntime/core/providers/js/operators/gemm.cc new file mode 100644 index 0000000000000..f579d62bdfb5f --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gemm.cc @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +#include "gemm.h" + +namespace onnxruntime { +namespace js { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Gemm, \ + kOnnxDomain, \ + 11, \ + T, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gemm); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Gemm, \ + kOnnxDomain, \ + 9, 10, \ + T, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gemm); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Gemm, \ + kOnnxDomain, \ + 7, 8, \ + T, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gemm); + +REGISTER_KERNEL_TYPED(float) + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/gemm.h b/onnxruntime/core/providers/js/operators/gemm.h new file mode 100644 index 0000000000000..27c41788ccfbd --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gemm.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +template +class Gemm : public JsKernel { + public: + Gemm(const OpKernelInfo& info) : JsKernel(info) { + float alpha = info.GetAttrOrDefault("alpha", 1.0f); + float beta = info.GetAttrOrDefault("beta", 1.0f); + int64_t transA = info.GetAttrOrDefault("transA", 0); + int64_t transB = info.GetAttrOrDefault("transB", 0); + + // currently only support Conv2D. TODO: support other + JSEP_INIT_KERNEL_ATTRIBUTE(Gemm, ({ + "alpha" : $1, + "beta" : $2, + "transA" : $3, + "transB" : $4 + }), + static_cast(alpha), + static_cast(beta), + static_cast(transA), + static_cast(transB)); + } +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/matmul.cc b/onnxruntime/core/providers/js/operators/matmul.cc new file mode 100644 index 0000000000000..ddfbb454def07 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/matmul.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +JSEP_KERNEL_IMPL(MatMul, MatMul) + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 12, kJsExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_KERNEL_EX(MatMul, kOnnxDomain, 13, kJsExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/pool.cc b/onnxruntime/core/providers/js/operators/pool.cc new file mode 100644 index 0000000000000..03e6caef7e5b8 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/pool.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +#include "pool.h" + +namespace onnxruntime { +namespace js { + +#define POOLING_KERNEL(op_name, domain, is_channels_last, data_type, pool_type, since_version) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + data_type, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pool); + +#define POOLING_KERNEL_VERSIONED(op_name, domain, is_channels_last, data_type, pool_type, since_version, end_version) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + end_version, \ + data_type, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pool); + +#define POOLING_KERNEL_WITH_INDICES(op_name, domain, is_channels_last, data_type, pool_type, since_version) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + data_type, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); + +#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, domain, is_channels_last, data_type, pool_type, since_version, end_version) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + end_version, \ + data_type, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); + +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, float, AveragePool, 7, 9) +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, float, AveragePool, 10, 10) +POOLING_KERNEL(AveragePool, kOnnxDomain, false, float, AveragePool, 11) +POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, float, AveragePool, 11) +POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, float, AveragePool, 1) +POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, float, AveragePool, 1) + +POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, float, MaxPool<1>, 1, 7) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 8, 9) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 10, 10) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 11, 11) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, float, MaxPool<8>, 11, 11) +POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 12) +POOLING_KERNEL_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, float, MaxPool<8>, 12) +POOLING_KERNEL(GlobalMaxPool, kOnnxDomain, false, float, MaxPool<1>, 1) +POOLING_KERNEL(GlobalMaxPool, kMSInternalNHWCDomain, true, float, MaxPool<1>, 1) + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/pool.h b/onnxruntime/core/providers/js/operators/pool.h new file mode 100644 index 0000000000000..5dbe5d0b8881d --- /dev/null +++ b/onnxruntime/core/providers/js/operators/pool.h @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/providers/cpu/nn/pool_base.h" + +namespace onnxruntime { +namespace js { + +#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \ + "format" : $15 ? "NHWC" : "NCHW", \ + "auto_pad" : $1, \ + "ceil_mode" : $2, \ + "count_include_pad" : $3, \ + "storage_order" : $4, \ + "dilations" : [ $5, $6 ], \ + "kernel_shape" : [ $7, $8 ], \ + "pads" : [ $9, $10, $11, $12 ], \ + "strides" : [ $13, $14 ] \ +}) + +#define POOL_ATTRIBUTES_PARAM_LIST \ + static_cast(pool_attrs_.auto_pad), \ + static_cast(pool_attrs_.ceil_mode), \ + static_cast(pool_attrs_.count_include_pad), \ + static_cast(pool_attrs_.storage_order), \ + static_cast(pool_attrs_.dilations.size() > 0 ? pool_attrs_.dilations[0] : 0), \ + static_cast(pool_attrs_.dilations.size() > 1 ? pool_attrs_.dilations[1] : 0), \ + static_cast(pool_attrs_.kernel_shape.size() > 0 ? pool_attrs_.kernel_shape[0] : 0), \ + static_cast(pool_attrs_.kernel_shape.size() > 1 ? pool_attrs_.kernel_shape[1] : 0), \ + static_cast(pool_attrs_.pads.size() > 0 ? pool_attrs_.pads[0] : 0), \ + static_cast(pool_attrs_.pads.size() > 1 ? pool_attrs_.pads[1] : 0), \ + static_cast(pool_attrs_.pads.size() > 2 ? pool_attrs_.pads[2] : 0), \ + static_cast(pool_attrs_.pads.size() > 3 ? pool_attrs_.pads[3] : 0), \ + static_cast(pool_attrs_.strides.size() > 0 ? pool_attrs_.strides[0] : 0), \ + static_cast(pool_attrs_.strides.size() > 1 ? pool_attrs_.strides[1] : 0), \ + static_cast(is_channels_last) + +#define GLOBAL_POOL_ATTRIBUTES_JS_OBJ_MAPPING ({"format" : $1 ? "NHWC" : "NCHW"}) +#define GLOBAL_POOL_ATTRIBUTES_PARAM_LIST static_cast(is_channels_last) + +template +class Pool : public JsKernel, public PoolBase { + public: + Pool(const OpKernelInfo& info) : JsKernel(info), PoolBase(info) { + if (pool_attrs_.global_pooling) { + if constexpr (PoolType::type == onnxruntime::PoolType::kAveragePool) { + JSEP_INIT_KERNEL_ATTRIBUTE(GlobalAveragePool, GLOBAL_POOL_ATTRIBUTES_JS_OBJ_MAPPING, GLOBAL_POOL_ATTRIBUTES_PARAM_LIST); + } else if constexpr (PoolType::type == onnxruntime::PoolType::kMaxPool) { + JSEP_INIT_KERNEL_ATTRIBUTE(GlobalMaxPool, GLOBAL_POOL_ATTRIBUTES_JS_OBJ_MAPPING, GLOBAL_POOL_ATTRIBUTES_PARAM_LIST); + } else { + // TODO: GlobalLpPool + } + } else { + if constexpr (PoolType::type == onnxruntime::PoolType::kAveragePool) { + JSEP_INIT_KERNEL_ATTRIBUTE(AveragePool, POOL_ATTRIBUTES_JS_OBJ_MAPPING, POOL_ATTRIBUTES_PARAM_LIST); + } else if constexpr (PoolType::type == onnxruntime::PoolType::kMaxPool) { + JSEP_INIT_KERNEL_ATTRIBUTE(MaxPool, POOL_ATTRIBUTES_JS_OBJ_MAPPING, POOL_ATTRIBUTES_PARAM_LIST); + } else { + // TODO: LpPool + } + } + } +}; + +template +class Pool, is_channels_last> final : public Pool, is_channels_last> { + public: + Pool(const OpKernelInfo& info) : Pool, is_channels_last>(info) {} +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/reshape.cc b/onnxruntime/core/providers/js/operators/reshape.cc new file mode 100644 index 0000000000000..d8959c89f3fe7 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/reshape.cc @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "reshape.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + Reshape, + kOnnxDomain, + 14, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 13, 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 5, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/reshape.h b/onnxruntime/core/providers/js/operators/reshape.h new file mode 100644 index 0000000000000..97a294163c748 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/reshape.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/framework/data_transfer_manager.h" +#include "core/providers/cpu/tensor/reshape_helper.h" + +namespace onnxruntime { +namespace js { + +class Reshape final : public JsKernel { + public: + Reshape(const OpKernelInfo& info) : JsKernel(info), + allow_zero_(info.GetAttrOrDefault("allowzero", static_cast(0)) == 1) { + } + + Status Compute(OpKernelContext* context) const override { + // Copy the second input tensor into the shape vector + const Tensor* shapeTensor = context->Input(1); + if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + if (shapeTensor->Shape().NumDimensions() != 1) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "A shape tensor must be a vector tensor, got ", shapeTensor->Shape().NumDimensions(), " dimensions"); + auto data_span = shapeTensor->template DataAsSpan(); + TensorShapeVector shape(data_span.begin(), data_span.end()); + const Tensor* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const TensorShape& X_shape = X->Shape(); + + ReshapeHelper helper(X_shape, shape, allow_zero_); + + Tensor* Y = context->Output(0, TensorShape(shape)); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + + return Status::OK(); + } + + private: + bool allow_zero_; +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/shape_op.cc b/onnxruntime/core/providers/js/operators/shape_op.cc new file mode 100644 index 0000000000000..ec0de3c04a11e --- /dev/null +++ b/onnxruntime/core/providers/js/operators/shape_op.cc @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "core/providers/cpu/tensor/shape_op.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 1, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + // properly force CPU/GPU synch inside the kernel + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 13, 14, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + // properly force CPU/GPU synch inside the kernel + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +ONNX_OPERATOR_KERNEL_EX( + Shape, + kOnnxDomain, + 15, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + // properly force CPU/GPU synch inside the kernel + .OutputMemoryType(OrtMemTypeCPU, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/transpose.cc b/onnxruntime/core/providers/js/operators/transpose.cc new file mode 100644 index 0000000000000..6803e6e7a2a76 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/transpose.cc @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "transpose.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 1, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Transpose); + +ONNX_OPERATOR_KERNEL_EX( + Transpose, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Transpose); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h new file mode 100644 index 0000000000000..f2214438c6fd1 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/transpose.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/common/gsl.h" +#include "core/providers/cpu/tensor/transpose.h" + +namespace onnxruntime { +namespace js { + +class Transpose final : public JsKernel, public TransposeBase { + public: + Transpose(const OpKernelInfo& info) : JsKernel(info), TransposeBase(info) { + std::vector perm; + if (perm_specified_) { + perm.resize(perm_.size()); + perm[0] = gsl::narrow_cast(perm_.size()); + for (size_t i = 0; i < perm_.size(); ++i) { + perm[i] = gsl::narrow_cast(perm_[i]); + } + } + // printf("Transpose: perm_specified_ = %d, perm.size() = %d, perm[0] = %d, perm[1] = %d, perm[2] = %d, perm[3] = %d\n", + // perm_specified_, static_cast(perm.size()), perm[0], perm[1], perm[2], perm[3]); + JSEP_INIT_KERNEL_ATTRIBUTE(Transpose, ({ + "perm" : $1 ? Array.from(HEAP32.subarray($2, $2 + $1)) : [] + }), + gsl::narrow_cast(perm_specified_ ? perm_.size() : 0), + reinterpret_cast(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2); + } +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc new file mode 100644 index 0000000000000..df8c9760c1067 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + KERNEL_CLASS); + +#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + KERNEL_CLASS); + +// math + +JSEP_KERNEL_IMPL(Abs, Abs) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, float, Abs) +JSEP_ELEMENTWISE_KERNEL(Abs, 13, float, Abs) + +JSEP_KERNEL_IMPL(Neg, Neg) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Neg, 6, 12, float, Neg) +JSEP_ELEMENTWISE_KERNEL(Neg, 13, float, Neg) + +JSEP_KERNEL_IMPL(Floor, Floor) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, float, Floor) +JSEP_ELEMENTWISE_KERNEL(Floor, 13, float, Floor) + +JSEP_KERNEL_IMPL(Ceil, Ceil) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, float, Ceil) +JSEP_ELEMENTWISE_KERNEL(Ceil, 13, float, Ceil) + +JSEP_KERNEL_IMPL(Reciprocal, Reciprocal) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, float, Reciprocal) +JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, float, Reciprocal) + +JSEP_KERNEL_IMPL(Sqrt, Sqrt) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, float, Sqrt) +JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, float, Sqrt) + +JSEP_KERNEL_IMPL(Exp, Exp) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, float, Exp) +JSEP_ELEMENTWISE_KERNEL(Exp, 13, float, Exp) + +JSEP_KERNEL_IMPL(Erf, Erf) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, float, Erf) +JSEP_ELEMENTWISE_KERNEL(Erf, 13, float, Erf) + +JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, float, Sigmoid) +JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, float, Sigmoid) + +JSEP_KERNEL_IMPL(Sin, Sin) +JSEP_ELEMENTWISE_KERNEL(Sin, 7, float, Sin) + +JSEP_KERNEL_IMPL(Cos, Cos) +JSEP_ELEMENTWISE_KERNEL(Cos, 7, float, Cos) + +JSEP_KERNEL_IMPL(Tan, Tan) +JSEP_ELEMENTWISE_KERNEL(Tan, 7, float, Tan) + +JSEP_KERNEL_IMPL(Asin, Asin) +JSEP_ELEMENTWISE_KERNEL(Asin, 7, float, Asin) + +JSEP_KERNEL_IMPL(Acos, Acos) +JSEP_ELEMENTWISE_KERNEL(Acos, 7, float, Acos) + +JSEP_KERNEL_IMPL(Atan, Atan) +JSEP_ELEMENTWISE_KERNEL(Atan, 7, float, Atan) + +JSEP_KERNEL_IMPL(Sinh, Sinh) +JSEP_ELEMENTWISE_KERNEL(Sinh, 9, float, Sinh) + +JSEP_KERNEL_IMPL(Cosh, Cosh) +JSEP_ELEMENTWISE_KERNEL(Cosh, 9, float, Cosh) + +JSEP_KERNEL_IMPL(Asinh, Asinh) +JSEP_ELEMENTWISE_KERNEL(Asinh, 9, float, Asinh) + +JSEP_KERNEL_IMPL(Acosh, Acosh) +JSEP_ELEMENTWISE_KERNEL(Acosh, 9, float, Acosh) + +JSEP_KERNEL_IMPL(Atanh, Atanh) +JSEP_ELEMENTWISE_KERNEL(Atanh, 9, float, Atanh) + +// activation + +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, float, ClipV10) +JSEP_KERNEL_IMPL(Clip, Clip) +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2), + Clip); +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2), + Clip); +ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2), + Clip); + +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0) +JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu) + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/symbols.txt b/onnxruntime/core/providers/js/symbols.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 947d2a9b5abc1..261d16a4e8be5 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -46,6 +46,10 @@ #include "core/providers/nnapi/nnapi_provider_factory_creator.h" #endif +#if defined(USE_JS) +#include "core/providers/js/js_provider_factory_creator.h" +#endif + #if defined(USE_OPENVINO) #include "core/providers/openvino/openvino_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 27da418b2c924..de7c8bd6fb101 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -89,6 +89,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(AzureProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "JS") == 0) { +#if defined(USE_JS) + options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); +#else + status = create_not_supported_status(); #endif } else { ORT_UNUSED_PARAMETER(options); diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 48d1a44004f1c..47cb578f7e969 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -363,7 +363,14 @@ int OrtRun(OrtSession* session, const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count, const char** output_names, size_t output_count, ort_tensor_handle_t* outputs, OrtRunOptions* run_options) { - return CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); +#if defined(USE_JS) + EM_ASM({ Module["jsepRunPromise"] = new Promise(function(r) { Module.jsepRunPromiseResolve = r; }); }); +#endif + auto status_code = CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); +#if defined(USE_JS) + EM_ASM({ Module.jsepRunPromiseResolve($0); }, status_code); +#endif + return status_code; } char* OrtEndProfiling(ort_session_handle_t session) { diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js new file mode 100644 index 0000000000000..6c2c3522c7db2 --- /dev/null +++ b/onnxruntime/wasm/js_internal_api.js @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +// init JSEP +Module["jsepInit"] = function (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, run) { + Module.jsepBackend = backend; + Module.jsepAlloc = alloc; + Module.jsepFree = free; + Module.jsepCopy = copy; + Module.jsepCopyAsync = copyAsync; + Module.jsepCreateKernel = createKernel; + Module.jsepReleaseKernel = releaseKernel; + Module.jsepRun = run; +}; diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index ad6c7ce12be23..ea64485c55809 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -484,6 +484,7 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument( "--nnapi_min_api", type=int, help="Minimum Android API level to enable NNAPI, should be no less than 27" ) + parser.add_argument("--use_js", action="store_true", help="Build with JavaScript kernels.") parser.add_argument("--use_qnn", action="store_true", help="Build with QNN support.") parser.add_argument("--qnn_home", help="Path to QNN SDK dir.") parser.add_argument("--use_rknpu", action="store_true", help="Build with RKNPU.") @@ -946,6 +947,7 @@ def generate_build_tree( "-Donnxruntime_USE_ARMNN=" + ("ON" if args.use_armnn else "OFF"), "-Donnxruntime_ARMNN_RELU_USE_CPU=" + ("OFF" if args.armnn_relu else "ON"), "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"), + "-Donnxruntime_USE_JS=" + ("ON" if args.use_js else "OFF"), # Training related flags "-Donnxruntime_ENABLE_NVTX_PROFILE=" + ("ON" if args.enable_nvtx_profile else "OFF"), "-Donnxruntime_ENABLE_TRAINING=" + ("ON" if args.enable_training else "OFF"), diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index a60ace38724e4..1b3ec6af24109 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -19,8 +19,13 @@ parameters: displayName: 'Build static library' type: boolean default: false +- name: BuildJsep + displayName: 'Build JSEP' + type: boolean + default: true - name: ExtraBuildArgs + displayName: 'Extra build command line arguments' type: string @@ -77,6 +82,7 @@ stages: BuildConfig: 'Release' ExtraBuildArgs: '--skip_tests --enable_wasm_api_exception_catching --disable_rtti --use_extensions --cmake_extra_defines onnxruntime_WEBASSEMBLY_DEFAULT_EXTENSION_FLAGS=ON ${{ parameters.ExtraBuildArgs }}' PoolName: ${{ parameters.PoolName }} + BuildJsep: ${{ parameters.BuildJsep }} - ${{ if eq(parameters.BuildStaticLib, 'true') }}: - stage: Build_wasm_Release_static_library diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 18391e26584a2..4ec339bb0fb81 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -22,6 +22,10 @@ parameters: - name: TimeoutInMinutes default: 180 +- name: BuildJsep + type: boolean + default: false + jobs: - job: build_WASM pool: ${{ parameters.PoolName }} @@ -95,12 +99,23 @@ jobs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)\wasm_simd --enable_wasm_simd' workingDirectory: '$(Build.BinariesDirectory)' + - ${{ if eq(parameters.BuildJsep, true) }}: + - task: PythonScript@0 + displayName: 'Build (simd + JSEP)' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)\wasm_simd_jsep --enable_wasm_simd --use_js --target onnxruntime_webassembly' + workingDirectory: '$(Build.BinariesDirectory)' - ${{ if eq(parameters.SkipPublish, false) }}: - script: | copy $(Build.BinariesDirectory)\wasm\${{ parameters.BuildConfig }}\ort-wasm*.* $(Build.ArtifactStagingDirectory) copy $(Build.BinariesDirectory)\wasm_threads\${{ parameters.BuildConfig }}\ort-wasm*.* $(Build.ArtifactStagingDirectory) copy $(Build.BinariesDirectory)\wasm_simd_threads\${{ parameters.BuildConfig }}\ort-wasm*.* $(Build.ArtifactStagingDirectory) copy $(Build.BinariesDirectory)\wasm_simd\${{ parameters.BuildConfig }}\ort-wasm*.* $(Build.ArtifactStagingDirectory) + if exist $(Build.BinariesDirectory)\wasm_simd_jsep ( + copy $(Build.BinariesDirectory)\wasm_simd_jsep\${{ parameters.BuildConfig }}\ort-wasm-simd.wasm $(Build.ArtifactStagingDirectory)\ort-wasm-simd.jsep.wasm + copy $(Build.BinariesDirectory)\wasm_simd_jsep\${{ parameters.BuildConfig }}\ort-wasm-simd.js $(Build.ArtifactStagingDirectory)\ort-wasm-simd.jsep.js + ) displayName: 'Create Artifacts' - ${{ if eq(parameters.SkipPublish, false) }}: - task: PublishPipelineArtifact@0