Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 682dde6

Browse files
Scott Grayapark263
Scott Gray
authored andcommitted
generalized fast transpose/dimshuffle
1 parent e110534 commit 682dde6

File tree

3 files changed

+223
-1
lines changed

3 files changed

+223
-1
lines changed

neon/backends/convolution.py

+181
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
from neon.backends import kernel_specs
2323
from neon.backends.cuda_templates import _common_round, _ew_types
2424
from math import ceil
25+
from operator import mul
26+
import sys
27+
28+
if sys.version_info >= (3, 0):
29+
from functools import reduce
2530

2631

2732
class KernelGroup(object):
@@ -970,3 +975,179 @@ def _get_shuffle_kernel(dtype):
970975
kernel = module.get_function("dimShuffle")
971976
kernel.prepare("PPIIIIIIIIIIIIIIII")
972977
return kernel
978+
979+
980+
@context_dependent_memoize
981+
def _get_copy_transpose_kernel(dtype, shape, axes=None):
982+
983+
src = range(len(shape))
984+
dst = list(axes)
985+
986+
src_contig = src[-1]
987+
dst_contig = dst[-1]
988+
989+
assert src_contig != dst_contig, "Inner dimension must change (for now)"
990+
991+
dim_params = []
992+
dim_values = []
993+
in_offset = []
994+
out_offset = []
995+
magic_params = []
996+
magic_values = []
997+
magic = ""
998+
999+
for i, s in enumerate(src):
1000+
1001+
idx = "".join(str(x) for x in src[i+1:])
1002+
val = reduce(mul, (shape[x] for x in src[i+1:]), 1)
1003+
1004+
if s == dst_contig:
1005+
in_dim_j = "dim_%s" % idx
1006+
elif idx:
1007+
in_offset.append("idx_%d*dim_%s" % (s, idx))
1008+
else:
1009+
in_offset.append("idx_%d" % s)
1010+
1011+
if idx:
1012+
dim_params.append("int dim_%s" % idx)
1013+
dim_values.append(val)
1014+
1015+
for i, d in enumerate(dst):
1016+
1017+
idx = "".join(str(x) for x in dst[i+1:])
1018+
val = reduce(mul, (shape[x] for x in dst[i+1:]), 1)
1019+
1020+
if d == src_contig:
1021+
out_dim_j = "dim_%s" % idx
1022+
1023+
if idx:
1024+
dim_params.append("int dim_%s" % idx)
1025+
dim_values.append(val)
1026+
1027+
out_offset.append("idx_%d*dim_%s" % (d, idx))
1028+
else:
1029+
out_offset.append("idx_%d" % d)
1030+
1031+
src2 = list(src)
1032+
src2[dst_contig:dst_contig+1] = ()
1033+
1034+
blk = compound_idx = "".join(str(x) for x in src2)
1035+
1036+
grid_shape = list(shape)
1037+
grid_shape[src_contig] = _ceil_div(shape[src_contig], 32)
1038+
grid_shape[dst_contig] = _ceil_div(shape[dst_contig], 32)
1039+
1040+
while len(src2) > 1:
1041+
1042+
idx1 = src2[0]
1043+
src2[0:1] = ()
1044+
idx2 = "".join(str(i) for i in src2)
1045+
div = reduce(mul, (grid_shape[i] for i in src2), 1)
1046+
1047+
magic_params.append("int magic_%s, int shift_%s, int div_%s" % (idx2, idx2, idx2))
1048+
magic_values.append(_magic64(div))
1049+
magic_values.append(div)
1050+
1051+
magic += r"""
1052+
int idx_{1} = div64(idx_{0}, magic_{2}, shift_{2});
1053+
int idx_{2} = idx_{0} - idx_{1}*div_{2};
1054+
""".format(compound_idx, idx1, idx2)
1055+
1056+
compound_idx = idx2
1057+
1058+
params = _flatten([dim_params, magic_params])
1059+
values = _flatten([dim_values, magic_values])
1060+
1061+
shuffle_kernel = r"""
1062+
__device__ __forceinline__ int div64(int value, int magic, int shift)
1063+
{
1064+
// if the divisor is a power of 2 the magic will be 1 and it's just a simple right shift
1065+
// Otherwise multiply by magic and right shift just the high bits
1066+
int result;
1067+
asm("{\n\t"
1068+
".reg .pred p;\n\t"
1069+
".reg .u64 res64;\n\t"
1070+
".reg .u32 lo32, hi32;\n\t"
1071+
"setp.ne.s32 p, %%2, 1;\n\t"
1072+
"mul.wide.u32 res64, %%1, %%2;\n\t"
1073+
"mov.b64 {lo32, hi32}, res64;\n\t"
1074+
"selp.u32 hi32, hi32, %%1, p;\n\t"
1075+
"shr.u32 %%0, hi32, %%3;\n\t"
1076+
"}" : "=r"(result) : "r"(value), "r"(magic), "r"(shift));
1077+
return result;
1078+
}
1079+
1080+
__global__ void copy_transpose(%(type)s* out, const %(type)s* in, %(params)s)
1081+
{
1082+
__shared__ %(type)s tile[32][33];
1083+
1084+
int tid_x = threadIdx.x;
1085+
int tid_y = threadIdx.y;
1086+
int idx_%(blk)s = blockIdx.x;
1087+
int idx_%(dst)s = blockIdx.y;
1088+
1089+
%(magic)s
1090+
1091+
idx_%(src)s = (idx_%(src)s << 5) + tid_x;
1092+
idx_%(dst)s = (idx_%(dst)s << 5) + tid_y;
1093+
1094+
int offset = %(in_offset)s;
1095+
1096+
#pragma unroll
1097+
for (int j = 0; j < 32; j += 8)
1098+
{
1099+
int idx_%(dst)sj = idx_%(dst)s + j;
1100+
if (idx_%(dst)sj < dim_%(dst)s && idx_%(src)s < dim_%(src)s)
1101+
tile[tid_y + j][tid_x] = in[idx_%(dst)sj*%(in_dim_j)s + offset];
1102+
}
1103+
__syncthreads();
1104+
1105+
%(type)s val00 = tile[tid_x][tid_y + 0];
1106+
%(type)s val08 = tile[tid_x][tid_y + 8];
1107+
%(type)s val16 = tile[tid_x][tid_y + 16];
1108+
%(type)s val24 = tile[tid_x][tid_y + 24];
1109+
1110+
idx_%(src)s += tid_y - tid_x;
1111+
idx_%(dst)s += tid_x - tid_y;
1112+
1113+
bool b%(dst)s = idx_%(dst)s < dim_%(dst)s;
1114+
1115+
%(type)s* out00 = out + %(out_offset)s;
1116+
%(type)s* out08 = out00 + %(out_dim_j)s*8;
1117+
%(type)s* out16 = out08 + %(out_dim_j)s*8;
1118+
%(type)s* out24 = out16 + %(out_dim_j)s*8;
1119+
1120+
if (idx_%(src)s + 0 < dim_%(src)s && b%(dst)s) *out00 = val00;
1121+
if (idx_%(src)s + 8 < dim_%(src)s && b%(dst)s) *out08 = val08;
1122+
if (idx_%(src)s + 16 < dim_%(src)s && b%(dst)s) *out16 = val16;
1123+
if (idx_%(src)s + 24 < dim_%(src)s && b%(dst)s) *out24 = val24;
1124+
}
1125+
"""
1126+
code = shuffle_kernel % dict(
1127+
type=_ew_types[dtype[1:]]["type"],
1128+
params=", ".join(params),
1129+
blk=blk,
1130+
src=src_contig,
1131+
dst=dst_contig,
1132+
magic=magic,
1133+
in_offset=" + ".join(in_offset),
1134+
out_offset=" + ".join(out_offset),
1135+
in_dim_j=in_dim_j,
1136+
out_dim_j=out_dim_j
1137+
)
1138+
module = SourceModule(code)
1139+
kernel = module.get_function("copy_transpose")
1140+
kernel.prepare("PP" + "I"*len(values))
1141+
1142+
grid_x = grid_shape[src_contig]
1143+
grid_y = grid_shape[dst_contig]
1144+
for s in src:
1145+
if s not in (src_contig, dst_contig):
1146+
grid_x *= grid_shape[s]
1147+
1148+
return dict(
1149+
kernel=kernel,
1150+
grid=(grid_x, grid_y, 1),
1151+
block=(32, 8, 1),
1152+
args=values
1153+
)

neon/backends/nervanagpu.py

+41
Original file line numberDiff line numberDiff line change
@@ -2231,6 +2231,47 @@ def cublas_dot(self, A, B, C, alpha=1.0, beta=0.0):
22312231
else:
22322232
raise TypeError("Unsupported type for cublas gemm")
22332233

2234+
def copy_transpose(self, a, out, axes=None, repeat=1):
2235+
"""
2236+
Function to perform a fast copy transpose/dimshuffle operation.
2237+
Works just like numpy.transpose, but requires an output tensor argument.
2238+
"""
2239+
assert a.dtype == out.dtype
2240+
assert a.size == out.size
2241+
assert a.gpudata != out.gpudata
2242+
2243+
if axes is None:
2244+
axes = tuple(range(len(a.shape)-1,-1,-1))
2245+
elif type(axes) is not tuple:
2246+
axes = tuple(axes)
2247+
2248+
assert all(out.shape[i]==a.shape[x] for i,x in enumerate(axes))
2249+
2250+
from neon.backends.convolution import _get_copy_transpose_kernel
2251+
2252+
kernel_data = _get_copy_transpose_kernel(a.dtype.str, a.shape, axes)
2253+
2254+
# Warmup
2255+
if repeat > 1:
2256+
for r in range(max(repeat // 10, 1)):
2257+
kernel_data["kernel"].prepared_async_call(kernel_data["grid"], kernel_data["block"],
2258+
self.stream, out.gpudata, a.gpudata, *kernel_data["args"])
2259+
2260+
if self.bench > 1 or repeat > 1:
2261+
start, end = _get_events()
2262+
start.record(self.stream)
2263+
2264+
for r in range(repeat):
2265+
kernel_data["kernel"].prepared_async_call(kernel_data["grid"], kernel_data["block"],
2266+
self.stream, out.gpudata, a.gpudata, *kernel_data["args"])
2267+
2268+
if self.bench > 1 or repeat > 1:
2269+
end.record(self.stream)
2270+
end.synchronize()
2271+
msecs = end.time_since(start) / repeat
2272+
bandwidth = a.nbytes*2 / (msecs * 1024 * 1024)
2273+
print("%7.3f msecs %4.0f GBps copy_transpose" % (msecs, bandwidth))
2274+
22342275
def init_mark(self):
22352276
"""
22362277
Generate a timing mark object

neon/data/imageloader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(self, repo_dir, inner_size, scale_range, do_transforms=True,
111111

112112
# View for subtracting the mean.
113113
# Find a shape that's fast for ew broadcast
114-
image_dim = self.data.reshape((ishape[0],-1)).shape[1]
114+
image_dim = self.data.reshape((ishape[0], -1)).shape[1]
115115
fast_dim = [i for i in range(1, 257) if image_dim % i == 0][-1]
116116
self.data_view = self.data.reshape((ishape[0], image_dim//fast_dim, fast_dim))
117117

0 commit comments

Comments
 (0)