Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions onnxruntime/core/providers/webgpu/tensor/expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,38 @@ Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size");
if (Inputs()[0].var_type == ProgramVariableDataType::Boolx4) {
const auto& input_indices = shader.AddIndices("input_indices");
const auto& output_indices = shader.AddIndices("output_indices");
if (input_last_dim_divisible_by_4_) {
// The last dims of input shape and output shape are all divisible by 4.
shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n"
<< " let input_offset = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n"
<< output.SetByOffset("global_idx", input.GetByOffset("input_offset"));
} else if (output_last_dim_divisible_by_4_) {
// The last dim of output shape is divisible by 4, and the last dim of input shape is 1.
shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n"
<< " let input_offset = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n"
<< " let value = vec4<bool>(" << input.GetByOffset("input_offset / 4") << "[input_offset % 4]);\n"
<< " " << output.SetByOffset("global_idx", "value");
} else {
shader.MainFunctionBody() << " var output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n"
<< " let input_offset_0 = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n"
<< " output_indices = " << output_indices.OffsetToIndices("global_idx * 4 + 1") << ";\n"
<< " let input_offset_1 = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n"
<< " output_indices = " << output_indices.OffsetToIndices("global_idx * 4 + 2") << ";\n"
<< " let input_offset_2 = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n"
<< " output_indices = " << output_indices.OffsetToIndices("global_idx * 4 + 3") << ";\n"
<< " let input_offset_3 = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n"
<< " let value = vec4<bool>("
<< input.GetByOffset("input_offset_0 / 4") << "[input_offset_0 % 4], "
<< input.GetByOffset("input_offset_1 / 4") << "[input_offset_1 % 4], "
<< input.GetByOffset("input_offset_2 / 4") << "[input_offset_2 % 4], "
<< input.GetByOffset("input_offset_3 / 4") << "[input_offset_3 % 4]);\n"
<< output.SetByOffset("global_idx", "value");
}
return Status::OK();
}
if (input.NumComponents() != output.NumComponents()) {
const auto& output_indices = shader.AddIndices("output_indices");
shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n"
Expand All @@ -40,31 +72,41 @@ Status Expand::ComputeInternal(ComputeContext& context) const {
auto* output_tensor = context.Output(0, output_shape);

bool is_int64 = input_tensor->DataType() == DataTypeImpl::GetType<int64_t>();
const int components_i = (input_shape.IsScalar() || is_int64) ? 1 : input_shape[input_shape.NumDimensions() - 1] % 4 == 0 ? 4
: 1;
const int components_o = (output_shape.IsScalar() || is_int64) ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4
: 1;
uint32_t data_size = onnxruntime::narrow<uint32_t>(output_shape.Size() / components_o);
// Check if either input is boolean
// For boolean inputs, we need to handle them differently in the shader. This is because `bool` is not a valid type in
// storage buffer. We have to use a `u32` to represent 4 boolean values.
bool is_bool = input_tensor->DataType() == DataTypeImpl::GetType<bool>();
bool input_last_dim_divisible_by_4 = (!(input_shape.IsScalar() || is_int64)) && (input_shape[input_shape.NumDimensions() - 1] % 4 == 0);
bool output_last_dim_divisible_by_4 = (!(output_shape.IsScalar() || is_int64)) && (output_shape[output_shape.NumDimensions() - 1] % 4 == 0);
const int components_i = (is_bool || input_last_dim_divisible_by_4) ? 4 : 1;
const int components_o = (is_bool || output_last_dim_divisible_by_4) ? 4 : 1;
uint32_t data_size = onnxruntime::narrow<uint32_t>((output_shape.Size() + components_o - 1) / components_o);
if (data_size == 0) {
return Status::OK();
}
ExpandProgram program{};
program
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}})
.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
ExpandProgram program{input_last_dim_divisible_by_4, output_last_dim_divisible_by_4};
program.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({
{data_size},
});
if (components_i != components_o) {
if (is_bool) {
program.CacheHint(std::to_string(static_cast<int>(input_last_dim_divisible_by_4)), std::to_string(static_cast<int>(output_last_dim_divisible_by_4)))
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, components_i}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, {data_size}, components_o}})
.AddIndices(std::move(input_shape));
} else {
program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}});
}
if (is_bool || components_i != components_o) {
program.AddIndices(std::move(output_shape));
}
return context.RunProgram(program);
}

template <int StartVersion, int EndVersion>
KernelCreateInfo CreateExpandVersionedKernelInfo(bool enable_int64) {
const auto& type_constraints = GetOpTypeConstraints(enable_int64, false);
const auto& type_constraints = GetOpTypeConstraints(enable_int64, true);

KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<Expand>(info);
Expand All @@ -85,7 +127,7 @@ KernelCreateInfo CreateExpandVersionedKernelInfo(bool enable_int64) {

template <int SinceVersion>
KernelCreateInfo CreateExpandKernelInfo(bool enable_int64) {
const auto& type_constraints = GetOpTypeConstraints(enable_int64, false);
const auto& type_constraints = GetOpTypeConstraints(enable_int64, true);

KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<Expand>(info);
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/providers/webgpu/tensor/expand.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@ namespace webgpu {

class ExpandProgram final : public Program<ExpandProgram> {
public:
ExpandProgram() : Program{"Expand"} {}
ExpandProgram(const bool input_last_dim_divisible_by_4, const bool output_last_dim_divisible_by_4) : Program{"Expand"},
input_last_dim_divisible_by_4_{input_last_dim_divisible_by_4},
output_last_dim_divisible_by_4_{output_last_dim_divisible_by_4} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32});

private:
bool input_last_dim_divisible_by_4_;
bool output_last_dim_divisible_by_4_;
};

class Expand final : public WebGpuKernel {
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/webgpu/tensor/flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.TypeConstraint("T", WebGpuSupportedNumberAndBoolTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

Expand All @@ -26,7 +26,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.TypeConstraint("T", WebGpuSupportedNumberAndBoolTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

Expand All @@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.TypeConstraint("T", WebGpuSupportedNumberAndBoolTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

Expand All @@ -48,7 +48,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.TypeConstraint("T", WebGpuSupportedNumberAndBoolTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

Expand All @@ -59,7 +59,7 @@ ONNX_OPERATOR_KERNEL_EX(
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.TypeConstraint("T", WebGpuSupportedNumberAndBoolTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

Expand Down
75 changes: 51 additions & 24 deletions onnxruntime/core/providers/webgpu/tensor/gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,51 @@ namespace onnxruntime {
namespace webgpu {

Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& data = shader.AddInput("data", ShaderUsage::UseIndicesTypeAlias);
const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias);

const auto& data_indices = shader.AddIndices("data_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& output_indices = shader.AddIndices("output_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
bool is_bool = Inputs()[0].var_type == ProgramVariableDataType::Boolx4;
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size")
<< " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"
<< " var indices_indices = input_indices_indices_t(0);\n";
for (int i = 0; i < indices.Rank(); i++) {
shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n";
}
shader.MainFunctionBody() << " var idx = " << indices.GetByIndices("indices_indices") << ";\n"
<< " if (idx < 0) {\n"
<< " idx = idx + input_indices_value_t(" << data.IndicesGet("uniforms.data_shape", axis_) << ");\n"
<< " }\n"
<< " var data_indices : data_indices_t;\n";
for (int i = 0, j = 0; i < data.Rank(); i++) {
if (static_cast<uint32_t>(i) == axis_) {
shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n";
j += indices.Rank();
<< " var idx : input_indices_value_t;\n"
<< " var output_indices : output_indices_indices_t;\n"
<< " var indices_indices : input_indices_indices_t;\n"
<< " var data_indices : data_indices_indices_t;\n"
<< " var value : output_value_t;\n"
<< " var data_offset : u32;\n";
for (int comp = 0; comp < (is_bool ? 4 : 1); comp++) {
shader.MainFunctionBody() << " output_indices = " << output_indices.OffsetToIndices(is_bool ? (std::to_string(comp) + " + 4 * global_idx") : "global_idx") << ";\n";

for (int i = 0; i < indices.Rank(); i++) {
shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output_indices.IndicesGet("output_indices", axis_ + i)) << ";\n";
}

shader.MainFunctionBody() << " idx = " << indices.GetByIndices("indices_indices") << ";\n"
<< " if (idx < 0) {\n"
<< " idx = idx + input_indices_value_t(" << data_indices.IndicesGet("uniforms.data_indices_shape", axis_) << ");\n"
<< " }\n";

for (int i = 0, j = 0; i < data_indices.Rank(); i++) {
if (static_cast<uint32_t>(i) == axis_) {
shader.MainFunctionBody() << " " << data_indices.IndicesSet("data_indices", i, "u32(idx)") << ";\n";
j += indices.Rank();
} else {
shader.MainFunctionBody() << " " << data_indices.IndicesSet("data_indices", i, output_indices.IndicesGet("output_indices", j)) << ";\n";
j++;
}
}

shader.MainFunctionBody() << " data_offset = " << data_indices.IndicesToOffset("data_indices") << ";\n";
if (is_bool) {
shader.MainFunctionBody() << " value[" << comp << "] = " << data.GetByOffset("data_offset / 4") << "[data_offset % 4];\n";
} else {
shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n";
j++;
shader.MainFunctionBody() << " value = " << data.GetByOffset("data_offset") << ";\n";
}
}

shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", data.GetByIndices("data_indices"));
shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", "value");

return Status::OK();
}
Expand All @@ -47,14 +66,22 @@ Status Gather::ComputeInternal(ComputeContext& context) const {
return Status::OK();
}

bool is_bool = p.input_tensor->DataType() == DataTypeImpl::GetType<bool>();
if (is_bool) {
// Shader will pack four bools into one uint, so we consider the types of input and output as vec4<bool>.
data_size = (data_size + 3) / 4;
}

uint32_t axis = static_cast<uint32_t>(p.axis);
GatherProgram program{axis};
program
.AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank},
.AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, (is_bool ? 4 : 1)},
{p.indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank})
.AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank, {data_size}, (is_bool ? 4 : 1)})
.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.CacheHint(std::to_string(axis))
.AddIndices(p.input_tensor->Shape())
.AddIndices(p.output_tensor->Shape())
.AddUniformVariables({{data_size}});
return context.RunProgram(program);
}
Expand All @@ -71,9 +98,9 @@ Status Gather::ComputeInternal(ComputeContext& context) const {
KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()), \
KERNEL_CLASS);

WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberTypes())
WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberTypes())
WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberTypes())
WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberAndBoolTypes())
WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberAndBoolTypes())
WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberAndBoolTypes())

} // namespace webgpu
} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/tensor/unsqueeze.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace webgpu {

template <int StartVersion, int EndVersion>
KernelCreateInfo CreateUnsqueezeVersionedKernelInfo(bool enable_int64) {
const auto& type_constraints = GetOpTypeConstraints(enable_int64, false);
const auto& type_constraints = GetOpTypeConstraints(enable_int64, true);

KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<Unsqueeze>(info);
Expand Down Expand Up @@ -45,7 +45,7 @@ KernelCreateInfo CreateUnsqueezeVersionedKernelInfo(bool enable_int64) {

template <int SinceVersion>
KernelCreateInfo CreateUnsqueezeKernelInfo(bool enable_int64) {
const auto& type_constraints = GetOpTypeConstraints(enable_int64, false);
const auto& type_constraints = GetOpTypeConstraints(enable_int64, true);

KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<Unsqueeze>(info);
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_supported_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ using SupportedFloats =
float,
MLFloat16>;

using SupportedNumberAndBoolTypes =
TypeList<
float,
MLFloat16,
int32_t,
uint32_t,
bool>;

inline const std::vector<MLDataType>& WebGpuSupportedNumberTypes() {
static const std::vector<MLDataType> supportedDataTypes = BuildKernelDefConstraintsFromTypeList<SupportedNumberTypes>();
return supportedDataTypes;
Expand All @@ -30,6 +38,11 @@ inline const std::vector<MLDataType>& WebGpuSupportedFloatTypes() {
return supportedDataTypes;
}

inline const std::vector<MLDataType>& WebGpuSupportedNumberAndBoolTypes() {
static const std::vector<MLDataType> supportedDataTypes = BuildKernelDefConstraintsFromTypeList<SupportedNumberAndBoolTypes>();
return supportedDataTypes;
}

inline const std::vector<MLDataType>& GetOpTypeConstraints(bool enable_int64 = false, bool enable_bool = false) {
static std::vector<MLDataType> base_types{
DataTypeImpl::GetTensorType<MLFloat16>(),
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/test/providers/cpu/nn/flatten_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ TEST_F(FlattenOpTest, Flatten_invalid_axis) {
test_.Run(OpTester::ExpectResult::kExpectFailure, "Invalid value(5) for attribute 'axis'");
}

TEST_F(FlattenOpTest, Flatten_axis2_bool) {
test_.AddAttribute<int64_t>("axis", 2L);
test_.AddInput<bool>("data", {2L, 2L, 2L, 3L}, {false, true, false, true, true, false, true, false, false, false, false, true, true, false, true, true, true, false, false, true, true, true, false, true});
test_.AddOutput<bool>("output", {4L, 6L}, {false, true, false, true, true, false, true, false, false, false, false, true, true, false, true, true, true, false, false, true, true, true, false, true});
test_.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

TEST_F(FlattenOpTest, Flatten_axis3) {
test_.AddAttribute<int64_t>("axis", 3L);
test_.AddInput<float>("data", {2L, 3L, 4L, 5L}, data0_);
Expand Down
Loading
Loading