Skip to content

Commit 315d700

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][ez] Introduce a graph config setting to force resize functions to execute
Pull Request resolved: #15158 Title says it all! A few months back, a mechanism was introduced where an `ExecuteNode` would not call an operator's resize function if none of the arguments were updated. However, this creates a blind spot during testing where the resize function of operators are not tested since the generated operator tests do not modify input sizes. To address this, add a way to force the resize function to be called during testing. ghstack-source-id: 317683546 @exported-using-ghexport Differential Revision: [D84716451](https://our.internmc.facebook.com/intern/diff/D84716451/)
1 parent b7e5cff commit 315d700

File tree

6 files changed

+24
-5
lines changed

6 files changed

+24
-5
lines changed

backends/vulkan/runtime/graph/GraphConfig.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ GraphConfig::GraphConfig() {
6565
local_wg_size_override = {};
6666

6767
expect_dynamic_shapes = false;
68+
force_resize = false;
6869

6970
external_adapter = nullptr;
7071
}

backends/vulkan/runtime/graph/GraphConfig.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ struct GraphConfig final {
3535

3636
// Whether or not the ComputeGraph should expect input shapes to be dynamic
3737
bool expect_dynamic_shapes;
38+
// Used for testing/debugging only. Forces ExecuteNode to trigger the resize
39+
// function even if none of the inputs have been updated.
40+
bool force_resize = false;
3841

3942
// Execution properties that determine specifics re: how command buffer
4043
// submission is handled, etc. 0 means this field is not set.

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ ExecuteNode::ExecuteNode(
2121
name_(name) {}
2222

2323
bool ExecuteNode::trigger_resize(ComputeGraph* graph) {
24-
const bool any_arg_updated = was_any_arg_updated(graph);
25-
if (resize_fn_ && any_arg_updated) {
24+
bool any_arg_updated = was_any_arg_updated(graph);
25+
if (resize_fn_ && (any_arg_updated || graph->graphconfig().force_resize)) {
2626
resize_fn_(graph, args_, resize_args_);
27+
any_arg_updated = true;
2728
}
2829
return any_arg_updated;
2930
}

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ void resize_pool2d_node(
5959

6060
if (is_max_pool2d) {
6161
const ValueRef indices = args.at(0).refs.at(1);
62-
graph->virtual_resize(indices, new_out_sizes);
62+
// For max_pool2d variant, indices tensor will be a 0-dim tensor - only
63+
// resize the indices tensor if this is not the case.
64+
if (graph->sizes_of(indices).size() > 0) {
65+
graph->virtual_resize(indices, new_out_sizes);
66+
}
6367
}
6468
}
6569

backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,21 @@ void resize_unsqueeze_node(
5454
const ValueRef in = args.at(1).refs.at(0);
5555
const ValueRef dims_ref = extra_args.at(0);
5656

57-
const IntListPtr dims = graph->get_int_list(dims_ref);
57+
std::vector<int64_t> dims_vec;
58+
if (graph->is_scalar_or_none(dims_ref)) {
59+
// Handle scalar case
60+
int64_t dim = graph->extract_scalar<int64_t>(dims_ref);
61+
dims_vec.push_back(dim);
62+
} else {
63+
// Handle list case
64+
const IntListPtr dims = graph->get_int_list(dims_ref);
65+
dims_vec.assign(dims->begin(), dims->end());
66+
}
5867

5968
std::vector<int64_t> out_sizes = graph->sizes_of(in);
6069

6170
// Insert singleton dimensions at the specified positions
62-
for (auto dim : *dims) {
71+
for (auto dim : dims_vec) {
6372
int64_t d = dim;
6473
if (d < 0) {
6574
d += static_cast<int64_t>(out_sizes.size()) + 1;

backends/vulkan/test/op_tests/utils/gen_correctness_vk.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
3434
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
3535
config.set_storage_type_override(default_storage_type);
3636
config.set_memory_layout_override(default_memory_layout);
37+
config.force_resize = true;
3738
graph = new ComputeGraph(config);
3839
3940
if (test_dtype == at::kHalf) {{

0 commit comments

Comments
 (0)