|
22 | 22 | from neon.backends import kernel_specs
|
23 | 23 | from neon.backends.cuda_templates import _common_round, _ew_types
|
24 | 24 | from math import ceil
|
| 25 | +from operator import mul |
| 26 | +import sys |
| 27 | + |
| 28 | +if sys.version_info >= (3, 0): |
| 29 | + from functools import reduce |
25 | 30 |
|
26 | 31 |
|
27 | 32 | class KernelGroup(object):
|
@@ -970,3 +975,179 @@ def _get_shuffle_kernel(dtype):
|
970 | 975 | kernel = module.get_function("dimShuffle")
|
971 | 976 | kernel.prepare("PPIIIIIIIIIIIIIIII")
|
972 | 977 | return kernel
|
| 978 | + |
| 979 | + |
| 980 | +@context_dependent_memoize |
| 981 | +def _get_copy_transpose_kernel(dtype, shape, axes=None): |
| 982 | + |
| 983 | + src = range(len(shape)) |
| 984 | + dst = list(axes) |
| 985 | + |
| 986 | + src_contig = src[-1] |
| 987 | + dst_contig = dst[-1] |
| 988 | + |
| 989 | + assert src_contig != dst_contig, "Inner dimension must change (for now)" |
| 990 | + |
| 991 | + dim_params = [] |
| 992 | + dim_values = [] |
| 993 | + in_offset = [] |
| 994 | + out_offset = [] |
| 995 | + magic_params = [] |
| 996 | + magic_values = [] |
| 997 | + magic = "" |
| 998 | + |
| 999 | + for i, s in enumerate(src): |
| 1000 | + |
| 1001 | + idx = "".join(str(x) for x in src[i+1:]) |
| 1002 | + val = reduce(mul, (shape[x] for x in src[i+1:]), 1) |
| 1003 | + |
| 1004 | + if s == dst_contig: |
| 1005 | + in_dim_j = "dim_%s" % idx |
| 1006 | + elif idx: |
| 1007 | + in_offset.append("idx_%d*dim_%s" % (s, idx)) |
| 1008 | + else: |
| 1009 | + in_offset.append("idx_%d" % s) |
| 1010 | + |
| 1011 | + if idx: |
| 1012 | + dim_params.append("int dim_%s" % idx) |
| 1013 | + dim_values.append(val) |
| 1014 | + |
| 1015 | + for i, d in enumerate(dst): |
| 1016 | + |
| 1017 | + idx = "".join(str(x) for x in dst[i+1:]) |
| 1018 | + val = reduce(mul, (shape[x] for x in dst[i+1:]), 1) |
| 1019 | + |
| 1020 | + if d == src_contig: |
| 1021 | + out_dim_j = "dim_%s" % idx |
| 1022 | + |
| 1023 | + if idx: |
| 1024 | + dim_params.append("int dim_%s" % idx) |
| 1025 | + dim_values.append(val) |
| 1026 | + |
| 1027 | + out_offset.append("idx_%d*dim_%s" % (d, idx)) |
| 1028 | + else: |
| 1029 | + out_offset.append("idx_%d" % d) |
| 1030 | + |
| 1031 | + src2 = list(src) |
| 1032 | + src2[dst_contig:dst_contig+1] = () |
| 1033 | + |
| 1034 | + blk = compound_idx = "".join(str(x) for x in src2) |
| 1035 | + |
| 1036 | + grid_shape = list(shape) |
| 1037 | + grid_shape[src_contig] = _ceil_div(shape[src_contig], 32) |
| 1038 | + grid_shape[dst_contig] = _ceil_div(shape[dst_contig], 32) |
| 1039 | + |
| 1040 | + while len(src2) > 1: |
| 1041 | + |
| 1042 | + idx1 = src2[0] |
| 1043 | + src2[0:1] = () |
| 1044 | + idx2 = "".join(str(i) for i in src2) |
| 1045 | + div = reduce(mul, (grid_shape[i] for i in src2), 1) |
| 1046 | + |
| 1047 | + magic_params.append("int magic_%s, int shift_%s, int div_%s" % (idx2, idx2, idx2)) |
| 1048 | + magic_values.append(_magic64(div)) |
| 1049 | + magic_values.append(div) |
| 1050 | + |
| 1051 | + magic += r""" |
| 1052 | + int idx_{1} = div64(idx_{0}, magic_{2}, shift_{2}); |
| 1053 | + int idx_{2} = idx_{0} - idx_{1}*div_{2}; |
| 1054 | +""".format(compound_idx, idx1, idx2) |
| 1055 | + |
| 1056 | + compound_idx = idx2 |
| 1057 | + |
| 1058 | + params = _flatten([dim_params, magic_params]) |
| 1059 | + values = _flatten([dim_values, magic_values]) |
| 1060 | + |
| 1061 | + shuffle_kernel = r""" |
| 1062 | +__device__ __forceinline__ int div64(int value, int magic, int shift) |
| 1063 | +{ |
| 1064 | + // if the divisor is a power of 2 the magic will be 1 and it's just a simple right shift |
| 1065 | + // Otherwise multiply by magic and right shift just the high bits |
| 1066 | + int result; |
| 1067 | + asm("{\n\t" |
| 1068 | + ".reg .pred p;\n\t" |
| 1069 | + ".reg .u64 res64;\n\t" |
| 1070 | + ".reg .u32 lo32, hi32;\n\t" |
| 1071 | + "setp.ne.s32 p, %%2, 1;\n\t" |
| 1072 | + "mul.wide.u32 res64, %%1, %%2;\n\t" |
| 1073 | + "mov.b64 {lo32, hi32}, res64;\n\t" |
| 1074 | + "selp.u32 hi32, hi32, %%1, p;\n\t" |
| 1075 | + "shr.u32 %%0, hi32, %%3;\n\t" |
| 1076 | + "}" : "=r"(result) : "r"(value), "r"(magic), "r"(shift)); |
| 1077 | + return result; |
| 1078 | +} |
| 1079 | +
|
| 1080 | +__global__ void copy_transpose(%(type)s* out, const %(type)s* in, %(params)s) |
| 1081 | +{ |
| 1082 | + __shared__ %(type)s tile[32][33]; |
| 1083 | +
|
| 1084 | + int tid_x = threadIdx.x; |
| 1085 | + int tid_y = threadIdx.y; |
| 1086 | + int idx_%(blk)s = blockIdx.x; |
| 1087 | + int idx_%(dst)s = blockIdx.y; |
| 1088 | +
|
| 1089 | + %(magic)s |
| 1090 | +
|
| 1091 | + idx_%(src)s = (idx_%(src)s << 5) + tid_x; |
| 1092 | + idx_%(dst)s = (idx_%(dst)s << 5) + tid_y; |
| 1093 | +
|
| 1094 | + int offset = %(in_offset)s; |
| 1095 | +
|
| 1096 | + #pragma unroll |
| 1097 | + for (int j = 0; j < 32; j += 8) |
| 1098 | + { |
| 1099 | + int idx_%(dst)sj = idx_%(dst)s + j; |
| 1100 | + if (idx_%(dst)sj < dim_%(dst)s && idx_%(src)s < dim_%(src)s) |
| 1101 | + tile[tid_y + j][tid_x] = in[idx_%(dst)sj*%(in_dim_j)s + offset]; |
| 1102 | + } |
| 1103 | + __syncthreads(); |
| 1104 | +
|
| 1105 | + %(type)s val00 = tile[tid_x][tid_y + 0]; |
| 1106 | + %(type)s val08 = tile[tid_x][tid_y + 8]; |
| 1107 | + %(type)s val16 = tile[tid_x][tid_y + 16]; |
| 1108 | + %(type)s val24 = tile[tid_x][tid_y + 24]; |
| 1109 | +
|
| 1110 | + idx_%(src)s += tid_y - tid_x; |
| 1111 | + idx_%(dst)s += tid_x - tid_y; |
| 1112 | +
|
| 1113 | + bool b%(dst)s = idx_%(dst)s < dim_%(dst)s; |
| 1114 | +
|
| 1115 | + %(type)s* out00 = out + %(out_offset)s; |
| 1116 | + %(type)s* out08 = out00 + %(out_dim_j)s*8; |
| 1117 | + %(type)s* out16 = out08 + %(out_dim_j)s*8; |
| 1118 | + %(type)s* out24 = out16 + %(out_dim_j)s*8; |
| 1119 | +
|
| 1120 | + if (idx_%(src)s + 0 < dim_%(src)s && b%(dst)s) *out00 = val00; |
| 1121 | + if (idx_%(src)s + 8 < dim_%(src)s && b%(dst)s) *out08 = val08; |
| 1122 | + if (idx_%(src)s + 16 < dim_%(src)s && b%(dst)s) *out16 = val16; |
| 1123 | + if (idx_%(src)s + 24 < dim_%(src)s && b%(dst)s) *out24 = val24; |
| 1124 | +} |
| 1125 | +""" |
| 1126 | + code = shuffle_kernel % dict( |
| 1127 | + type=_ew_types[dtype[1:]]["type"], |
| 1128 | + params=", ".join(params), |
| 1129 | + blk=blk, |
| 1130 | + src=src_contig, |
| 1131 | + dst=dst_contig, |
| 1132 | + magic=magic, |
| 1133 | + in_offset=" + ".join(in_offset), |
| 1134 | + out_offset=" + ".join(out_offset), |
| 1135 | + in_dim_j=in_dim_j, |
| 1136 | + out_dim_j=out_dim_j |
| 1137 | + ) |
| 1138 | + module = SourceModule(code) |
| 1139 | + kernel = module.get_function("copy_transpose") |
| 1140 | + kernel.prepare("PP" + "I"*len(values)) |
| 1141 | + |
| 1142 | + grid_x = grid_shape[src_contig] |
| 1143 | + grid_y = grid_shape[dst_contig] |
| 1144 | + for s in src: |
| 1145 | + if s not in (src_contig, dst_contig): |
| 1146 | + grid_x *= grid_shape[s] |
| 1147 | + |
| 1148 | + return dict( |
| 1149 | + kernel=kernel, |
| 1150 | + grid=(grid_x, grid_y, 1), |
| 1151 | + block=(32, 8, 1), |
| 1152 | + args=values |
| 1153 | + ) |
0 commit comments