Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/api/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ api::VulkanBuffer& vTensor::buffer(
return storage_.buffer_;
}

const api::BufferBindInfo vTensor::cpu_sizes_ubo() {
const api::BufferBindInfo vTensor::sizes_ubo() {
if (!cpu_sizes_uniform_.buffer()) {
cpu_sizes_uniform_ = api::UniformParamsBuffer(
storage_.context_, api::utils::make_whcn_ivec4(sizes_));
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/api/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class vTensor final {
* shader. Note that the UBO will be created the first time this function is
* called.
*/
const api::BufferBindInfo cpu_sizes_ubo();
const api::BufferBindInfo sizes_ubo();

/*
* Get a uniform buffer object containing the tensor GPU sizes to use in a
Expand Down
8 changes: 5 additions & 3 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ PrepackNode::PrepackNode(
const api::utils::uvec3& local_workgroup_size,
const ValueRef tref,
const ValueRef packed,
const api::ParamsBindList& params)
const api::ParamsBindList& params,
const api::SpecVarList& spec_vars)
: shader_(shader),
noop_shader_(get_noop_shader(graph, packed)),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
tref_(tref),
packed_(packed),
params_(params) {
params_(params),
spec_vars_(spec_vars) {
graph.update_descriptor_counts(shader, /*execute = */ false);
graph.update_descriptor_counts(noop_shader_, /*execute = */ false);
}
Expand Down Expand Up @@ -75,7 +77,7 @@ void PrepackNode::encode(ComputeGraph* graph) {
{
api::PipelineBarrier pipeline_barrier{};
api::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_);
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);

uint32_t idx = 0;
bind_tensor_to_descriptor_set(
Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/graph/ops/PrepackNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class PrepackNode final {
const api::utils::uvec3& local_workgroup_size,
const ValueRef tref,
const ValueRef packed,
const api::ParamsBindList& params);
const api::ParamsBindList& params,
const api::SpecVarList& spec_vars = {});

~PrepackNode() = default;

Expand All @@ -47,6 +48,7 @@ class PrepackNode final {
const ValueRef tref_;
const ValueRef packed_;
const api::ParamsBindList params_;
const api::SpecVarList spec_vars_;

private:
api::StorageBuffer create_staging_buffer(ComputeGraph* graph);
Expand Down
48 changes: 21 additions & 27 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@

#define VEC4_T ${texel_type(DTYPE)}

#define to_tensor_idx to_tensor_idx_${PACKING}
#define to_texture_pos to_texture_pos_${PACKING}

#define op(X, Y, A) ${OPERATOR}

#include "broadcasting_utils.h"
Expand All @@ -27,59 +24,56 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other;

layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes {
ivec4 data;
}
out_sizes;
ivec4 out_sizes;
};

layout(set = 0, binding = 4) uniform PRECISION restrict InSizes {
ivec4 data;
}
in_sizes;
ivec4 in_sizes;
};

layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes {
ivec4 data;
}
other_sizes;
ivec4 other_sizes;
};

layout(set = 0, binding = 6) uniform PRECISION restrict BroadcastParams {
ivec2 data;
}
broadcast_params;
ivec2 broadcast_params;
};

layout(set = 0, binding = 7) uniform PRECISION restrict Alpha {
float data;
}
alpha;
float alpha;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 2;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 idx = to_tensor_idx(pos, out_sizes.data);
const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim);

if (any(greaterThanEqual(idx, out_sizes.data))) {
if (any(greaterThanEqual(idx, out_sizes))) {
return;
}

ivec4 in_idx = broadcast_indices(idx, in_sizes.data);
ivec4 in_idx = broadcast_indices(idx, in_sizes);
VEC4_T in_texel = VEC4_T(texelFetch(
image_in,
to_texture_pos(in_idx, in_sizes.data),
to_texture_pos(in_idx, in_sizes, packed_dim),
0));

ivec4 other_idx = broadcast_indices(idx, other_sizes.data);
ivec4 other_idx = broadcast_indices(idx, other_sizes);
VEC4_T other_texel = VEC4_T(texelFetch(
image_other,
to_texture_pos(other_idx, other_sizes.data),
to_texture_pos(other_idx, other_sizes, packed_dim),
0));

// Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
if (broadcast_params.data.x > 0) {
if (broadcast_params.x > 0) {
in_texel = in_texel.xxxx;
}
if (broadcast_params.data.y > 0) {
if (broadcast_params.y > 0) {
other_texel = other_texel.xxxx;
}

imageStore(image_out, pos, VEC4_T(op(in_texel, other_texel, alpha.data)));
imageStore(image_out, pos, VEC4_T(op(in_texel, other_texel, alpha)));
}
4 changes: 0 additions & 4 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ binary_op:
DTYPE: float
PACKING: C_packed
generate_variant_forall:
PACKING:
- VALUE: C_packed
- VALUE: W_packed
- VALUE: H_packed
DTYPE:
- VALUE: half
- VALUE: float
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
#define VEC4_T ${texel_type(DTYPE)}
#define SCALAR_T ${texel_component_type(DTYPE)}

#define to_tensor_idx to_tensor_idx_${PACKING}
#define get_packed_stride get_packed_stride_${PACKING}

#include "indexing_utils.h"

$if DTYPE == "half":
Expand All @@ -31,25 +28,24 @@ layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
buffer_in;

// Corresponds to {1,4,3,9} in the example below.
layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes {
ivec4 data;
}
gpu_sizes;
layout(set = 0, binding = 2) uniform PRECISION restrict Sizes {
ivec4 sizes;
};

// Corresponds to {3,3,1,11} in the example below.
layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
ivec4 data;
}
original_sizes;
ivec4 original_sizes;
};

// Corresponds to {1,12} in the example below.
layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes {
ivec2 data;
}
padded_sizes;
ivec2 padded_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 2;

/*
* Computes special prepacking for a depthwise convolution. Each shader invocation
* calculates the input buffer location to read into the desired texel. This
Expand Down Expand Up @@ -77,26 +73,26 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data);
const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim);

if (any(greaterThanEqual(idx, gpu_sizes.data))) {
if (any(greaterThanEqual(idx, sizes))) {
return;
}

// As in usual staging shaders, map from GPU texel position to normal CPU
// buffer indices: (9,3) -> (4,3,9)
const int base_index = to_buffer_i(idx, gpu_sizes.data);
const int base_index = to_nchw_i(idx, sizes);
const ivec4 p0 =
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(gpu_sizes.data);
base_index + ivec4(0, 1, 2, 3) * get_nchw_stride(sizes, packed_dim);

// Re-map the normal CPU buffer indices to special indices, through a series
// of mappings: reshape is a no-op to the underlying indices, so we only map
// for pad and permute.
const int Np = padded_sizes.data.x;
const int N = original_sizes.data.w;
const int C = original_sizes.data.z;
const int H = original_sizes.data.y;
const int W = original_sizes.data.x;
const int Np = padded_sizes.x;
const int N = original_sizes.w;
const int C = original_sizes.z;
const int H = original_sizes.y;
const int W = original_sizes.x;

// Undo step 3 permute: (4,3,1,9) -> (3,4,1,9)
const ivec4 p1 = swap_adj_dims(p0, 4, (Np / 4), (C * H * W));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
conv2d_dw_prepack_weights:
parameter_names_with_default_values:
DTYPE: float
PACKING: C_packed
generate_variant_forall:
DTYPE:
- VALUE: half
Expand Down
55 changes: 25 additions & 30 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
#define VEC4_T ${texel_type(DTYPE)}
#define SCALAR_T ${texel_component_type(DTYPE)}

#define to_tensor_idx to_tensor_idx_${PACKING}
#define get_packed_stride get_packed_stride_${PACKING}

#include "indexing_utils.h"

$if DTYPE == "half":
Expand All @@ -26,30 +23,28 @@ layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
BUF_T data[];
}
buffer_in;
BUF_T buffer_in[];
};

// Corresponds to {1,4,9,24} in the example below.
layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes {
ivec4 data;
}
gpu_sizes;
layout(set = 0, binding = 2) uniform PRECISION restrict Sizes {
ivec4 sizes;
};

// Corresponds to {3,3,7,10} in the example below.
layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
ivec4 data;
}
original_sizes;
ivec4 original_sizes;
};

// Corresponds to {8,12} in the example below.
layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes {
ivec2 data;
}
padded_sizes;
ivec2 padded_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 2;

/*
* Computes special prepacking for a 2D convolution. Each shader invocation
* calculates the input buffer location to read into the desired texel. This
Expand Down Expand Up @@ -91,27 +86,27 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data);
const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim);

if (any(greaterThanEqual(idx, gpu_sizes.data))) {
if (any(greaterThanEqual(idx, sizes))) {
return;
}

// As in usual staging shaders, map from GPU texel position to normal CPU
// buffer indices: (24,9) -> (4,9,24)
const int base_index = to_buffer_i(idx, gpu_sizes.data);
const int base_index = to_nchw_i(idx, sizes);
const ivec4 p0 =
base_index + ivec4(0, 1, 2, 3) * get_packed_stride(gpu_sizes.data);
base_index + ivec4(0, 1, 2, 3) * get_nchw_stride(sizes, packed_dim);

// Re-map the normal CPU buffer indices to special indices, through a series
// of mappings: reshape is a no-op to the underlying indices, so we only map
// for pad and permute.
const int Np = padded_sizes.data.y;
const int Cp = padded_sizes.data.x;
const int N = original_sizes.data.w;
const int C = original_sizes.data.z;
const int H = original_sizes.data.y;
const int W = original_sizes.data.x;
const int Np = padded_sizes.y;
const int Cp = padded_sizes.x;
const int N = original_sizes.w;
const int C = original_sizes.z;
const int H = original_sizes.y;
const int W = original_sizes.x;

// Undo step 6 premute: (4,3,3,24) -> (3,4,3,24)
// Undo step 4 permute: (12,3,2,12) -> (12,2,3,12)
Expand All @@ -130,10 +125,10 @@ void main() {
const ivec4 mask = ivec4(greaterThanEqual(c, ivec4(C))) |
ivec4(greaterThanEqual(n, ivec4(N)));

SCALAR_T val_x = mix(SCALAR_T(buffer_in.data[p5.x]), 0, mask.x);
SCALAR_T val_y = mix(SCALAR_T(buffer_in.data[p5.y]), 0, mask.y);
SCALAR_T val_z = mix(SCALAR_T(buffer_in.data[p5.z]), 0, mask.z);
SCALAR_T val_w = mix(SCALAR_T(buffer_in.data[p5.w]), 0, mask.w);
SCALAR_T val_x = mix(SCALAR_T(buffer_in[p5.x]), 0, mask.x);
SCALAR_T val_y = mix(SCALAR_T(buffer_in[p5.y]), 0, mask.y);
SCALAR_T val_z = mix(SCALAR_T(buffer_in[p5.z]), 0, mask.z);
SCALAR_T val_w = mix(SCALAR_T(buffer_in[p5.w]), 0, mask.w);

VEC4_T texel = VEC4_T(val_x, val_y, val_z, val_w);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
conv2d_prepack_weights:
parameter_names_with_default_values:
DTYPE: float
PACKING: C_packed
generate_variant_forall:
DTYPE:
- VALUE: half
Expand Down
Loading