1515import unittest
1616
1717import numpy as np
18- from op_test import OpTest , convert_float_to_uint16 , get_numeric_gradient
18+ from op_test import (
19+ OpTest ,
20+ convert_float_to_uint16 ,
21+ get_device_place ,
22+ get_numeric_gradient ,
23+ is_custom_device ,
24+ )
1925from testsuite import create_op
2026
2127import paddle
@@ -162,7 +168,8 @@ def init_kernel_type(self):
162168
163169def create_test_cudnn_fp16_class (parent , grad_check = True ):
164170 @unittest .skipIf (
165- not core .is_compiled_with_cuda (), "core is not compiled with CUDA"
171+ not (core .is_compiled_with_cuda () or is_custom_device ()),
172+ "core is not compiled with CUDA" ,
166173 )
167174 class TestConv2DCUDNNFp16 (parent ):
168175 def init_kernel_type (self ):
@@ -171,19 +178,19 @@ def init_kernel_type(self):
171178
172179 def test_check_output (self ):
173180 if core .is_compiled_with_cuda ():
174- place = core . CUDAPlace ( 0 )
181+ place = get_device_place ( )
175182 if core .is_float16_supported (place ):
176183 self .check_output_with_place (place , atol = 2e-2 )
177184
178185 def test_check_grad_no_filter (self ):
179- place = core . CUDAPlace ( 0 )
186+ place = get_device_place ( )
180187 if core .is_float16_supported (place ) and grad_check :
181188 self .check_grad_with_place (
182189 place , ['Input' ], 'Output' , no_grad_set = {'Filter' }
183190 )
184191
185192 def test_check_grad_no_input (self ):
186- place = core . CUDAPlace ( 0 )
193+ place = get_device_place ( )
187194 if core .is_float16_supported (place ) and grad_check :
188195 self .check_grad_with_place (
189196 place , ['Filter' ], 'Output' , no_grad_set = {'Input' }
@@ -196,8 +203,8 @@ def test_check_grad_no_input(self):
196203
197204def create_test_cudnn_bf16_class (parent ):
198205 @unittest .skipIf (
199- not core .is_compiled_with_cuda ()
200- or not core .is_bfloat16_supported (core . CUDAPlace ( 0 )),
206+ not ( core .is_compiled_with_cuda () or is_custom_device () )
207+ or not core .is_bfloat16_supported (get_device_place ( )),
201208 "core is not compiled with CUDA and do not support bfloat16" ,
202209 )
203210 class TestConv2DCUDNNBF16 (parent ):
@@ -217,11 +224,11 @@ def init_kernel_type(self):
217224 self .dtype = np .uint16
218225
219226 def test_check_output (self ):
220- place = core . CUDAPlace ( 0 )
227+ place = get_device_place ( )
221228 self .check_output_with_place (place , atol = 1e-2 )
222229
223230 def test_check_grad_no_filter (self ):
224- place = core . CUDAPlace ( 0 )
231+ place = get_device_place ( )
225232 numeric_grads = self .get_numeric_grad (place , 'Input' )
226233 self .check_grad_with_place (
227234 place ,
@@ -232,7 +239,7 @@ def test_check_grad_no_filter(self):
232239 )
233240
234241 def test_check_grad_no_input (self ):
235- place = core . CUDAPlace ( 0 )
242+ place = get_device_place ( )
236243 numeric_grads = self .get_numeric_grad (place , 'Filter' )
237244 self .check_grad_with_place (
238245 place ,
@@ -294,20 +301,20 @@ def init_kernel_type(self):
294301 self .dtype = np .float16
295302
296303 def test_check_output (self ):
297- if core .is_compiled_with_cuda ():
298- place = core . CUDAPlace ( 0 )
304+ if core .is_compiled_with_cuda () or is_custom_device () :
305+ place = get_device_place ( )
299306 if core .is_float16_supported (place ):
300307 self .check_output_with_place (place , atol = 2e-2 )
301308
302309 def test_check_grad_no_filter (self ):
303- place = core . CUDAPlace ( 0 )
310+ place = get_device_place ( )
304311 if core .is_float16_supported (place ) and grad_check :
305312 self .check_grad_with_place (
306313 place , ['Input' ], 'Output' , no_grad_set = {'Filter' }
307314 )
308315
309316 def test_check_grad_no_input (self ):
310- place = core . CUDAPlace ( 0 )
317+ place = get_device_place ( )
311318 if core .is_float16_supported (place ) and grad_check :
312319 self .check_grad_with_place (
313320 place , ['Filter' ], 'Output' , no_grad_set = {'Input' }
@@ -491,12 +498,12 @@ def setUp(self):
491498 self .outputs = {'Output' : output }
492499
493500 def has_cuda (self ):
494- return core .is_compiled_with_cuda () and (
501+ return ( core .is_compiled_with_cuda () or is_custom_device () ) and (
495502 self .use_cudnn or self .use_cuda
496503 )
497504
498505 def test_check_output (self ):
499- place = core . CUDAPlace ( 0 ) if self .has_cuda () else core .CPUPlace ()
506+ place = get_device_place ( ) if self .has_cuda () else core .CPUPlace ()
500507 # TODO(wangzhongpu): support onednn op in dygraph mode
501508 self .check_output_with_place (
502509 place ,
@@ -510,7 +517,7 @@ def test_check_grad(self):
510517 hasattr (self , "no_need_check_grad" ) and self .no_need_check_grad
511518 ):
512519 return
513- place = core . CUDAPlace ( 0 ) if self .has_cuda () else core .CPUPlace ()
520+ place = get_device_place ( ) if self .has_cuda () else core .CPUPlace ()
514521 # TODO(wangzhongpu): support onednn op in dygraph mode
515522 self .check_grad_with_place (
516523 place ,
@@ -526,7 +533,7 @@ def test_check_grad_no_filter(self):
526533 hasattr (self , "no_need_check_grad" ) and self .no_need_check_grad
527534 ):
528535 return
529- place = core . CUDAPlace ( 0 ) if self .has_cuda () else core .CPUPlace ()
536+ place = get_device_place ( ) if self .has_cuda () else core .CPUPlace ()
530537 # TODO(wangzhongpu): support onednn op in dygraph mode
531538 self .check_grad_with_place (
532539 place ,
@@ -543,7 +550,7 @@ def test_check_grad_no_input(self):
543550 hasattr (self , "no_need_check_grad" ) and self .no_need_check_grad
544551 ):
545552 return
546- place = core . CUDAPlace ( 0 ) if self .has_cuda () else core .CPUPlace ()
553+ place = get_device_place ( ) if self .has_cuda () else core .CPUPlace ()
547554 # TODO(wangzhongpu): support onednn op in dygraph mode
548555 self .check_grad_with_place (
549556 place ,
@@ -830,7 +837,7 @@ def has_cuda(self):
830837
831838 def test_check_output (self ):
832839 # TODO(wangzhongpu): support onednn op in dygraph mode
833- place = core . CUDAPlace ( 0 ) if self .has_cuda () else core .CPUPlace ()
840+ place = get_device_place ( ) if self .has_cuda () else core .CPUPlace ()
834841 self .check_output_with_place (
835842 place ,
836843 atol = 1e-5 ,
@@ -842,7 +849,7 @@ def test_check_grad(self):
842849 # TODO(wangzhongpu): support onednn op in dygraph mode
843850 if self .dtype == np .float16 :
844851 return
845- place = core . CUDAPlace ( 0 ) if self .has_cuda () else core .CPUPlace ()
852+ place = get_device_place ( ) if self .has_cuda () else core .CPUPlace ()
846853 self .check_grad_with_place (
847854 place ,
848855 {'Input' , 'Filter' },
@@ -856,7 +863,7 @@ def test_check_grad_no_filter(self):
856863 # TODO(wangzhongpu): support onednn op in dygraph mode
857864 if self .dtype == np .float16 :
858865 return
859- place = core . CUDAPlace ( 0 ) if self .has_cuda () else core .CPUPlace ()
866+ place = get_device_place ( ) if self .has_cuda () else core .CPUPlace ()
860867 self .check_grad_with_place (
861868 place ,
862869 ['Input' ],
@@ -871,7 +878,7 @@ def test_check_grad_no_input(self):
871878 # TODO(wangzhongpu): support onednn op in dygraph mode
872879 if self .dtype == np .float16 :
873880 return
874- place = core . CUDAPlace ( 0 ) if self .has_cuda () else core .CPUPlace ()
881+ place = get_device_place ( ) if self .has_cuda () else core .CPUPlace ()
875882 self .check_grad_with_place (
876883 place ,
877884 ['Filter' ],
0 commit comments