Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class Context final {
PipelineBarrier&,
const utils::uvec3&,
const utils::uvec3&,
const SpecVarList&,
VkFence fence_handle,
Arguments&&...);

Expand Down Expand Up @@ -494,6 +495,7 @@ inline bool Context::submit_compute_job(
PipelineBarrier& pipeline_barrier,
const utils::uvec3& global_work_group,
const utils::uvec3& local_work_group_size,
const SpecVarList& specialization,
VkFence fence_handle,
Arguments&&... arguments) {
// If any of the provided arguments does not have memory associated with it,
Expand Down Expand Up @@ -537,7 +539,7 @@ inline bool Context::submit_compute_job(

// Factor out template parameter independent code to minimize code bloat.
DescriptorSet descriptor_set =
get_descriptor_set(shader, local_work_group_size);
get_descriptor_set(shader, local_work_group_size, specialization);

detail::bind(
descriptor_set,
Expand Down
8 changes: 5 additions & 3 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ ExecuteNode::ExecuteNode(
const std::vector<ArgGroup>& args,
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
const std::vector<ValueRef>& resize_args,
const api::SpecVarList& spec_vars)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
args_(args),
params_(params),
resize_fn_(resize_fn),
resize_args_(resize_args) {
resize_args_(resize_args),
spec_vars_(spec_vars) {
graph.update_descriptor_counts(shader, /*execute = */ true);
}

Expand All @@ -40,7 +42,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

api::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_);
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class ExecuteNode final {
const std::vector<ArgGroup>& args,
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});
const std::vector<ValueRef>& resize_args = {},
const api::SpecVarList& spec_vars = {});

~ExecuteNode() = default;

Expand All @@ -76,6 +77,7 @@ class ExecuteNode final {
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
const ResizeFunction resize_fn_;
const std::vector<ValueRef> resize_args_;
const api::SpecVarList spec_vars_;
};

} // namespace vkcompute
46 changes: 46 additions & 0 deletions backends/vulkan/test/glsl/fill_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

$PRECISION = "highp"
$DTYPE = "float"

#define PRECISION ${PRECISION}

#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

#include "indexing_utils.h"

layout(std430) buffer;

layout(set = 0, binding = 0) buffer PRECISION restrict writeonly Buffer {
VEC4_T data[];
}
buffer_in;

layout(set = 0, binding = 1) uniform PRECISION restrict Params {
int len;
}
params;



layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const float scale = 1;
layout(constant_id = 4) const float offset = 0;

void main() {
const int i = ivec3(gl_GlobalInvocationID).x;

const int base = 4 * i;
if (base < params.len) {
buffer_in.data[i] = scale * (VEC4_T(base) + VEC4_T(0, 1, 2, 3)) + offset;
}
}
4 changes: 4 additions & 0 deletions backends/vulkan/test/utils/test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void record_nchw_to_image_op(
pipeline_barrier,
v_dst.virtual_extents(),
adaptive_work_group_size(v_dst.virtual_extents()),
{},
VK_NULL_HANDLE,
v_dst.image(
pipeline_barrier,
Expand All @@ -47,6 +48,7 @@ void record_image_to_nchw_op(
pipeline_barrier,
v_src.virtual_extents(),
adaptive_work_group_size(v_src.virtual_extents()),
{},
VK_NULL_HANDLE,
v_src.image(pipeline_barrier, api::PipelineStage::COMPUTE),
dst_buffer,
Expand Down Expand Up @@ -83,6 +85,7 @@ void record_conv2d_prepack_weights_op(
pipeline_barrier,
v_dst.virtual_extents(),
adaptive_work_group_size(v_dst.virtual_extents()),
{},
VK_NULL_HANDLE,
v_dst.image(
pipeline_barrier,
Expand All @@ -109,6 +112,7 @@ void record_binary_op(
pipeline_barrier,
v_dst.virtual_extents(),
adaptive_work_group_size(v_dst.virtual_extents()),
{},
VK_NULL_HANDLE,
v_dst.image(
pipeline_barrier,
Expand Down
35 changes: 35 additions & 0 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,38 @@ TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
ASSERT_TRUE(*(reinterpret_cast<const float*>(copy_data + 6)) == 5.5f);
}

TEST_F(VulkanComputeAPITest, spec_var_shader_test) {
size_t len = 16;
api::StorageBuffer buffer(api::context(), api::kFloat, len);

float scale = 3.0f;
float offset = 1.5f;

{
api::UniformParamsBuffer params(api::context(), int32_t(len));
uint32_t len_div4 = api::utils::div_up(uint32_t(len), uint32_t(4));
api::PipelineBarrier pipeline_barrier{};
api::context()->submit_compute_job(
VK_KERNEL(fill_buffer),
pipeline_barrier,
{64, 1, 1},
{len_div4, 1, 1},
{SV(scale), SV(offset)},
VK_NULL_HANDLE,
buffer.buffer(),
params.buffer());
}

submit_to_gpu();

std::vector<float> data(len);
copy_staging_to_ptr(buffer, data.data(), buffer.nbytes());

for (size_t i = 0; i < len; ++i) {
CHECK_VALUE(data, i, scale * i + offset);
}
}

TEST_F(VulkanComputeAPITest, update_params_between_submit) {
api::context()->set_cmd(/*reusable = */ true);
std::vector<int64_t> sizes = {4, 4, 2};
Expand Down Expand Up @@ -126,6 +158,7 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) {
pipeline_barrier,
{4, 4, 4},
{4, 4, 4},
{},
VK_NULL_HANDLE,
a.image(
pipeline_barrier,
Expand Down Expand Up @@ -185,6 +218,7 @@ void test_storage_buffer_type(const size_t len) {
pipeline_barrier,
{64, 1, 1},
{len_div4, 1, 1},
{},
VK_NULL_HANDLE,
buffer.buffer(),
params.buffer());
Expand Down Expand Up @@ -880,6 +914,7 @@ void run_from_gpu_test(
pipeline_barrier,
vten.virtual_extents(),
{4, 4, 4},
{},
VK_NULL_HANDLE,
vten.image(
pipeline_barrier,
Expand Down