Skip to content

Commit

Permalink
CudnnNormConvolution is no longer supported on NVIDIA Hopper GPUs (#4…
Browse files Browse the repository at this point in the history
…8203)

* Skip tests that use fused_ops on H100

* Add error message to FusedOps on H100
  • Loading branch information
Tom-Zheng authored Nov 22, 2022
1 parent 2ab60c3 commit df4dfda
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/operators/fused/cudnn_norm_conv.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ struct NormConvolutionArgs {
int stride,
int dilation,
int group) {
PADDLE_ENFORCE_LT(
ctx.GetComputeCapability(),
90,
phi::errors::PreconditionNotMet(
"Expect compute compatiblity to be less than 90, but got %d. "
"CUDNN FusedOps is no longer available on H100 and later "
"devices.",
ctx.GetComputeCapability()));
PADDLE_ENFORCE_EQ(
input_shape.size(),
4U,
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/fused/cudnn_norm_conv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ TEST(CudnnNormConvFp16, K1S1) {
phi::GPUContext *ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));

if (ctx->GetComputeCapability() < 70) {
if (ctx->GetComputeCapability() < 70 || ctx->GetComputeCapability() >= 90) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
Expand Down Expand Up @@ -469,7 +469,7 @@ TEST(CudnnNormConvFp16, K3S1) {
phi::GPUContext *ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));

if (ctx->GetComputeCapability() < 70) {
if (ctx->GetComputeCapability() < 70 || ctx->GetComputeCapability() >= 90) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
Expand Down Expand Up @@ -499,7 +499,7 @@ TEST(CudnnNormConvFp16, K1S1O4) {
phi::GPUContext *ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));

if (ctx->GetComputeCapability() < 70) {
if (ctx->GetComputeCapability() < 70 || ctx->GetComputeCapability() >= 90) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3, true),
Expand Down Expand Up @@ -529,7 +529,7 @@ TEST(CudnnNormConvFp16, K1S2O4) {
phi::GPUContext *ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));

if (ctx->GetComputeCapability() <= 70) {
if (ctx->GetComputeCapability() <= 70 || ctx->GetComputeCapability() >= 90) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3), paddle::platform::EnforceNotMet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
@unittest.skipIf(
not paddle.is_compiled_with_cuda()
or paddle.get_cudnn_version() < 8000
or paddle.device.cuda.get_device_capability()[0] < 7,
or paddle.device.cuda.get_device_capability()[0] < 7
or paddle.device.cuda.get_device_capability()[0] >= 9,
"only support with cuda and cudnn version is at least 8.0 "
"and device's compute capability is at least 7.0",
"and device's compute capability is at least 7.0 and less than 9.0",
)
class TestFuseResNetUnit(unittest.TestCase):
def test_fuse_resenet_unit(self):
Expand Down

0 comments on commit df4dfda

Please sign in to comment.