diff --git a/CMakeLists.txt b/CMakeLists.txt index 08e3af884b..5ea1d8e32e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -156,6 +156,352 @@ target_link_libraries(onnxruntime-genai PRIVATE onnxruntime_extensions) target_link_directories(onnxruntime-genai PRIVATE ${ORT_LIB_DIR}) target_link_libraries(onnxruntime-genai PRIVATE Threads::Threads) +# Configure WebGPU support if enabled +if(USE_WEBGPU) + message(STATUS "WebGPU support enabled") + + # Dawn WebGPU library paths + set(DAWN_ROOT "C:/Users/jiajiaqin/workspace/dawn") + + # Set build directory based on CMAKE_BUILD_TYPE + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DAWN_BUILD_DIR "${DAWN_ROOT}/out/Debug") + set(DAWN_LIB_SUBDIR "Debug") + message(STATUS "Using Dawn Debug libraries") + else() + set(DAWN_BUILD_DIR "${DAWN_ROOT}/out/Release") + set(DAWN_LIB_SUBDIR "Release") + message(STATUS "Using Dawn Release libraries") + endif() + + # Check if Dawn is available + if(NOT EXISTS "${DAWN_BUILD_DIR}") + message(FATAL_ERROR "Dawn build directory not found at ${DAWN_BUILD_DIR}. Please build Dawn first.") + endif() + + # Dawn include directories + target_include_directories(onnxruntime-genai PRIVATE + "${DAWN_ROOT}/include" + "${DAWN_BUILD_DIR}/gen/include" + ) + + # Use direct library paths - all Dawn/Tint/Abseil dependencies + set(DAWN_LIBRARIES + # Dawn libraries + "${DAWN_BUILD_DIR}/src/dawn/common/${DAWN_LIB_SUBDIR}/dawn_common.lib" + "${DAWN_BUILD_DIR}/src/dawn/${DAWN_LIB_SUBDIR}/dawn_proc.lib" + "${DAWN_BUILD_DIR}/src/dawn/glfw/${DAWN_LIB_SUBDIR}/dawn_glfw.lib" + "${DAWN_BUILD_DIR}/src/dawn/native/${DAWN_LIB_SUBDIR}/dawn_native.lib" + "${DAWN_BUILD_DIR}/src/dawn/native/dawn_native_objects.dir/${DAWN_LIB_SUBDIR}/dawn_native_objects.lib" + "${DAWN_BUILD_DIR}/src/dawn/platform/${DAWN_LIB_SUBDIR}/dawn_platform.lib" + "${DAWN_BUILD_DIR}/src/dawn/utils/${DAWN_LIB_SUBDIR}/dawn_system_utils.lib" + "${DAWN_BUILD_DIR}/src/dawn/utils/${DAWN_LIB_SUBDIR}/dawn_test_utils.lib" + "${DAWN_BUILD_DIR}/src/dawn/utils/${DAWN_LIB_SUBDIR}/dawn_wgpu_utils.lib" + "${DAWN_BUILD_DIR}/src/dawn/wire/${DAWN_LIB_SUBDIR}/dawn_wire.lib" + + # Tint API and core libraries + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_api.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_api_common.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_cmd_common.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_cmd_fuzz_ir_helpers.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_core.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_core_constant.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_core_intrinsic.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_core_ir.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_core_ir_analysis.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_core_ir_binary.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_core_ir_transform.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_core_ir_type.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_core_type.lib" + + # Tint WGSL support + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_ast.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_inspector.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_intrinsic.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_ir.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_program.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_reader.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_reader_lower.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_reader_parser.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_reader_program_to_ir.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_resolver.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_sem.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_writer.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_writer_ast_printer.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_writer_ir_to_program.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_writer_raise.lib" + # "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_wgsl_writer_syntax_tree_printer.lib" + + # Tint SPIRV support + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_intrinsic.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_ir.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_reader.lib" + # "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_reader_ast_lower.lib" + # "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_reader_ast_parser.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_reader_common.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_reader_lower.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_reader_parser.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_type.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_validate.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_writer.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_writer_common.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_writer_helpers.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_writer_printer.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_spirv_writer_raise.lib" + + # Tint HLSL support + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_hlsl.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_hlsl_intrinsic.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_hlsl_ir.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_hlsl_type.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_hlsl_validate.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_hlsl_writer.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_hlsl_writer_common.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_hlsl_writer_helpers.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_hlsl_writer_printer.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_hlsl_writer_raise.lib" + + # Tint GLSL support + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_glsl.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_glsl_intrinsic.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_glsl_ir.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_glsl_validate.lib" + + # Tint MSL support + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_msl.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_msl_intrinsic.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_msl_ir.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_msl_ir_transform.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_lang_msl_type.lib" + + # Tint utility libraries + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_bytes.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_command.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_containers.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_diagnostic.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_file.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_ice.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_macros.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_math.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_memory.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_protos_ir_fuzz_proto.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_protos_ir_proto.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_rtti.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_strconv.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_symbol.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_system.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_text.lib" + "${DAWN_BUILD_DIR}/src/tint/${DAWN_LIB_SUBDIR}/tint_utils_text_generator.lib" + + # Abseil base libraries + "${DAWN_BUILD_DIR}/third_party/abseil/absl/base/${DAWN_LIB_SUBDIR}/absl_base.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/base/${DAWN_LIB_SUBDIR}/absl_log_severity.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/base/${DAWN_LIB_SUBDIR}/absl_malloc_internal.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/base/${DAWN_LIB_SUBDIR}/absl_raw_logging_internal.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/base/${DAWN_LIB_SUBDIR}/absl_spinlock_wait.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/base/${DAWN_LIB_SUBDIR}/absl_strerror.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/base/${DAWN_LIB_SUBDIR}/absl_throw_delegate.lib" + + # Abseil container libraries + "${DAWN_BUILD_DIR}/third_party/abseil/absl/container/${DAWN_LIB_SUBDIR}/absl_hashtablez_sampler.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/container/${DAWN_LIB_SUBDIR}/absl_raw_hash_set.lib" + + # Abseil hash libraries + "${DAWN_BUILD_DIR}/third_party/abseil/absl/hash/${DAWN_LIB_SUBDIR}/absl_city.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/hash/${DAWN_LIB_SUBDIR}/absl_hash.lib" + # "${DAWN_BUILD_DIR}/third_party/abseil/absl/hash/${DAWN_LIB_SUBDIR}/absl_low_level_hash.lib" + + # Abseil debugging libraries + "${DAWN_BUILD_DIR}/third_party/abseil/absl/debugging/${DAWN_LIB_SUBDIR}/absl_debugging_internal.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/debugging/${DAWN_LIB_SUBDIR}/absl_demangle_internal.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/debugging/${DAWN_LIB_SUBDIR}/absl_examine_stack.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/debugging/${DAWN_LIB_SUBDIR}/absl_failure_signal_handler.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/debugging/${DAWN_LIB_SUBDIR}/absl_leak_check.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/debugging/${DAWN_LIB_SUBDIR}/absl_stacktrace.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/debugging/${DAWN_LIB_SUBDIR}/absl_symbolize.lib" + + # Abseil string libraries + "${DAWN_BUILD_DIR}/third_party/abseil/absl/strings/${DAWN_LIB_SUBDIR}/absl_cord.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/strings/${DAWN_LIB_SUBDIR}/absl_cordz_functions.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/strings/${DAWN_LIB_SUBDIR}/absl_cordz_handle.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/strings/${DAWN_LIB_SUBDIR}/absl_cordz_info.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/strings/${DAWN_LIB_SUBDIR}/absl_cordz_sample_token.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/strings/${DAWN_LIB_SUBDIR}/absl_cord_internal.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/strings/${DAWN_LIB_SUBDIR}/absl_strings.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/strings/${DAWN_LIB_SUBDIR}/absl_strings_internal.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/strings/${DAWN_LIB_SUBDIR}/absl_string_view.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/strings/${DAWN_LIB_SUBDIR}/absl_str_format_internal.lib" + + # Abseil synchronization libraries + "${DAWN_BUILD_DIR}/third_party/abseil/absl/synchronization/${DAWN_LIB_SUBDIR}/absl_graphcycles_internal.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/synchronization/${DAWN_LIB_SUBDIR}/absl_kernel_timeout_internal.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/synchronization/${DAWN_LIB_SUBDIR}/absl_synchronization.lib" + + # Abseil time libraries + "${DAWN_BUILD_DIR}/third_party/abseil/absl/time/${DAWN_LIB_SUBDIR}/absl_civil_time.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/time/${DAWN_LIB_SUBDIR}/absl_time.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/time/${DAWN_LIB_SUBDIR}/absl_time_zone.lib" + + # Abseil status libraries + "${DAWN_BUILD_DIR}/third_party/abseil/absl/status/${DAWN_LIB_SUBDIR}/absl_status.lib" + "${DAWN_BUILD_DIR}/third_party/abseil/absl/status/${DAWN_LIB_SUBDIR}/absl_statusor.lib" + + # Abseil numeric libraries + "${DAWN_BUILD_DIR}/third_party/abseil/absl/numeric/${DAWN_LIB_SUBDIR}/absl_int128.lib" + + # Abseil types libraries + # "${DAWN_BUILD_DIR}/third_party/abseil/absl/types/${DAWN_LIB_SUBDIR}/absl_bad_any_cast_impl.lib" + # "${DAWN_BUILD_DIR}/third_party/abseil/absl/types/${DAWN_LIB_SUBDIR}/absl_bad_optional_access.lib" + # "${DAWN_BUILD_DIR}/third_party/abseil/absl/types/${DAWN_LIB_SUBDIR}/absl_bad_variant_access.lib" + + # SPIRV-Tools libraries + "${DAWN_BUILD_DIR}/third_party/spirv-tools/source/${DAWN_LIB_SUBDIR}/SPIRV-Tools.lib" + "${DAWN_BUILD_DIR}/third_party/spirv-tools/source/opt/${DAWN_LIB_SUBDIR}/SPIRV-Tools-opt.lib" + + # DXC libraries (DirectX Shader Compiler) + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/dxcompiler.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/dxclib.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/dxcvalidator.lib" + + # DXC LLVM core libraries + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMCore.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMSupport.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMAnalysis.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMBitReader.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMBitWriter.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMTransformUtils.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMInstCombine.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMScalarOpts.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMipo.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMVectorize.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMTarget.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMPasses.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMOption.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMProfileData.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMLinker.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMIRReader.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMAsmParser.lib" + + # DXC DXIL libraries + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMDXIL.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMDxilContainer.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMDxilValidation.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMDxilHash.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMDxilCompression.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMDxilRootSignature.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMDxcSupport.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMDxcBindingTable.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMHLSL.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/LLVMMSSupport.lib" + + # DXC Clang libraries + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangBasic.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangLex.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangParse.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangAST.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangSema.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangAnalysis.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangEdit.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangRewrite.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangDriver.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangCodeGen.lib" + "${DAWN_BUILD_DIR}/third_party/dxc/${DAWN_LIB_SUBDIR}/lib/clangFrontend.lib" + ) + + # Link all Dawn/Tint/Abseil libraries + target_link_libraries(onnxruntime-genai PRIVATE ${DAWN_LIBRARIES}) + + # Link Windows system libraries required by Dawn + if(WIN32) + target_link_libraries(onnxruntime-genai PRIVATE + user32.lib + dxguid.lib + kernel32.lib + gdi32.lib + winspool.lib + shell32.lib + ole32.lib + oleaut32.lib + uuid.lib + comdlg32.lib + advapi32.lib + onecore.lib + ) + endif() + + target_compile_definitions(onnxruntime-genai PRIVATE USE_WEBGPU) + list(LENGTH DAWN_LIBRARIES NUM_LIBRARIES) + message(STATUS "WebGPU support configured with ${NUM_LIBRARIES} Dawn/Tint/Abseil libraries") + +else() + # USE_WEBGPU=OFF: Use ONNX Runtime's internal Dawn (shared mode) + # We need Dawn headers AND dawn_proc.lib to resolve wgpu* function symbols + message(STATUS "WebGPU support enabled (using ONNX Runtime's internal Dawn)") + + # Dawn WebGPU library paths + set(DAWN_ROOT "C:/Users/jiajiaqin/workspace/dawn") + + # Set build directory based on CMAKE_BUILD_TYPE + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DAWN_BUILD_DIR "${DAWN_ROOT}/out/Debug") + set(DAWN_LIB_SUBDIR "Debug") + message(STATUS "Using Dawn Debug libraries") + else() + set(DAWN_BUILD_DIR "${DAWN_ROOT}/out/Release") + set(DAWN_LIB_SUBDIR "Release") + message(STATUS "Using Dawn Release libraries") + endif() + + # Check if Dawn is available + if(NOT EXISTS "${DAWN_ROOT}/include") + message(FATAL_ERROR "Dawn header directory not found at ${DAWN_ROOT}/include. Please ensure Dawn source is available.") + endif() + if(NOT EXISTS "${DAWN_BUILD_DIR}") + message(FATAL_ERROR "Dawn build directory not found at ${DAWN_BUILD_DIR}. Please build Dawn first.") + endif() + + # Dawn include directories + target_include_directories(onnxruntime-genai PRIVATE + "${DAWN_ROOT}/include" + "${DAWN_BUILD_DIR}/gen/include" + ) + + # Link dawn_proc.lib and dawn_common.lib + # - dawn_proc.lib: provides the wgpu* C function implementations + # - dawn_common.lib: provides Dawn logging utilities required by dawn_proc + # The DawnProcTable from ONNX Runtime will redirect these to ONNX Runtime's Dawn at runtime + set(DAWN_PROC_LIB "${DAWN_BUILD_DIR}/src/dawn/${DAWN_LIB_SUBDIR}/dawn_proc.lib") + set(DAWN_COMMON_LIB "${DAWN_BUILD_DIR}/src/dawn/common/${DAWN_LIB_SUBDIR}/dawn_common.lib") + + if(NOT EXISTS "${DAWN_PROC_LIB}") + message(FATAL_ERROR "dawn_proc.lib not found at ${DAWN_PROC_LIB}") + endif() + if(NOT EXISTS "${DAWN_COMMON_LIB}") + message(FATAL_ERROR "dawn_common.lib not found at ${DAWN_COMMON_LIB}") + endif() + + target_link_libraries(onnxruntime-genai PRIVATE + "${DAWN_PROC_LIB}" + "${DAWN_COMMON_LIB}" + ) + + # On Windows, also explicitly link onnxruntime.lib to resolve OrtWebGpuGet* symbols + # (ONNXRUNTIME_LIB is set to "onnxruntime.dll" which can't be used for linking) + if(WIN32) + set(ONNXRUNTIME_IMPORT_LIB "${ORT_LIB_DIR}/onnxruntime.lib") + if(NOT EXISTS "${ONNXRUNTIME_IMPORT_LIB}") + message(FATAL_ERROR "onnxruntime.lib not found at ${ONNXRUNTIME_IMPORT_LIB}") + endif() + target_link_libraries(onnxruntime-genai PRIVATE "${ONNXRUNTIME_IMPORT_LIB}") + message(STATUS "Linked onnxruntime.lib for OrtWebGpuGet* symbols: ${ONNXRUNTIME_IMPORT_LIB}") + endif() + + message(STATUS "WebGPU configured with dawn_proc.lib and dawn_common.lib (using ONNX Runtime's internal Dawn)") +endif() + # The genai library itself is always embedded in the shared library list(APPEND ortgenai_embed_libs "$") diff --git a/build.py b/build.py index f4e0841fbe..e148fd3e18 100644 --- a/build.py +++ b/build.py @@ -132,6 +132,8 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescript parser.add_argument("--use_dml", action="store_true", help="Whether to use DML. Default is to not use DML.") + parser.add_argument("--use_webgpu", action="store_true", help="Whether to use WebGPU. Default is to not use WebGPU.") + parser.add_argument("--use_guidance", action="store_true", help="Whether to add guidance support. Default is False.") # The following options are mutually exclusive (cross compiling options such as android, ios, etc.) @@ -527,6 +529,7 @@ def update(args: argparse.Namespace, env: dict[str, str]): f"-DUSE_TRT_RTX={'ON' if args.use_trt_rtx else 'OFF'}", f"-DUSE_ROCM={'ON' if args.use_rocm else 'OFF'}", f"-DUSE_DML={'ON' if args.use_dml else 'OFF'}", + f"-DUSE_WEBGPU={'ON' if args.use_webgpu else 'OFF'}", f"-DENABLE_JAVA={'ON' if args.build_java else 'OFF'}", f"-DBUILD_WHEEL={build_wheel}", f"-DUSE_GUIDANCE={'ON' if args.use_guidance else 'OFF'}", diff --git a/cmake/options.cmake b/cmake/options.cmake index 5071f733c6..7a5c6d9709 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -5,6 +5,7 @@ option(USE_CUDA "Build with CUDA support" ON) option(USE_TRT_RTX "Build with TensorRT-RTX support" OFF) option(USE_ROCM "Build with ROCm support" ON) option(USE_DML "Build with DML support" OFF) +option(USE_WEBGPU "Build with WebGPU support" OFF) option(USE_WINML "Build with WinML support" OFF) option(USE_GUIDANCE "Build with guidance support" OFF) diff --git a/src/config.cpp b/src/config.cpp index e3e115b8e8..74cc66e2ba 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -935,6 +935,13 @@ bool IsGraphCaptureEnabled(const Config::SessionOptions& session_options) { } } else if (provider_options->name == "DML") { return true; + } else if (provider_options->name == "WebGPU") { + for (const auto& value : provider_options->options) { + if (value.first == "enableGraphCapture" && value.second == "1") { + return true; + } + } + return false; } else if (provider_options->name == "NvTensorRtRtx") { for (const auto& value : provider_options->options) { if (value.first == "enable_cuda_graph" && value.second == "1") { diff --git a/src/models/model.cpp b/src/models/model.cpp index 0eab6e0d6d..dbb96d55e8 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -20,6 +20,7 @@ #include "marian.h" #include "decoder_only_pipeline.h" #include "../dml/interface.h" +#include "../webgpu/interface.h" #if defined(_WIN32) #include @@ -577,9 +578,18 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options, opt_it != provider_options.options.end() && opt_it->second == "1") { p_device = GetDeviceInterface(DeviceType::QNN); } - } else if (provider_options.name == "WebGPU") + } else if (provider_options.name == "WebGPU") { p_device = GetDeviceInterface(DeviceType::WEBGPU); - else if (provider_options.name == "OpenVINO") + // Convert provider options to unordered_map for SetWebGPUProvider + std::unordered_map webgpu_options; + for (const auto& option : provider_options.options) { + webgpu_options[option.first] = option.second; + } + + // Use the new SetWebGPUProvider function for enhanced provider setup + SetWebGPUProvider(session_options, webgpu_options); + continue; // Skip the generic AppendExecutionProvider below + } else if (provider_options.name == "OpenVINO") p_device = GetDeviceInterface(DeviceType::OpenVINO); else if (provider_options.name == "VitisAI") { session_options.AddConfigEntry("session.inter_op.allow_spinning", "0"); @@ -871,7 +881,9 @@ Model::Model(std::unique_ptr config) : config_{std::move(config)} { EnsureDeviceOrtInit(*p_device_, *config_); // Only CUDA, TRT-RTX and DML does every input on the device - if (p_device_->GetType() == DeviceType::CUDA || p_device_->GetType() == DeviceType::DML || p_device_->GetType() == DeviceType::NvTensorRtRtx) + // For WebGPU, use device memory only if graph capture is enabled, otherwise use CPU + if (p_device_->GetType() == DeviceType::CUDA || p_device_->GetType() == DeviceType::DML || p_device_->GetType() == DeviceType::NvTensorRtRtx || + (p_device_->GetType() == DeviceType::WEBGPU && IsGraphCaptureEnabled(config_->model.decoder.session_options))) p_device_inputs_ = p_device_; else p_device_inputs_ = GetDeviceInterface(DeviceType::CPU); diff --git a/src/webgpu/interface.cpp b/src/webgpu/interface.cpp index 3b1fd9c256..37795d5b89 100644 --- a/src/webgpu/interface.cpp +++ b/src/webgpu/interface.cpp @@ -5,38 +5,158 @@ #include "../search.h" #include "interface.h" +#ifdef USE_WEBGPU +// USE_WEBGPU path: Build our own Dawn (legacy mode, will be deprecated) +#include "webgpu_update_mask_kernel.h" +#include "webgpu_update_position_ids_kernel.h" +#include "webgpu_cast_kernel.h" + +#include +#include +#include +#else +// Preferred path: Use WebGPU EP's Dawn instance +// This enables graph capture and other WebGPU features without building Dawn twice + +// Forward declare the C API functions to avoid header conflicts +extern "C" { +const void* OrtWebGpuGetDawnProcTable(int context_id); +void* OrtWebGpuGetInstance(int context_id); +void* OrtWebGpuGetDevice(int context_id); +} + +#include "webgpu_update_mask_kernel.h" +#include "webgpu_update_position_ids_kernel.h" +#include "webgpu_cast_kernel.h" + +#include +#include +#endif + +#include +#include +#include +#include +#include + namespace Generators { namespace WebGPU { -static Ort::Allocator* ort_allocator_{}; +Ort::Allocator* ort_allocator_{}; + +// Global Dawn objects shared across all WebGPU operations +// USE_WEBGPU=ON: We own these objects completely (created by us, destroyed by us) +// USE_WEBGPU=OFF: We hold BORROWED references to WebGPU EP's objects +// We use MoveToCHandle() in destructor to release WITHOUT decrementing refcount +wgpu::Device webgpu_device_; +wgpu::Queue webgpu_queue_; +wgpu::Instance webgpu_instance_; + const char* device_label = "WebGPU"; struct WebGPUMemory final : DeviceBuffer { WebGPUMemory(size_t size) : owned_{true} { size_in_bytes_ = size; - p_cpu_ = p_device_ = static_cast(ort_allocator_->Alloc(size_in_bytes_)); + p_device_ = static_cast(ort_allocator_->Alloc(size_in_bytes_)); } WebGPUMemory(void* p, size_t size) : owned_{false} { size_in_bytes_ = size; - p_cpu_ = p_device_ = static_cast(p); + p_device_ = static_cast(p); } ~WebGPUMemory() override { if (owned_) ort_allocator_->Free(p_device_); + if (p_cpu_) + free(p_cpu_); } const char* GetType() const override { return device_label; } - void AllocateCpu() override { throw std::runtime_error("CPU can't access WebGPU memory"); } - void CopyDeviceToCpu() override { throw std::runtime_error("CPU can't access WebGPU memory"); } - void CopyCpuToDevice() override { throw std::runtime_error("CPU can't access WebGPU memory"); } + + void AllocateCpu() override { + if (!p_cpu_) + p_cpu_ = static_cast(malloc(size_in_bytes_)); + } + + void CopyDeviceToCpu() override { +#ifdef USE_WEBGPU + // USE_WEBGPU=ON: webgpu_device_ points to our own Dawn device + if (!webgpu_device_) { + throw std::runtime_error("WebGPU device not initialized"); + } +#else + // USE_WEBGPU=OFF: webgpu_device_ points to WebGPU EP's Dawn device + if (!webgpu_device_) { + throw std::runtime_error("WebGPU device not initialized. Ensure WebGPU EP is enabled and InitOrt() was called."); + } +#endif + AllocateCpu(); + WGPUBuffer src_buf = reinterpret_cast(p_device_); + wgpu::BufferDescriptor desc{}; + desc.size = size_in_bytes_; + desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead; + + auto staging_buffer = webgpu_device_.CreateBuffer(&desc); + auto command_encoder = webgpu_device_.CreateCommandEncoder(); + command_encoder.CopyBufferToBuffer(src_buf, 0, staging_buffer, 0, size_in_bytes_); + wgpu::CommandBuffer command_buffer1 = command_encoder.Finish(); + webgpu_queue_.Submit(1, &command_buffer1); + + webgpu_instance_.WaitAny( + staging_buffer.MapAsync( + wgpu::MapMode::Read, + 0, size_in_bytes_, + wgpu::CallbackMode::WaitAnyOnly, + [&](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status == wgpu::MapAsyncStatus::Success) { + // Get a pointer to the mapped data + const void* mapped_data = staging_buffer.GetConstMappedRange(0, size_in_bytes_); + // Copy the data to our vector + memcpy(p_cpu_, mapped_data, size_in_bytes_); + } else { + std::cerr << "Failed to map staging buffer: " << static_cast(status) << std::endl; + } + }), + UINT64_MAX); + staging_buffer.Unmap(); + staging_buffer.Destroy(); + } + + void CopyCpuToDevice() override { + if (!webgpu_queue_) { + throw std::runtime_error("WebGPU queue not initialized"); + } + WGPUBuffer dst_buf = reinterpret_cast(p_device_); + webgpu_queue_.WriteBuffer(dst_buf, 0, p_cpu_, size_in_bytes_); + } + void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override { - throw std::runtime_error("CPU can't access WebGPU memory"); + if (!webgpu_device_) { + throw std::runtime_error("WebGPU device not initialized"); + } + if (source.GetType() == device_label) { + // WebGPU buffer-to-buffer copy + WGPUBuffer src_buf = reinterpret_cast(source.p_device_); + WGPUBuffer dst_buf = reinterpret_cast(p_device_); + + auto command_encoder = webgpu_device_.CreateCommandEncoder(); + command_encoder.CopyBufferToBuffer(src_buf, begin_source, dst_buf, begin_dest, size_in_bytes); + wgpu::CommandBuffer command_buffer = command_encoder.Finish(); + webgpu_queue_.Submit(1, &command_buffer); + } else { + CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes); + } } void Zero() override { - throw std::runtime_error("Zeroing not implemented for WebGPU memory"); + if (!webgpu_queue_) { + throw std::runtime_error("WebGPU queue not initialized"); + } + // Clear buffer by writing zeros + WGPUBuffer dst_buf = reinterpret_cast(p_device_); + std::vector zero_data(size_in_bytes_, 0); + webgpu_queue_.WriteBuffer(dst_buf, 0, zero_data.data(), size_in_bytes_); } bool owned_; @@ -44,6 +164,33 @@ struct WebGPUMemory final : DeviceBuffer { struct InterfaceImpl : DeviceInterface { InterfaceImpl() { +#ifdef USE_WEBGPU + // USE_WEBGPU=ON: Initialize our own Dawn (legacy mode) + InitializeDawn(); +#else + // USE_WEBGPU=OFF: Will retrieve Dawn from WebGPU EP later + webgpu_ep_context_id_ = 0; // Default WebGPU EP context + dawn_initialized_from_ep_ = false; +#endif + } + + ~InterfaceImpl() override { +#ifndef USE_WEBGPU + // USE_WEBGPU=OFF: Release borrowed references from WebGPU EP + // CRITICAL: Call MoveToCHandle() to prevent destructors from calling Release() + // during static destruction (which happens AFTER ORT's WebGPU EP cleanup) + if (webgpu_device_) { + webgpu_device_.MoveToCHandle(); + } + if (webgpu_queue_) { + webgpu_queue_.MoveToCHandle(); + } + if (webgpu_instance_) { + webgpu_instance_.MoveToCHandle(); + } +#else + // USE_WEBGPU=ON: We own the objects, let their destructors clean up normally +#endif } DeviceType GetType() const override { return DeviceType::WEBGPU; } @@ -51,8 +198,22 @@ struct InterfaceImpl : DeviceInterface { void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override { assert(!ort_allocator_); ort_allocator_ = &allocator; +#ifdef USE_WEBGPU + // USE_WEBGPU=ON: Global objects already initialized in InitializeDawn() + // Nothing additional needed here +#else + // USE_WEBGPU=OFF: Retrieve Dawn from WebGPU EP and initialize global objects + if (!dawn_initialized_from_ep_) { + InitializeDawnFromWebGpuEP(); + } +#endif } +#ifdef USE_WEBGPU + wgpu::Device GetDevice() const { return webgpu_device_; } + wgpu::Instance GetInstance() const { return webgpu_instance_; } +#endif + Ort::Allocator& GetAllocator() override { return *ort_allocator_; } @@ -68,14 +229,394 @@ struct InterfaceImpl : DeviceInterface { std::unique_ptr CreateGreedy(const GeneratorParams& params) override { return std::make_unique(params); } std::unique_ptr CreateBeam(const GeneratorParams& params) override { return std::make_unique(params); } - void Synchronize() override {} // Nothing to do? -}; + // WebGPU-accelerated operations (available in both USE_WEBGPU=ON and OFF modes) + bool UpdateAttentionMask(void* next_mask_data, void* mask_data, int batch_beam_size, int new_kv_length, int total_length, int max_length, bool update_only, ONNXTensorElementDataType type) override { + if (!webgpu_device_) { + // Fall back to CPU implementation if WebGPU context is not initialized + return false; + } + + // Only support static mask updates (update_only = true) + if (!update_only) { + return false; // Fall back to CPU for dynamic mask handling + } + + try { + if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { + webgpu_update_mask_kernel_int32_.UpdateMask( + webgpu_device_, webgpu_queue_, + static_cast(next_mask_data), static_cast(mask_data), + batch_beam_size, new_kv_length, total_length, max_length, update_only); + } else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + webgpu_update_mask_kernel_int64_.UpdateMask( + webgpu_device_, webgpu_queue_, + static_cast(next_mask_data), static_cast(mask_data), + batch_beam_size, new_kv_length, total_length, max_length, update_only); + } else { + return false; // Unsupported data type + } + return true; + } catch (const std::exception&) { + return false; // Fall back to CPU on any error + } + } + + bool Cast(void* input_data, void* output_data, ONNXTensorElementDataType input_type, ONNXTensorElementDataType output_type, size_t element_count) override { + if (!webgpu_device_) { + return false; // Fall back to CPU if WebGPU context is not initialized + } + + if (input_type == output_type) { + throw std::runtime_error("Cast - input and output types are the same"); + } + + try { + // Handle supported type conversions + if (input_type == Ort::TypeToTensorType && output_type == Ort::TypeToTensorType) { + return webgpu_cast_kernel_.CastInt32ToInt64( + webgpu_device_, webgpu_queue_, + input_data, + output_data, + element_count); + } else if (input_type == Ort::TypeToTensorType && output_type == Ort::TypeToTensorType) { + return webgpu_cast_kernel_.CastFloat16ToFloat32( + webgpu_device_, webgpu_queue_, + input_data, + output_data, + element_count); + } else { + // Unsupported type conversion, fall back to CPU + return false; + } + } catch (const std::exception&) { + return false; // Fall back to CPU on any error + } + } + + bool UpdatePositionIds(void* position_ids, int batch_beam_size, int total_length, int new_kv_length, ONNXTensorElementDataType type) override { + if (!webgpu_device_) { + return false; // Fall back to CPU if WebGPU context is not initialized + } + + // Only support batch_beam_size == 1 for graph capture (continuous decoding) + if (batch_beam_size != 1) { + return false; // Fall back to CPU for batch_size > 1 + } + + try { + if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { + webgpu_update_position_ids_kernel_int32_.UpdatePositionIds( + webgpu_device_, webgpu_queue_, + static_cast(position_ids), + batch_beam_size, total_length, new_kv_length); + } else if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + webgpu_update_position_ids_kernel_int64_.UpdatePositionIds( + webgpu_device_, webgpu_queue_, + static_cast(position_ids), + batch_beam_size, total_length, new_kv_length); + } else { + return false; // Unsupported data type + } + return true; + } catch (const std::exception&) { + return false; // Fall back to CPU on any error + } + } + void Synchronize() override {} // Nothing to do + + private: +#ifdef USE_WEBGPU + // USE_WEBGPU=ON: No additional members needed (global objects used directly) +#else + // USE_WEBGPU=OFF: Track initialization state + int webgpu_ep_context_id_; + bool dawn_initialized_from_ep_; + + void InitializeDawnFromWebGpuEP() { + // Retrieve Dawn objects from WebGPU EP + const void* dawn_proc_table_ptr = OrtWebGpuGetDawnProcTable(webgpu_ep_context_id_); + void* instance_ptr = OrtWebGpuGetInstance(webgpu_ep_context_id_); + void* device_ptr = OrtWebGpuGetDevice(webgpu_ep_context_id_); + + if (!device_ptr || !instance_ptr) { + throw std::runtime_error( + "Failed to retrieve Dawn objects from WebGPU EP. " + "Make sure WebGPU EP is initialized first by calling SetWebGPUProvider."); + } + + // Initialize Dawn proc table if available + if (dawn_proc_table_ptr) { + const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table_ptr); + dawnProcSetProcs(dawn_procs); + std::cout << "Initialized Dawn proc table from WebGPU EP" << std::endl; + } + + // Wrap WebGPU EP's objects in C++ wrappers WITHOUT incrementing reference count + // We're borrowing these references - WebGPU EP owns them + // In destructor, we'll use MoveToCHandle() to release wrapper without calling Release() + // This ensures reference count stays unchanged: WebGPU EP is the sole owner + webgpu_device_ = wgpu::Device(static_cast(device_ptr)); + webgpu_instance_ = wgpu::Instance(static_cast(instance_ptr)); + webgpu_queue_ = webgpu_device_.GetQueue(); + + dawn_initialized_from_ep_ = true; + + std::cout << "Successfully initialized from WebGPU EP context " << webgpu_ep_context_id_ << std::endl; + std::cout << " Device: " << device_ptr << std::endl; + std::cout << " Instance: " << instance_ptr << std::endl; + std::cout << " Using borrowed references (WebGPU EP owns lifecycle)" << std::endl; + std::cout << " Graph capture and WebGPU operations are now enabled" << std::endl; + } +#endif + // WebGPU kernels (available in both modes) + WebGPUUpdateMaskKernel webgpu_update_mask_kernel_int32_; + WebGPUUpdateMaskKernel webgpu_update_mask_kernel_int64_; + WebGPUUpdatePositionIdsKernel webgpu_update_position_ids_kernel_int32_; + WebGPUUpdatePositionIdsKernel webgpu_update_position_ids_kernel_int64_; + CastKernel webgpu_cast_kernel_; + +#ifdef USE_WEBGPU + std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const { + std::vector required_features; + constexpr wgpu::FeatureName features[]{ + wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix, + wgpu::FeatureName::TimestampQuery, + wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::Subgroups, + wgpu::FeatureName::BufferMapExtendedUsages, + }; + for (auto feature : features) { + if (adapter.HasFeature(feature)) { + required_features.push_back(feature); + } + } + return required_features; + } + + wgpu::Limits GetRequiredLimits(const wgpu::Adapter& adapter) const { + wgpu::Limits required_limits{}; + wgpu::Limits adapter_limits; + adapter.GetLimits(&adapter_limits); + + required_limits.maxBindGroups = adapter_limits.maxBindGroups; + required_limits.maxComputeWorkgroupStorageSize = adapter_limits.maxComputeWorkgroupStorageSize; + required_limits.maxComputeWorkgroupsPerDimension = adapter_limits.maxComputeWorkgroupsPerDimension; + required_limits.maxStorageBufferBindingSize = adapter_limits.maxStorageBufferBindingSize; + required_limits.maxBufferSize = adapter_limits.maxBufferSize; + required_limits.maxComputeInvocationsPerWorkgroup = adapter_limits.maxComputeInvocationsPerWorkgroup; + required_limits.maxComputeWorkgroupSizeX = adapter_limits.maxComputeWorkgroupSizeX; + required_limits.maxComputeWorkgroupSizeY = adapter_limits.maxComputeWorkgroupSizeY; + required_limits.maxComputeWorkgroupSizeZ = adapter_limits.maxComputeWorkgroupSizeZ; + + return required_limits; + } + + bool InitializeDawn() { + // Set up Dawn instance + dawnProcSetProcs(&dawn::native::GetProcs()); + + // Create Dawn instance + wgpu::InstanceFeatureName required_instance_features[] = {wgpu::InstanceFeatureName::TimedWaitAny}; + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.requiredFeatures = required_instance_features; + instance_desc.requiredFeatureCount = sizeof(required_instance_features) / sizeof(required_instance_features[0]); + webgpu_instance_ = wgpu::CreateInstance(&instance_desc); + + wgpu::RequestAdapterOptions adapter_options = {}; + adapter_options.backendType = wgpu::BackendType::D3D12; + adapter_options.powerPreference = wgpu::PowerPreference::HighPerformance; + + std::vector enableToggleNames; + // enableToggleNames.push_back("dump_shaders"); + enableToggleNames.push_back("use_dxc"); + enableToggleNames.push_back("allow_unsafe_apis"); + // enableToggleNames.push_back("disable_symbol_renaming"); + // enableToggleNames.push_back("disable_workgroup_init"); + // enableToggleNames.push_back("disable_robustness"); + // enableToggleNames.push_back("emit_hlsl_debug_symbols"); + wgpu::DawnTogglesDescriptor toggles = {}; + toggles.enabledToggles = enableToggleNames.data(); + toggles.enabledToggleCount = enableToggleNames.size(); + adapter_options.nextInChain = &toggles; + + // Synchronously create the adapter + wgpu::Adapter w_adapter; + webgpu_instance_.WaitAny( + webgpu_instance_.RequestAdapter( + &adapter_options, wgpu::CallbackMode::WaitAnyOnly, + [&](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char* message) { + if (status != wgpu::RequestAdapterStatus::Success) { + return false; + } + w_adapter = std::move(adapter); + return true; + }), + UINT64_MAX); + if (w_adapter == nullptr) { + return false; + } + + wgpu::AdapterInfo info; + w_adapter.GetInfo(&info); + + // Create device from adapter + wgpu::DeviceDescriptor device_desc = {}; + device_desc.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device&, wgpu::DeviceLostReason reason, wgpu::StringView message) { + const char* reasonName = ""; + switch (reason) { + case wgpu::DeviceLostReason::Unknown: + reasonName = "Unknown"; + break; + case wgpu::DeviceLostReason::Destroyed: + reasonName = "Destroyed"; + break; + case wgpu::DeviceLostReason::FailedCreation: + reasonName = "FailedCreation"; + break; + default: + // DAWN_UNREACHABLE(); + break; + } + std::cerr << "Device lost because of " << reasonName << ": " << message.data; + }); + device_desc.SetUncapturedErrorCallback( + [](const wgpu::Device&, wgpu::ErrorType type, wgpu::StringView message) { + const char* errorTypeName = ""; + switch (type) { + case wgpu::ErrorType::Validation: + errorTypeName = "Validation"; + break; + case wgpu::ErrorType::OutOfMemory: + errorTypeName = "Out of memory"; + break; + case wgpu::ErrorType::Unknown: + errorTypeName = "Unknown"; + break; + default: + // DAWN_UNREACHABLE(); + break; + } + std::cerr << errorTypeName << " error: " << message.data; + }); + + // Configure device toggles (add this after error callbacks, before setting features) + std::vector enabledDeviceToggles = { + "skip_validation", // only use "skip_validation" when ValidationMode is set to "Disabled" + "disable_robustness", + "d3d_disable_ieee_strictness"}; + + std::vector disabledDeviceToggles = { + "lazy_clear_resource_on_first_use", + "timestamp_quantization"}; + + wgpu::DawnTogglesDescriptor deviceToggles = {}; + deviceToggles.enabledToggles = enabledDeviceToggles.data(); + deviceToggles.enabledToggleCount = enabledDeviceToggles.size(); + deviceToggles.disabledToggles = disabledDeviceToggles.data(); + deviceToggles.disabledToggleCount = disabledDeviceToggles.size(); + device_desc.nextInChain = &deviceToggles; + + std::vector required_features = GetAvailableRequiredFeatures(w_adapter); + if (required_features.size() > 0) { + device_desc.requiredFeatures = required_features.data(); + device_desc.requiredFeatureCount = required_features.size(); + } + + wgpu::Limits required_limits = GetRequiredLimits(w_adapter); + device_desc.requiredLimits = &required_limits; + + // Synchronously create the device + webgpu_instance_.WaitAny( + w_adapter.RequestDevice( + &device_desc, wgpu::CallbackMode::WaitAnyOnly, + [&](wgpu::RequestDeviceStatus status, wgpu::Device device, const char* message) { + if (status != wgpu::RequestDeviceStatus::Success) { + return false; + } + webgpu_device_ = std::move(device); + webgpu_queue_ = webgpu_device_.GetQueue(); + return true; + }), + UINT64_MAX); + if (webgpu_device_ == nullptr) { + return false; + } + + // Set up device lost callback + webgpu_device_.SetLoggingCallback([](wgpu::LoggingType type, struct wgpu::StringView message) { + std::cerr << message.data; + }); + + std::cout << "WebGPU device initialized successfully!" << std::endl; + return true; + } +#endif +}; } // namespace WebGPU +std::unique_ptr g_webgpu_device; + +void InitWebGPUInterface() { + if (!g_webgpu_device) + g_webgpu_device = std::make_unique(); +} + +void SetWebGPUProvider(OrtSessionOptions& session_options, const std::unordered_map& provider_options) { + std::vector keys, values; + std::vector key_strings, value_strings; + + // Add all existing provider options + for (const auto& option : provider_options) { + key_strings.push_back(option.first); + value_strings.push_back(option.second); + } + +#ifdef USE_WEBGPU + // USE_WEBGPU=ON: Pass our own Dawn objects to WebGPU EP + auto* webgpu_interface = static_cast(g_webgpu_device.get()); + if (!webgpu_interface) { + throw std::runtime_error("WebGPU interface not initialized"); + } + + key_strings.push_back("dawnProcTable"); + value_strings.push_back(std::to_string(reinterpret_cast(&dawn::native::GetProcs()))); + + key_strings.push_back("webgpuInstance"); + value_strings.push_back(std::to_string(reinterpret_cast(webgpu_interface->GetInstance().Get()))); + + key_strings.push_back("webgpuDevice"); + value_strings.push_back(std::to_string(reinterpret_cast(webgpu_interface->GetDevice().Get()))); + + key_strings.push_back("deviceId"); + value_strings.push_back("1"); +#else + // USE_WEBGPU=OFF: Let WebGPU EP initialize its own Dawn + // We'll retrieve it later in InitOrt() via the exposed APIs + std::cout << "Using WebGPU EP's internal Dawn (shared mode)" << std::endl; +#endif + + // Convert to C-style arrays + for (const auto& key : key_strings) { + keys.push_back(key.c_str()); + } + for (const auto& value : value_strings) { + values.push_back(value.c_str()); + } + + // Append the WebGPU provider + Ort::ThrowOnError(Ort::api->SessionOptionsAppendExecutionProvider( + &session_options, "WebGPU", keys.data(), values.data(), keys.size())); +} + DeviceInterface* GetWebGPUInterface() { - static std::unique_ptr g_device = std::make_unique(); - return g_device.get(); + if (!g_webgpu_device) { + InitWebGPUInterface(); + } + return g_webgpu_device.get(); } } // namespace Generators diff --git a/src/webgpu/interface.h b/src/webgpu/interface.h index 204b4dfedf..dbc53c6a80 100644 --- a/src/webgpu/interface.h +++ b/src/webgpu/interface.h @@ -1,8 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once +#include +#include + +#ifdef USE_WEBGPU +// Forward declarations for WebGPU types - must be in global scope +namespace wgpu { +class Device; +class Instance; +} // namespace wgpu +#endif + +// Forward declaration for ONNX Runtime types +struct OrtSessionOptions; namespace Generators { DeviceInterface* GetWebGPUInterface(); +void InitWebGPUInterface(); +void SetWebGPUProvider(OrtSessionOptions& session_options, const std::unordered_map& provider_options); } // namespace Generators \ No newline at end of file diff --git a/src/webgpu/webgpu_cast_kernel.cpp b/src/webgpu/webgpu_cast_kernel.cpp new file mode 100644 index 0000000000..800e7e7c02 --- /dev/null +++ b/src/webgpu/webgpu_cast_kernel.cpp @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "webgpu_cast_kernel.h" + +#include +#include + +namespace Generators { +namespace WebGPU { + +namespace { + +const char* kInt32ToInt64Shader = R"( +struct Constants { + element_count: u32, +} + +@group(0) @binding(0) var input_data: array; +@group(0) @binding(1) var output_data: array; // i64 as two i32s +@group(0) @binding(2) var constants: Constants; + +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let index = global_id.x; + if (index >= constants.element_count) { + return; + } + + let input_val = input_data[index]; + let output_index = index * 2u; + + // Convert int32 to int64 by sign extension + output_data[output_index] = input_val; // Low 32 bits + output_data[output_index + 1u] = select(0i, -1i, input_val < 0); // High 32 bits (sign extension) +} +)"; + +const char* kFloat16ToFloat32Shader = R"( +enable f16; + +struct Constants { + element_count: u32, +} + +@group(0) @binding(0) var input_data: array; +@group(0) @binding(1) var output_data: array; +@group(0) @binding(2) var constants: Constants; + +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let index = global_id.x; + if (index >= constants.element_count) { + return; + } + + // Convert f16 to f32 directly + output_data[index] = f32(input_data[index]); +} +)"; + +} // namespace + +void CastKernel::CreateInt32ToInt64Pipeline(wgpu::Device device) { + // Create shader module + wgpu::ShaderModuleWGSLDescriptor wgsl_desc{}; + wgsl_desc.code = kInt32ToInt64Shader; + + wgpu::ShaderModuleDescriptor shader_desc{}; + shader_desc.nextInChain = &wgsl_desc; + shader_desc.label = "Int32ToInt64 Cast Shader"; + + auto shader_module = device.CreateShaderModule(&shader_desc); + + // Create compute pipeline + wgpu::ComputePipelineDescriptor pipeline_desc{}; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; + pipeline_desc.label = "Int32ToInt64 Cast Pipeline"; + + int32_to_int64_pipeline_ = device.CreateComputePipeline(&pipeline_desc); + + // Create constants buffer + wgpu::BufferDescriptor constants_desc{}; + constants_desc.size = 16; // Align to 16 bytes for uniform buffer + constants_desc.usage = wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst; + constants_desc.label = "Cast Constants Buffer"; + constants_buffer_ = device.CreateBuffer(&constants_desc); +} + +void CastKernel::CreateFloat16ToFloat32Pipeline(wgpu::Device device) { + // Create shader module + wgpu::ShaderModuleWGSLDescriptor wgsl_desc{}; + wgsl_desc.code = kFloat16ToFloat32Shader; + + wgpu::ShaderModuleDescriptor shader_desc{}; + shader_desc.nextInChain = &wgsl_desc; + shader_desc.label = "Float16ToFloat32 Cast Shader"; + + auto shader_module = device.CreateShaderModule(&shader_desc); + + // Create compute pipeline + wgpu::ComputePipelineDescriptor pipeline_desc{}; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; + pipeline_desc.label = "Float16ToFloat32 Cast Pipeline"; + + float16_to_float32_pipeline_ = device.CreateComputePipeline(&pipeline_desc); +} + +bool CastKernel::CastInt32ToInt64(wgpu::Device device, wgpu::Queue queue, void* input_data, void* output_data, size_t element_count) { + try { + // Ensure pipeline exists for this device + if (!int32_to_int64_pipeline_ || int32_to_int64_pipeline_.Get() == nullptr) { + CreateInt32ToInt64Pipeline(device); + } + WGPUBuffer input_raw = reinterpret_cast(input_data); + WGPUBuffer output_raw = reinterpret_cast(output_data); + wgpu::Buffer input_buffer(input_raw); + wgpu::Buffer output_buffer(output_raw); + + // Update constants + uint32_t constants_data = static_cast(element_count); + queue.WriteBuffer(constants_buffer_, 0, &constants_data, sizeof(constants_data)); + + // Create bind group using the pipeline's bind group layout + // Note: We don't cache this because the buffers change each call + wgpu::BindGroup bind_group; + { + // Get bind group layout from the pipeline (required for default pipeline layout) + auto bind_group_layout = int32_to_int64_pipeline_.GetBindGroupLayout(0); + + // Create bind group + std::vector bind_entries(3); + + bind_entries[0].binding = 0; + bind_entries[0].buffer = input_buffer; + bind_entries[0].size = wgpu::kWholeSize; + + bind_entries[1].binding = 1; + bind_entries[1].buffer = output_buffer; + bind_entries[1].size = wgpu::kWholeSize; + + bind_entries[2].binding = 2; + bind_entries[2].buffer = constants_buffer_; + bind_entries[2].size = sizeof(uint32_t); + + wgpu::BindGroupDescriptor bind_group_desc{}; + bind_group_desc.layout = bind_group_layout; + bind_group_desc.entryCount = bind_entries.size(); + bind_group_desc.entries = bind_entries.data(); + bind_group = device.CreateBindGroup(&bind_group_desc); + } + + // Dispatch compute + auto encoder = device.CreateCommandEncoder(); + auto compute_pass = encoder.BeginComputePass(); + + compute_pass.SetPipeline(int32_to_int64_pipeline_); + compute_pass.SetBindGroup(0, bind_group); + + uint32_t workgroups = (static_cast(element_count) + 255) / 256; + compute_pass.DispatchWorkgroups(workgroups); + compute_pass.End(); + + auto command_buffer = encoder.Finish(); + queue.Submit(1, &command_buffer); + + return true; + } catch (const std::exception& e) { + std::cerr << "CastKernel::CastInt32ToInt64 error: " << e.what() << std::endl; + return false; + } +} + +bool CastKernel::CastFloat16ToFloat32(wgpu::Device device, wgpu::Queue queue, void* input_data, void* output_data, size_t element_count) { + try { + // Ensure pipeline exists for this device + if (!float16_to_float32_pipeline_ || float16_to_float32_pipeline_.Get() == nullptr) { + CreateFloat16ToFloat32Pipeline(device); + } + WGPUBuffer input_raw = reinterpret_cast(input_data); + WGPUBuffer output_raw = reinterpret_cast(output_data); + wgpu::Buffer input_buffer(input_raw); + wgpu::Buffer output_buffer(output_raw); + + // Update constants + uint32_t constants_data = static_cast(element_count); + queue.WriteBuffer(constants_buffer_, 0, &constants_data, sizeof(constants_data)); + + // Create bind group using the pipeline's bind group layout + // Note: We don't cache this because the buffers change each call + wgpu::BindGroup bind_group; + { + // Get bind group layout from the pipeline (required for default pipeline layout) + auto bind_group_layout = float16_to_float32_pipeline_.GetBindGroupLayout(0); + + // Create bind group + std::vector bind_entries(3); + + bind_entries[0].binding = 0; + bind_entries[0].buffer = input_buffer; + bind_entries[0].size = wgpu::kWholeSize; + + bind_entries[1].binding = 1; + bind_entries[1].buffer = output_buffer; + bind_entries[1].size = wgpu::kWholeSize; + + bind_entries[2].binding = 2; + bind_entries[2].buffer = constants_buffer_; + bind_entries[2].size = sizeof(uint32_t); + + wgpu::BindGroupDescriptor bind_group_desc{}; + bind_group_desc.layout = bind_group_layout; + bind_group_desc.entryCount = bind_entries.size(); + bind_group_desc.entries = bind_entries.data(); + bind_group = device.CreateBindGroup(&bind_group_desc); + } + + // Dispatch compute + auto encoder = device.CreateCommandEncoder(); + auto compute_pass = encoder.BeginComputePass(); + + compute_pass.SetPipeline(float16_to_float32_pipeline_); + compute_pass.SetBindGroup(0, bind_group); + + uint32_t workgroups = (static_cast(element_count) + 255) / 256; + compute_pass.DispatchWorkgroups(workgroups); + compute_pass.End(); + + auto command_buffer = encoder.Finish(); + queue.Submit(1, &command_buffer); + + return true; + } catch (const std::exception& e) { + std::cerr << "CastKernel::CastFloat16ToFloat32 error: " << e.what() << std::endl; + return false; + } +} + +} // namespace WebGPU +} // namespace Generators diff --git a/src/webgpu/webgpu_cast_kernel.h b/src/webgpu/webgpu_cast_kernel.h new file mode 100644 index 0000000000..20e7611011 --- /dev/null +++ b/src/webgpu/webgpu_cast_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include + +namespace Generators { +namespace WebGPU { + +// WebGPU kernel for casting between different data types +// Available in both USE_WEBGPU=ON and OFF modes +class CastKernel { + public: + CastKernel() = default; + ~CastKernel() = default; + + // Cast int32 to int64 (most common case for position ids) + bool CastInt32ToInt64(wgpu::Device device, wgpu::Queue queue, void* input_data, void* output_data, size_t element_count); + + // Cast float16 to float32 + bool CastFloat16ToFloat32(wgpu::Device device, wgpu::Queue queue, void* input_data, void* output_data, size_t element_count); + + private: + // Cached resources for int32 to int64 conversion + wgpu::ComputePipeline int32_to_int64_pipeline_; + + // Cached resources for float16 to float32 conversion + wgpu::ComputePipeline float16_to_float32_pipeline_; + + wgpu::Buffer constants_buffer_; + + void CreateInt32ToInt64Pipeline(wgpu::Device device); + void CreateFloat16ToFloat32Pipeline(wgpu::Device device); +}; + +} // namespace WebGPU +} // namespace Generators diff --git a/src/webgpu/webgpu_update_mask_kernel.cpp b/src/webgpu/webgpu_update_mask_kernel.cpp new file mode 100644 index 0000000000..c9aeabbf07 --- /dev/null +++ b/src/webgpu/webgpu_update_mask_kernel.cpp @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "webgpu_update_mask_kernel.h" + +#include +#include +#include +#include +#include +#include + +namespace { + +// WGSL compute shader for updating attention masks (static mask handling only) +const char* kUpdateMaskShaderI64 = R"( +struct Constants { + batch_beam_size: u32, + new_kv_length: u32, + total_length: u32, + max_length: u32, +} + +@group(0) @binding(0) var mask: array>; +@group(0) @binding(1) var constants: Constants; + +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let index = global_id.x; + let total_elements = constants.batch_beam_size * constants.new_kv_length; + + if (index >= total_elements) { + return; + } + + let batch_id = index / constants.new_kv_length; + let seq_id = (index % constants.new_kv_length) + 1u; + + // Update mask_data[batch_id * max_length + total_length - seq_id] = 1 + let mask_index = batch_id * constants.max_length + constants.total_length - seq_id; + mask[mask_index] = vec2(1u, 0u); +} +)"; + +const char* kUpdateMaskShaderI32 = R"( +struct Constants { + batch_beam_size: u32, + new_kv_length: u32, + total_length: u32, + max_length: u32, +} + +@group(0) @binding(0) var mask: array; +@group(0) @binding(1) var constants: Constants; + +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let index = global_id.x; + let total_elements = constants.batch_beam_size * constants.new_kv_length; + + if (index >= total_elements) { + return; + } + + let batch_id = index / constants.new_kv_length; + let seq_id = (index % constants.new_kv_length) + 1u; + + // Update mask_data[batch_id * max_length + total_length - seq_id] = 1 + let mask_index = batch_id * constants.max_length + constants.total_length - seq_id; + mask[mask_index] = i32(1); +} +)"; + +std::string GetShaderForType(bool is_int64) { + if (is_int64) { + return kUpdateMaskShaderI64; + } + return kUpdateMaskShaderI32; +} + +} // namespace + +template +void WebGPUUpdateMaskKernel::InitializePipeline(wgpu::Device device) { + if (initialized_) { + return; + } + + std::string shader_source = GetShaderForType(std::is_same_v); + + wgpu::ShaderModuleWGSLDescriptor wgsl_desc; + wgsl_desc.code = shader_source.c_str(); + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &wgsl_desc; + shader_desc.label = "UpdateMaskShader"; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + // Create bind group layout for static shader: mask buffer (read_write) + constants buffer (uniform) + wgpu::BindGroupLayoutEntry bind_group_layout_entries[2]; + + // mask buffer (read_write) + bind_group_layout_entries[0].binding = 0; + bind_group_layout_entries[0].visibility = wgpu::ShaderStage::Compute; + bind_group_layout_entries[0].buffer.type = wgpu::BufferBindingType::Storage; + bind_group_layout_entries[0].buffer.hasDynamicOffset = false; + + // constants buffer (uniform) + bind_group_layout_entries[1].binding = 1; + bind_group_layout_entries[1].visibility = wgpu::ShaderStage::Compute; + bind_group_layout_entries[1].buffer.type = wgpu::BufferBindingType::Uniform; + bind_group_layout_entries[1].buffer.hasDynamicOffset = false; + + wgpu::BindGroupLayoutDescriptor bind_group_layout_desc; + bind_group_layout_desc.entryCount = 2; + bind_group_layout_desc.entries = bind_group_layout_entries; + + wgpu::BindGroupLayout bind_group_layout = device.CreateBindGroupLayout(&bind_group_layout_desc); + + // Create pipeline layout + wgpu::PipelineLayoutDescriptor pipeline_layout_desc; + pipeline_layout_desc.bindGroupLayoutCount = 1; + pipeline_layout_desc.bindGroupLayouts = &bind_group_layout; + + wgpu::PipelineLayout pipeline_layout = device.CreatePipelineLayout(&pipeline_layout_desc); + + // Create compute pipeline + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; + + pipeline_ = device.CreateComputePipeline(&pipeline_desc); + + // Create and cache the constants buffer + wgpu::BufferDescriptor constants_buffer_desc; + constants_buffer_desc.size = sizeof(Constants); + constants_buffer_desc.usage = wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst; + constants_buffer_ = device.CreateBuffer(&constants_buffer_desc); + + initialized_ = true; +} + +template +void WebGPUUpdateMaskKernel::UpdateMask( + wgpu::Device device, + wgpu::Queue queue, + T* next_mask_data, + T* mask_data, + int batch_beam_size, + int new_kv_length, + int total_length, + int max_length, + bool update_only) { + // Only handle static mask updates (update_only = true) + if (!update_only) { + throw std::runtime_error("Dynamic mask handling not supported in WebGPU implementation"); + } + + // Initialize the pipeline and cached resources + InitializePipeline(device); + + // Create constants and upload to cached buffer + Constants constants; + constants.batch_beam_size = static_cast(batch_beam_size); + constants.new_kv_length = static_cast(new_kv_length); + constants.total_length = static_cast(total_length); + constants.max_length = static_cast(max_length); + + // Upload constants to cached buffer + queue.WriteBuffer(constants_buffer_, 0, &constants, sizeof(Constants)); + + // Create cached bind group if not already created + if (!bind_group_initialized_) { + // Get the existing mask buffer (already on GPU) + // Create a wgpu::Buffer from the existing WGPUBuffer handle without taking ownership + WGPUBuffer raw_mask_buffer = reinterpret_cast(mask_data); + wgpu::Buffer mask_buffer = wgpu::Buffer(raw_mask_buffer); + + // Create bind group for this specific mask buffer and cache it + wgpu::BindGroupEntry bind_group_entries[2]; + bind_group_entries[0].binding = 0; + bind_group_entries[0].buffer = mask_buffer; + bind_group_entries[0].size = mask_buffer.GetSize(); + + bind_group_entries[1].binding = 1; + bind_group_entries[1].buffer = constants_buffer_; + bind_group_entries[1].size = sizeof(Constants); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = pipeline_.GetBindGroupLayout(0); + bind_group_desc.entryCount = 2; + bind_group_desc.entries = bind_group_entries; + + bind_group_ = device.CreateBindGroup(&bind_group_desc); + bind_group_initialized_ = true; + } + + uint32_t workgroup_count = (batch_beam_size * new_kv_length + 255) / 256; + + // Create command encoder and dispatch + wgpu::CommandEncoderDescriptor encoder_desc; + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(&encoder_desc); + + wgpu::ComputePassDescriptor compute_pass_desc; + wgpu::ComputePassEncoder compute_pass = encoder.BeginComputePass(&compute_pass_desc); + + compute_pass.SetPipeline(pipeline_); + compute_pass.SetBindGroup(0, bind_group_); + compute_pass.DispatchWorkgroups(workgroup_count); + compute_pass.End(); + + wgpu::CommandBuffer command_buffer = encoder.Finish(); + queue.Submit(1, &command_buffer); +} + +template +std::string WebGPUUpdateMaskKernel::GetShaderSource() { + return GetShaderForType(std::is_same_v); +} + +// Explicit template instantiations +template class WebGPUUpdateMaskKernel; +template class WebGPUUpdateMaskKernel; diff --git a/src/webgpu/webgpu_update_mask_kernel.h b/src/webgpu/webgpu_update_mask_kernel.h new file mode 100644 index 0000000000..0b440e72ef --- /dev/null +++ b/src/webgpu/webgpu_update_mask_kernel.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include "../models/onnxruntime_api.h" + +// WebGPU kernel for updating attention masks efficiently on GPU +// Available in both USE_WEBGPU=ON and OFF modes +template +class WebGPUUpdateMaskKernel { + public: + WebGPUUpdateMaskKernel() = default; + ~WebGPUUpdateMaskKernel() = default; + + // Update attention mask using WebGPU compute shader + void UpdateMask( + wgpu::Device device, + wgpu::Queue queue, + T* next_mask_data, + T* mask_data, + int batch_beam_size, + int new_kv_length, + int total_length, + int max_length, + bool update_only); + + private: + struct Constants { + uint32_t batch_beam_size; + uint32_t new_kv_length; + uint32_t total_length; + uint32_t max_length; + }; + + wgpu::ComputePipeline pipeline_; + wgpu::Buffer constants_buffer_; + wgpu::BindGroup bind_group_; + bool initialized_ = false; + bool bind_group_initialized_ = false; + + void InitializePipeline(wgpu::Device device); + std::string GetShaderSource(); +}; diff --git a/src/webgpu/webgpu_update_position_ids_kernel.cpp b/src/webgpu/webgpu_update_position_ids_kernel.cpp new file mode 100644 index 0000000000..21ca277406 --- /dev/null +++ b/src/webgpu/webgpu_update_position_ids_kernel.cpp @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "webgpu_update_position_ids_kernel.h" + +#include +#include +#include +#include +#include + +namespace { + +// WGSL compute shader for updating position IDs (continuous decoding only) +const char* kUpdatePositionIdsShaderI64 = R"( +struct Constants { + total_length: u32, + new_kv_length: u32, +} + +@group(0) @binding(0) var positions: array>; +@group(0) @binding(1) var constants: Constants; + +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let index = global_id.x; + if (index >= constants.new_kv_length) { + return; + } + + // Calculate position IDs for continuous decoding: positions[i] = i + total_length - new_kv_length + let pos_value = index + constants.total_length - constants.new_kv_length; + positions[index] = vec2(pos_value, select(0u, 0xFFFFFFFFu, i32(pos_value) < 0)); +} +)"; + +const char* kUpdatePositionIdsShaderI32 = R"( +struct Constants { + total_length: u32, + new_kv_length: u32, +} + +@group(0) @binding(0) var positions: array; +@group(0) @binding(1) var constants: Constants; + +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let index = global_id.x; + if (index >= constants.new_kv_length) { + return; + } + + // Calculate position IDs for continuous decoding: positions[i] = i + total_length - new_kv_length + positions[index] = i32(index + constants.total_length - constants.new_kv_length); +} +)"; + +std::string GetShaderForType(bool is_int64) { + if (is_int64) { + return kUpdatePositionIdsShaderI64; + } + return kUpdatePositionIdsShaderI32; +} + +} // namespace + +template +void WebGPUUpdatePositionIdsKernel::InitializePipeline(wgpu::Device device) { + if (initialized_) { + return; + } + + std::string shader_source = GetShaderForType(std::is_same_v); + + wgpu::ShaderModuleWGSLDescriptor wgsl_desc; + wgsl_desc.code = shader_source.c_str(); + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &wgsl_desc; + shader_desc.label = "UpdatePositionIdsShader"; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + // Create bind group layout: positions buffer (read_write) + constants buffer (uniform) + wgpu::BindGroupLayoutEntry bind_group_layout_entries[2]; + + // positions buffer (read_write) + bind_group_layout_entries[0].binding = 0; + bind_group_layout_entries[0].visibility = wgpu::ShaderStage::Compute; + bind_group_layout_entries[0].buffer.type = wgpu::BufferBindingType::Storage; + bind_group_layout_entries[0].buffer.hasDynamicOffset = false; + + // constants buffer (uniform) + bind_group_layout_entries[1].binding = 1; + bind_group_layout_entries[1].visibility = wgpu::ShaderStage::Compute; + bind_group_layout_entries[1].buffer.type = wgpu::BufferBindingType::Uniform; + bind_group_layout_entries[1].buffer.hasDynamicOffset = false; + + wgpu::BindGroupLayoutDescriptor bind_group_layout_desc; + bind_group_layout_desc.entryCount = 2; + bind_group_layout_desc.entries = bind_group_layout_entries; + + wgpu::BindGroupLayout bind_group_layout = device.CreateBindGroupLayout(&bind_group_layout_desc); + + // Create pipeline layout + wgpu::PipelineLayoutDescriptor pipeline_layout_desc; + pipeline_layout_desc.bindGroupLayoutCount = 1; + pipeline_layout_desc.bindGroupLayouts = &bind_group_layout; + + wgpu::PipelineLayout pipeline_layout = device.CreatePipelineLayout(&pipeline_layout_desc); + + // Create compute pipeline + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; + + pipeline_ = device.CreateComputePipeline(&pipeline_desc); + + // Create and cache the constants buffer + wgpu::BufferDescriptor constants_buffer_desc; + constants_buffer_desc.size = sizeof(Constants); + constants_buffer_desc.usage = wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst; + constants_buffer_ = device.CreateBuffer(&constants_buffer_desc); + + initialized_ = true; +} + +template +void WebGPUUpdatePositionIdsKernel::UpdatePositionIds( + wgpu::Device device, + wgpu::Queue queue, + T* position_ids, + int batch_beam_size, + int total_length, + int new_kv_length) { + // Only support batch_beam_size == 1 for graph capture (continuous decoding) + if (batch_beam_size != 1) { + throw std::runtime_error("WebGPU UpdatePositionIds only supports batch_beam_size == 1"); + } + + // Initialize the pipeline and cached resources + InitializePipeline(device); + + // Create constants and upload to cached buffer + Constants constants; + constants.total_length = static_cast(total_length); + constants.new_kv_length = static_cast(new_kv_length); + + // Upload constants to cached buffer + queue.WriteBuffer(constants_buffer_, 0, &constants, sizeof(Constants)); + + // Create cached bind group if not already created + if (!bind_group_initialized_) { + // Get the existing position IDs buffer (already on GPU) + // Create a wgpu::Buffer from the existing WGPUBuffer handle without taking ownership + WGPUBuffer raw_positions_buffer = reinterpret_cast(position_ids); + wgpu::Buffer positions_buffer = wgpu::Buffer(raw_positions_buffer); + + // Create bind group for this specific positions buffer and cache it + wgpu::BindGroupEntry bind_group_entries[2]; + bind_group_entries[0].binding = 0; + bind_group_entries[0].buffer = positions_buffer; + bind_group_entries[0].size = positions_buffer.GetSize(); + + bind_group_entries[1].binding = 1; + bind_group_entries[1].buffer = constants_buffer_; + bind_group_entries[1].size = sizeof(Constants); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = pipeline_.GetBindGroupLayout(0); + bind_group_desc.entryCount = 2; + bind_group_desc.entries = bind_group_entries; + + bind_group_ = device.CreateBindGroup(&bind_group_desc); + bind_group_initialized_ = true; + } + + uint32_t workgroup_count = (new_kv_length + 255) / 256; + + // Create command encoder and dispatch + wgpu::CommandEncoderDescriptor encoder_desc; + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(&encoder_desc); + + wgpu::ComputePassDescriptor compute_pass_desc; + wgpu::ComputePassEncoder compute_pass = encoder.BeginComputePass(&compute_pass_desc); + + compute_pass.SetPipeline(pipeline_); + compute_pass.SetBindGroup(0, bind_group_); + compute_pass.DispatchWorkgroups(workgroup_count); + compute_pass.End(); + + wgpu::CommandBuffer command_buffer = encoder.Finish(); + queue.Submit(1, &command_buffer); +} + +template +std::string WebGPUUpdatePositionIdsKernel::GetShaderSource() { + return GetShaderForType(std::is_same_v); +} + +// Explicit template instantiations +template class WebGPUUpdatePositionIdsKernel; +template class WebGPUUpdatePositionIdsKernel; diff --git a/src/webgpu/webgpu_update_position_ids_kernel.h b/src/webgpu/webgpu_update_position_ids_kernel.h new file mode 100644 index 0000000000..2cb05e5e9a --- /dev/null +++ b/src/webgpu/webgpu_update_position_ids_kernel.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include "../models/onnxruntime_api.h" + +// WebGPU kernel for updating position IDs efficiently on GPU +// Available in both USE_WEBGPU=ON and OFF modes +template +class WebGPUUpdatePositionIdsKernel { + public: + WebGPUUpdatePositionIdsKernel() = default; + ~WebGPUUpdatePositionIdsKernel() = default; + + // Update position IDs using WebGPU compute shader (continuous decoding only) + void UpdatePositionIds( + wgpu::Device device, + wgpu::Queue queue, + T* position_ids, + int batch_beam_size, + int total_length, + int new_kv_length); + + private: + struct Constants { + uint32_t total_length; + uint32_t new_kv_length; + }; + + wgpu::ComputePipeline pipeline_; + wgpu::Buffer constants_buffer_; + wgpu::BindGroup bind_group_; + bool initialized_ = false; + bool bind_group_initialized_ = false; + + void InitializePipeline(wgpu::Device device); + std::string GetShaderSource(); +};