Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/GraphConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ GraphConfig::GraphConfig() {
local_wg_size_override = {};

expect_dynamic_shapes = false;
force_resize = false;

external_adapter = nullptr;
}
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/GraphConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ struct GraphConfig final {

// Whether or not the ComputeGraph should expect input shapes to be dynamic
bool expect_dynamic_shapes;
// Used for testing/debugging only. Forces ExecuteNode to trigger the resize
// function even if none of the inputs have been updated.
bool force_resize = false;

// Execution properties that determine specifics re: how command buffer
// submission is handled, etc. 0 means this field is not set.
Expand Down
5 changes: 3 additions & 2 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ ExecuteNode::ExecuteNode(
name_(name) {}

bool ExecuteNode::trigger_resize(ComputeGraph* graph) {
const bool any_arg_updated = was_any_arg_updated(graph);
if (resize_fn_ && any_arg_updated) {
bool any_arg_updated = was_any_arg_updated(graph);
if (resize_fn_ && (any_arg_updated || graph->graphconfig().force_resize)) {
resize_fn_(graph, args_, resize_args_);
any_arg_updated = true;
}
return any_arg_updated;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
config.set_storage_type_override(default_storage_type);
config.set_memory_layout_override(default_memory_layout);
config.force_resize = true;
graph = new ComputeGraph(config);

if (test_dtype == at::kHalf) {{
Expand Down
Loading