Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/webgpu/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {

// normalize axis
size_t axis = static_cast<size_t>(HandleNegativeAxis(axis_, input_rank));
bool is_transpose_required = axis < input_rank - 1;
// The `axis` attribute of the opset lower than version 13 describes the axis of the inputs when coerced to 2D,
// the 0th axis most likely describes the batch_size, so transpose is not required on old opset versions.
bool is_transpose_required = axis < input_rank - 1 && opset_ >= 13;

TensorShape transposed_input_shape;
Tensor transposed_input_tensor;
Expand All @@ -179,7 +181,9 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape);
}

const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1];
// The `axis` attribute of the opset lower than version 13 separates input tensor's dimensions into two parts,
// one part is treated as batch size, and the other part is performed by Softmax.
const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : (opset_ >= 13 ? input_shape[input_rank - 1] : input_shape.SizeFromDimension(axis));
const int64_t rows = input_shape.Size() / cols;
const int64_t components = GetMaxComponents(cols);
const auto packed_cols = cols / components;
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webgpu/math/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace webgpu {
class Softmax final : public WebGpuKernel {
public:
Softmax(const OpKernelInfo& info) : WebGpuKernel{info} {
int opset_ = info.node().SinceVersion();
opset_ = info.node().SinceVersion();
int64_t axis;
Status status = info.GetAttr<int64_t>("axis", &axis);

Expand All @@ -33,6 +33,7 @@ class Softmax final : public WebGpuKernel {

private:
int64_t axis_;
int opset_;
};

class SoftmaxProgram final : public Program<SoftmaxProgram> {
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/test/providers/cpu/math/softmax_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,7 @@ TEST(SoftmaxOperator, GH15949_regression_test) {
{0.00032932f, 0.01798029f, 0.9816904f});

// disable TRT as it does not support axis=0 as used by the model
// TODO: Fix the Softmax operator of WebGPU EP.
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider});
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

} // namespace test
Expand Down
Loading