Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCM][MIOpen] add support for softmax and log_softmax with MIOpen #8543

Merged
merged 1 commit into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions python/tvm/contrib/miopen.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,55 @@ def conv2d_forward(
),
name="y",
)


def softmax(x, axis=-1):
"""Compute softmax with MIOpen

Parameters
----------
x : tvm.te.Tensor
The input tensor

axis : int
The axis to compute softmax over

Returns
-------
ret : tvm.te.Tensor
The result tensor
"""
return te.extern(
x.shape,
[x],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.miopen.softmax.forward", ins[0], outs[0], axis
),
name="y",
)


def log_softmax(x, axis=-1):
"""Compute log softmax with MIOpen

Parameters
----------
x : tvm.te.Tensor
The input tensor

axis : int
The axis to compute log softmax over

Returns
-------
ret : tvm.te.Tensor
The result tensor
"""
return te.extern(
x.shape,
[x],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.miopen.log_softmax.forward", ins[0], outs[0], axis
),
name="y",
)
39 changes: 39 additions & 0 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm.contrib.thrust import can_use_rocthrust
from tvm.contrib import miopen

from .generic import *
from .. import op as _op
Expand Down Expand Up @@ -304,3 +305,41 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
plevel=15,
)
return strategy


@softmax_strategy.register(["rocm"])
def softmax_strategy_rocm(attrs, inputs, out_type, target):
"""rocm strategy for softmax"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="softmax.rocm",
)
if "miopen" in target.libs:
strategy.add_implementation(
wrap_compute_softmax(miopen.softmax),
wrap_topi_schedule(topi.generic.schedule_extern),
name="softmax.miopen",
plevel=15,
)
return strategy


@log_softmax_strategy.register(["rocm"])
def log_softmax_strategy_rocm(attrs, inputs, out_type, target):
"""rocm strategy for log softmax"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.log_softmax),
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="log_softmax.rocm",
)
if "miopen" in target.libs:
strategy.add_implementation(
wrap_compute_softmax(miopen.log_softmax),
wrap_topi_schedule(topi.generic.schedule_extern),
name="log_softmax.miopen",
plevel=15,
)
return strategy
4 changes: 4 additions & 0 deletions src/runtime/contrib/miopen/miopen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ void ConvEntry::CleanWorkspace() {
workspace_size = 0;
}

SoftmaxEntry::SoftmaxEntry() { MIOPEN_CALL(miopenCreateTensorDescriptor(&shape_desc)); }

SoftmaxEntry::~SoftmaxEntry() { MIOPEN_CALL(miopenDestroyTensorDescriptor(shape_desc)); }

} // namespace miopen
} // namespace contrib
} // namespace tvm
7 changes: 7 additions & 0 deletions src/runtime/contrib/miopen/miopen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,18 @@ struct ConvEntry {
void CleanWorkspace();
}; // ConvThreadEntry

struct SoftmaxEntry {
miopenTensorDescriptor_t shape_desc;
SoftmaxEntry();
~SoftmaxEntry();
}; // SoftmaxEntry

struct MIOpenThreadEntry {
MIOpenThreadEntry();
~MIOpenThreadEntry();
miopenHandle_t handle{nullptr};
ConvEntry conv_entry;
SoftmaxEntry softmax_entry;
runtime::DeviceAPI* rocm_api{nullptr};
static MIOpenThreadEntry* ThreadLocal();
}; // MIOpenThreadEntry
Expand Down
92 changes: 92 additions & 0 deletions src/runtime/contrib/miopen/softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/runtime/contrib/miopen/softmax.cc
* \brief Use external miopen softmax function
*/
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/registry.h>

#include "miopen_utils.h"

namespace tvm {
namespace contrib {
namespace miopen {

using namespace runtime;

void softmax_impl(TVMArgs args, TVMRetValue* ret, miopenSoftmaxAlgorithm_t alg) {
DLTensor* x = args[0];
DLTensor* y = args[1];
int axis = args[2];
int ndim = x->ndim;
int64_t* shape = x->shape;
if (axis < 0) axis += ndim;
ICHECK(axis >= 0 && axis < ndim);
// just fp32 for now
ICHECK(TypeMatch(x->dtype, kDLFloat, 32));
ICHECK(TypeMatch(y->dtype, kDLFloat, 32));

MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();

miopenSoftmaxMode_t mode;
if (axis == ndim - 1) {
int64_t N = 1;
for (int i = 0; i < ndim - 1; ++i) {
N *= shape[i];
}
mode = MIOPEN_SOFTMAX_MODE_INSTANCE;
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->softmax_entry.shape_desc, miopenFloat,
static_cast<int>(N), static_cast<int>(shape[ndim - 1]),
1, 1));
} else {
int64_t pre_axis_dim = 1;
int64_t post_axis_dim = 1;
for (int i = 0; i < ndim; ++i) {
if (i < axis) {
pre_axis_dim *= shape[i];
} else if (i > axis) {
post_axis_dim *= shape[i];
}
}
mode = MIOPEN_SOFTMAX_MODE_CHANNEL;
MIOPEN_CALL(miopenSet4dTensorDescriptor(
entry_ptr->softmax_entry.shape_desc, miopenFloat, static_cast<int>(pre_axis_dim),
static_cast<int>(shape[axis]), static_cast<int>(post_axis_dim), 1));
}

const float alpha = 1.f;
const float beta = 0.f;
MIOPEN_CALL(miopenSoftmaxForward_V2(entry_ptr->handle, &alpha,
entry_ptr->softmax_entry.shape_desc, x->data, &beta,
entry_ptr->softmax_entry.shape_desc, y->data, alg, mode));
}

TVM_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) {
softmax_impl(args, ret, MIOPEN_SOFTMAX_ACCURATE);
});

TVM_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); });

} // namespace miopen
} // namespace contrib
} // namespace tvm
66 changes: 63 additions & 3 deletions tests/python/contrib/test_miopen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@
from tvm import te
from tvm.contrib import miopen
import numpy as np
import pytest


requires_miopen = pytest.mark.skipif(
tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True) is None,
reason="MIOpen is not enabled",
)


@tvm.testing.requires_rocm
@requires_miopen
def test_conv2d():
in_channel = 3
out_channel = 64
Expand All @@ -35,9 +43,6 @@ def test_conv2d():
dilation_w = 1

xshape = [1, in_channel, 128, 128]
if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True):
print("skip because miopen is not enabled...")
return
wshape = (out_channel, in_channel, filter_h, filter_w)

X = te.placeholder(xshape, name="X")
Expand Down Expand Up @@ -72,5 +77,60 @@ def verify():
verify()


def verify_softmax(shape, axis, dtype="float32", log_softmax=False):
miopen_op = miopen.log_softmax if log_softmax else miopen.softmax
testing_op = (
tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python
)

A = te.placeholder(shape, dtype=dtype, name="A")
B = miopen_op(A, axis)
s = te.create_schedule([B.op])

dev = tvm.rocm(0)
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = testing_op(a_np)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
f = tvm.build(s, [A, B], target="rocm --host=llvm", name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3)


def verify_softmax_4d(shape, dtype="float32", log_softmax=False):
miopen_op = miopen.log_softmax if log_softmax else miopen.softmax
testing_op = (
tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python
)

A = te.placeholder(shape, dtype=dtype, name="A")
B = miopen_op(A, axis=1)
s = te.create_schedule([B.op])

dev = tvm.rocm(0)
n, c, h, w = shape
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = testing_op(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
f = tvm.build(s, [A, B], target="rocm --host=llvm", name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3)


@tvm.testing.requires_rocm
@requires_miopen
def test_softmax():
verify_softmax((32, 10), -1)
verify_softmax((3, 4), -1)
verify_softmax_4d((1, 16, 256, 256))
verify_softmax_4d((1, 16, 256, 256))

verify_softmax((32, 10), -1, log_softmax=True)
verify_softmax((3, 4), -1, log_softmax=True)
verify_softmax_4d((1, 16, 256, 256), log_softmax=True)


if __name__ == "__main__":
test_conv2d()