|  | 
|  | 1 | +/* | 
|  | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 3 | + * All rights reserved. | 
|  | 4 | + * | 
|  | 5 | + * This source code is licensed under the BSD-style license found in the | 
|  | 6 | + * LICENSE file in the root directory of this source tree. | 
|  | 7 | + */ | 
|  | 8 | + | 
|  | 9 | +#version 450 core | 
|  | 10 | +// clang-format off | 
|  | 11 | +#define PRECISION ${PRECISION} | 
|  | 12 | +// clang-format on | 
|  | 13 | + | 
|  | 14 | +#include "indexing_utils.h" | 
|  | 15 | + | 
|  | 16 | +layout(std430) buffer; | 
|  | 17 | + | 
|  | 18 | +// clang-format off | 
|  | 19 | +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; | 
|  | 20 | +// clang-format on | 
|  | 21 | +layout(set = 0, binding = 1) buffer  PRECISION restrict readonly Buffer { | 
|  | 22 | +  ${T[DTYPE]} data[]; | 
|  | 23 | +} | 
|  | 24 | +buffer_in; | 
|  | 25 | + | 
|  | 26 | +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { | 
|  | 27 | +  ivec4 data; | 
|  | 28 | +} | 
|  | 29 | +gpu_sizes; | 
|  | 30 | + | 
|  | 31 | +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { | 
|  | 32 | +  ivec4 data; | 
|  | 33 | +} | 
|  | 34 | +cpu_sizes; | 
|  | 35 | + | 
|  | 36 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | 
|  | 37 | + | 
|  | 38 | +void main() { | 
|  | 39 | +  const ivec3 pos = ivec3(gl_GlobalInvocationID); | 
|  | 40 | +  const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); | 
|  | 41 | + | 
|  | 42 | +  if (any(greaterThanEqual(coord, gpu_sizes.data))) { | 
|  | 43 | +    return; | 
|  | 44 | +  } | 
|  | 45 | + | 
|  | 46 | +  const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); | 
|  | 47 | +  const ivec4 buf_indices = | 
|  | 48 | +      base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); | 
|  | 49 | + | 
|  | 50 | +  ${T[DTYPE]} val_x = buffer_in.data[buf_indices.x]; | 
|  | 51 | +  ${T[DTYPE]} val_y = buffer_in.data[buf_indices.y]; | 
|  | 52 | +  ${T[DTYPE]} val_z = buffer_in.data[buf_indices.z]; | 
|  | 53 | +  ${T[DTYPE]} val_w = buffer_in.data[buf_indices.w]; | 
|  | 54 | + | 
|  | 55 | +  ${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w); | 
|  | 56 | + | 
|  | 57 | +  if (coord.z + 3 >= cpu_sizes.data.z) { | 
|  | 58 | +    ivec4 c_ind = ivec4(coord.z) + ivec4(0, 1, 2, 3); | 
|  | 59 | +    vec4 valid_c = vec4(lessThan(c_ind, ivec4(cpu_sizes.data.z))); | 
|  | 60 | +    texel = texel * valid_c; | 
|  | 61 | +  } | 
|  | 62 | + | 
|  | 63 | +  imageStore(image_out, ${GET_POS[NDIM]("pos")}, texel); | 
|  | 64 | +} | 
0 commit comments