@@ -850,37 +850,27 @@ def _get_gaussian_kernel2d(size, sigma):
850850    if  data_format  ==  "channels_first" :
851851        images  =  np .transpose (images , (0 , 2 , 3 , 1 ))
852852
853-     num_channels  =  images .shape [- 1 ]
853+     batch_size , height , width , num_channels  =  images .shape 
854+ 
854855    kernel  =  _create_gaussian_kernel (
855856        kernel_size , sigma , num_channels , input_dtype 
856857    )
857-     batch_size , height , width , _  =  images .shape 
858-     padded_images  =  np .pad (
859-         images ,
860-         (
861-             (0 , 0 ),
862-             (kernel_size [0 ] //  2 , kernel_size [0 ] //  2 ),
863-             (kernel_size [1 ] //  2 , kernel_size [1 ] //  2 ),
864-             (0 , 0 ),
865-         ),
866-         mode = "constant" ,
867-     )
868858
869-     blurred_images  =  np . zeros_like ( images ) 
870-     kernel_reshaped  =  kernel . reshape ( 
871-         ( 1 ,  kernel . shape [ 0 ],  kernel . shape [ 1 ],  num_channels ) 
872-     )
859+     pad_h  =  kernel_size [ 0 ]  //   2 
860+     pad_w  =  kernel_size [ 1 ]  //   2 
861+ 
862+     blurred_images   =   np . empty_like ( images )
873863
874864    for  b  in  range (batch_size ):
875-         image_patch   =   padded_images [ b  :  b   +   1 , :, :, :] 
876-         for   i   in   range ( height ): 
877-             for   j   in   range ( width ): 
878-                 patch   =   image_patch [ 
879-                     :,  i  :  i   +   kernel_size [ 0 ],  j  :  j   +   kernel_size [ 1 ], : 
880-                 ] 
881-                  blurred_images [b , i ,  j , : ] =  np . sum (
882-                      patch   *   kernel_reshaped ,  axis = ( 1 ,  2 ) 
883-                  )
865+         for   ch   in   range ( num_channels ): 
866+              padded   =   np . pad ( 
867+                  images [ b , :, :,  ch ], 
868+                 (( pad_h ,  pad_h ), ( pad_w ,  pad_w )), 
869+                 mode = "constant" , 
870+             ) 
871+             blurred_images [b , :, :,  ch ] =  scipy . signal . convolve2d (
872+                 padded ,  kernel [:, :,  ch ],  mode = "valid" 
873+             )
884874
885875    if  data_format  ==  "channels_first" :
886876        blurred_images  =  np .transpose (blurred_images , (0 , 3 , 1 , 2 ))
0 commit comments