@@ -210,13 +210,13 @@ def grid_sample(
210210 + ws * d_e * d_n + es * d_w * d_n
211211
212212 Args:
213- x(Tensor): The input tensor, which is a 4-d tensor with shape
214- [N, C, H, W] or a 5-d tensor with shape [N, C, D, H, W],
213+ x(Tensor): The input tensor, which is a 4-D tensor with shape
214+ [N, C, H, W] or a 5-D tensor with shape [N, C, D, H, W],
215215 N is the batch size, C is the channel number,
216216 D, H and W is the feature depth, height and width.
217217 The data type is float32 or float64.
218- grid(Tensor): Input grid tensor, which is a 4-d tensor with shape [N, grid_H,
219- grid_W, 2] or a 5-d tensor with shape [N, grid_D, grid_H,
218+ grid(Tensor): Input grid tensor, which is a 4-D tensor with shape [N, grid_H,
219+ grid_W, 2] or a 5-D tensor with shape [N, grid_D, grid_H,
220220 grid_W, 3]. The data type is float32 or float64.
221221 mode(str, optional): The interpolation method which can be 'bilinear' or 'nearest'.
222222 Default: 'bilinear'.
0 commit comments