Skip to content

Commit e6f395c

Browse files
committed
Update on "[ET-VK] Dynamic shape support in Vulkan Backend"
## Context This changeset exposes API functions to the `ComputeGraph` class that allow inputs to be resized, and for the resizing to propagate through the graph via re-calculation of output shapes. Differential Revision: [D54754546](https://our.internmc.facebook.com/intern/diff/D54754546/) [ghstack-poisoned]
2 parents 144b957 + 6fd8f10 commit e6f395c

File tree

13 files changed

+42
-869
lines changed

13 files changed

+42
-869
lines changed

backends/vulkan/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ load(":targets.bzl", "define_common_targets")
33

44
oncall("executorch")
55

6-
define_common_targets()
6+
define_common_targets(is_fbcode = True)
77

88
runtime.python_library(
99
name = "vulkan_preprocess",

backends/vulkan/targets.bzl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
1-
load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_fbcode")
2-
load("@fbsource//tools/build_defs:glob_defs.bzl", "subdir_glob")
31
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
42

5-
def get_glsl_image_format():
6-
if native.read_config("pt", "vulkan_full_precision", "0") == "0":
7-
return "rgba16f"
8-
return "rgba32f"
9-
10-
def vulkan_spv_shader_lib(name, spv_filegroup):
3+
def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False):
114
gen_aten_vulkan_spv_target = "//caffe2/tools:gen_aten_vulkan_spv_bin"
125
glslc_path = "//caffe2/fb/vulkan/dotslash:glslc"
13-
if is_fbcode():
6+
if is_fbcode:
147
gen_aten_vulkan_spv_target = "//caffe2:gen_vulkan_spv_bin"
158
glslc_path = "//caffe2/fb/vulkan/tools:glslc"
169

10+
glsl_paths = []
11+
12+
# TODO(ssjia): remove the need for subpath once subdir_glob is enabled in OSS
13+
for target, subpath in spv_filegroups.items():
14+
glsl_paths.append("$(location {})/{}".format(target, subpath))
15+
1716
genrule_cmd = [
1817
"$(exe {})".format(gen_aten_vulkan_spv_target),
19-
"--glsl-paths $(location {})".format(spv_filegroup),
20-
"--output-path $OUT --env FLOAT_IMAGE_FORMAT={}".format(get_glsl_image_format()),
18+
"--glsl-paths {}".format(" ".join(glsl_paths)),
19+
"--output-path $OUT",
2120
"--glslc-path=$(exe {})".format(glslc_path),
2221
"--tmp-dir-path=$OUT",
2322
]
@@ -49,7 +48,7 @@ def vulkan_spv_shader_lib(name, spv_filegroup):
4948
],
5049
)
5150

52-
def define_common_targets():
51+
def define_common_targets(is_fbcode = False):
5352
runtime.genrule(
5453
name = "gen_vk_delegate_schema",
5554
srcs = [
@@ -89,14 +88,17 @@ def define_common_targets():
8988

9089
runtime.filegroup(
9190
name = "vulkan_graph_runtime_shaders",
92-
srcs = subdir_glob([
93-
("runtime/graph/ops/glsl", "*"),
91+
srcs = native.glob([
92+
"runtime/graph/ops/glsl/*",
9493
]),
9594
)
9695

9796
vulkan_spv_shader_lib(
9897
name = "vulkan_graph_runtime_shaderlib",
99-
spv_filegroup = ":vulkan_graph_runtime_shaders",
98+
spv_filegroups = {
99+
":vulkan_graph_runtime_shaders": "runtime/graph/ops/glsl",
100+
},
101+
is_fbcode = is_fbcode,
100102
)
101103

102104
runtime.cxx_library(

examples/models/llama2/quantize.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -916,8 +916,9 @@ def linear_forward_8da4w(
916916
x, weight_int8, scales, zeros, out_features, group_size, precision
917917
):
918918
x = per_token_dynamic_quant(x)
919-
origin_x_size = x.size()
920-
x = x.reshape(-1, origin_x_size[-1])
919+
# TODO: verify and remove following reshape code
920+
# origin_x_size = x.size()
921+
# x = x.reshape(-1, origin_x_size[-1])
921922

922923
# TODO: better API
923924
# weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
@@ -939,8 +940,8 @@ def linear_forward_8da4w(
939940
# w_dq = w_dq.to(torch.float16)
940941
c = torch.nn.functional.linear(x, w_dq)
941942

942-
new_shape = origin_x_size[:-1] + (out_features,)
943-
c = c.reshape(new_shape)
943+
# new_shape = origin_x_size[:-1] + (out_features,)
944+
# c = c.reshape(new_shape)
944945

945946
return c
946947

@@ -1144,7 +1145,8 @@ def __init__(
11441145

11451146
def forward(self, input: torch.Tensor) -> torch.Tensor:
11461147
input = input.to(self.precision)
1147-
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
1148+
# padding is removed for perf
1149+
# input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
11481150
return linear_forward_8da4w(
11491151
input,
11501152
self.weight,
@@ -1387,6 +1389,10 @@ def make_names_and_values_dict_func(q, qparams):
13871389

13881390
def convert_for_runtime(self, model):
13891391
replace_linear_8da4w(
1390-
model, self.groupsize, self.inner_k_tiles, self.padding_allowed
1392+
model,
1393+
self.groupsize,
1394+
self.padding_allowed,
1395+
torch.int8,
1396+
self.precision,
13911397
)
13921398
return model

examples/portable/scripts/export.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import argparse
1010
import logging
1111

12+
import torch
13+
1214
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
1315

1416
from ...models import MODEL_NAME_TO_MODEL
@@ -75,4 +77,5 @@ def main() -> None:
7577

7678

7779
if __name__ == "__main__":
78-
main() # pragma: no cover
80+
with torch.no_grad():
81+
main() # pragma: no cover

examples/sdk/scripts/export_bundled_program.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@
88

99
import argparse
1010

11-
from typing import List, Union
11+
from typing import List
1212

1313
import torch
1414

15-
from executorch.exir import (
16-
ExecutorchProgram,
17-
ExecutorchProgramManager,
18-
MultiMethodExecutorchProgram,
19-
)
15+
from executorch.exir import ExecutorchProgramManager
2016
from executorch.sdk import BundledProgram
2117
from executorch.sdk.bundled_program.config import (
2218
MethodInputType,
@@ -33,11 +29,7 @@
3329

3430

3531
def save_bundled_program(
36-
executorch_program: Union[
37-
ExecutorchProgram,
38-
MultiMethodExecutorchProgram,
39-
ExecutorchProgramManager,
40-
],
32+
executorch_program: ExecutorchProgramManager,
4133
method_test_suites: List[MethodTestSuite],
4234
output_path: str,
4335
):

exir/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
_capture_legacy_do_not_use,
1111
CallSpec,
1212
capture,
13-
capture_multiple,
1413
CaptureConfig,
1514
EdgeCompileConfig,
1615
ExecutorchBackendConfig,
@@ -23,9 +22,6 @@
2322
ExecutorchProgram,
2423
ExecutorchProgramManager,
2524
ExirExportedProgram,
26-
multi_method_program_to_executorch,
27-
MultiMethodExecutorchProgram,
28-
MultiMethodExirExportedProgram,
2925
to_edge,
3026
)
3127
from executorch.exir.tracer import ExirDynamoConfig
@@ -37,7 +33,6 @@
3733
"emit_program",
3834
"EmitterOutput",
3935
"capture",
40-
"capture_multiple",
4136
"_capture_legacy_do_not_use",
4237
"CallSpec",
4338
"ExportedProgram",
@@ -49,12 +44,9 @@
4944
"EdgeProgramManager",
5045
"ExecutorchProgramManager",
5146
"edge_to_executorch_passes",
52-
"MultiMethodExirExportedProgram",
53-
"MultiMethodExecutorchProgram",
5447
"CaptureConfig",
5548
"EdgeCompileConfig",
5649
"ExecutorchBackendConfig",
5750
"Value",
58-
"multi_method_program_to_executorch",
5951
"ExirDynamoConfig",
6052
]

exir/capture/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
_capture_legacy_do_not_use,
1111
CallSpec,
1212
capture,
13-
capture_multiple,
1413
)
1514

1615
from executorch.exir.capture._config import (
@@ -23,7 +22,6 @@
2322
"CallSpec",
2423
"capture",
2524
"_capture_legacy_do_not_use",
26-
"capture_multiple",
2725
"CaptureConfig",
2826
"EdgeCompileConfig",
2927
"ExecutorchBackendConfig",

exir/capture/_capture.py

Lines changed: 2 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from collections import namedtuple
1010
from contextlib import contextmanager
1111
from types import MethodType
12-
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union
12+
from typing import Any, Callable, cast, List, Optional, Tuple
1313

1414
import torch
1515
from executorch.exir.capture._config import CaptureConfig
1616
from executorch.exir.error import ExportError, ExportErrorType, InternalError
17-
from executorch.exir.program import ExirExportedProgram, MultiMethodExirExportedProgram
17+
from executorch.exir.program import ExirExportedProgram
1818
from executorch.exir.program._program import _transform, HackedUpExportedProgramDONOTUSE
1919
from executorch.exir.tracer import (
2020
_default_decomposition_table,
@@ -360,137 +360,6 @@ def convert_to_fake(x):
360360
return ExirExportedProgram(ep, False)
361361

362362

363-
@compatibility(is_backward_compatible=False)
364-
def capture_multiple(
365-
m: Union[torch.nn.Module, Callable[..., Any]],
366-
args: Union[Dict[str, Tuple[Value, ...]], Tuple[Value, ...]],
367-
config: Optional[CaptureConfig] = None,
368-
prim_getters: Optional[Set[str]] = None,
369-
dynamic_shapes: Optional[Union[Dict[str, Any], List[Any]]] = None,
370-
):
371-
"""
372-
capture_multiple traces either an nn.Module or just a callable with PyTorch
373-
operations inside and produce a single MultiMethodExirExportedProgram that
374-
can potentially have multiple entry points. When multiple entry points
375-
are traced, each of them is stored separately in the resulting
376-
MultiMethodExirExportedProgram without sharing state.
377-
378-
Args:
379-
m: the `nn.Module` or callable to trace.
380-
381-
args: Tracing example inputs.
382-
383-
When `m` is an nn.Module, `args` can be
384-
1) A dictionary that maps names of method to their tracing example inputs.
385-
in this case, all specified methods will be captured.
386-
2) A tuple. In this case, `forward` method of `m` will be captured. It is
387-
equivalent to passing {"forward", tuple-type-args}
388-
389-
When `m` is a non-Module callable, `args` must be a Tuple containing
390-
tracing example inputs.
391-
392-
config: A CaptureConfig object that specifies how to interpret the
393-
program being captured.
394-
395-
prim_getters: A set of primitive getter functions to capture the return values of
396-
397-
dynamic_shapes: Input dynamic shapes.
398-
399-
When `m` is an nn.Module, `dynamic_shapes` is a dictionary that maps names of method
400-
to their input dynamic shapes.
401-
402-
When `m` is a non-Module callable, `dynamic_shapes` is a list of input dynamic shapes.
403-
404-
Returns:
405-
A MultiMethodExirExportedProgram.
406-
407-
if `m` is an nn.Module, returned program would have multiple
408-
captured methods, each corresponding to one entry in args dictionary.
409-
410-
if `m` is a non-Module callable, returned program would have a single
411-
captured method named `forward`.
412-
413-
Raises:
414-
AssertionError if given method name do not reference a valid method
415-
on the given nn.Module.
416-
"""
417-
warnings.warn(
418-
"This function is now deprecated, please use `torch.export and exir.to_edge` instead.",
419-
DeprecationWarning,
420-
stacklevel=1,
421-
)
422-
# Normalize m and args.
423-
compile_specs = []
424-
prim_getter_cache: Optional[Dict[str, Any]] = None
425-
if isinstance(m, torch.nn.Module):
426-
if dynamic_shapes is not None:
427-
assert isinstance(
428-
dynamic_shapes, dict
429-
), f"Expected a dict for dynamic_shapes, got {type(dynamic_shapes)}"
430-
431-
if isinstance(args, tuple):
432-
compile_specs.append(
433-
CompileSpec(
434-
"forward",
435-
m.forward,
436-
args,
437-
(
438-
dynamic_shapes["forward"]
439-
if dynamic_shapes and "forward" in dynamic_shapes
440-
else None
441-
),
442-
)
443-
)
444-
else:
445-
assert isinstance(
446-
args, dict
447-
), f"Expected a tuple or Dict[str, tuple], got {type(args)}"
448-
for method_name, method_args in args.items():
449-
compile_specs.append(
450-
CompileSpec(
451-
method_name,
452-
getattr(m, method_name),
453-
method_args,
454-
(
455-
dynamic_shapes[method_name]
456-
if dynamic_shapes and method_name in dynamic_shapes
457-
else None
458-
),
459-
)
460-
)
461-
if prim_getters is not None:
462-
prim_getter_cache = {}
463-
for getter in prim_getters:
464-
prim_getter_cache[getter] = getattr(m, getter)()
465-
else:
466-
# Reaching here means `m` is a non-Module callable.
467-
assert isinstance(
468-
m, Callable
469-
), f"Only nn.Module or callable allowed, got {type(m)}"
470-
assert isinstance(
471-
args, tuple
472-
), f"When tracing a non-Module callable, `args` must be a tuple of tracing inputs, but got {type(args)}"
473-
assert (
474-
prim_getters is None
475-
), "Caller should not specify primitive getter functions when only providing a callable as input"
476-
if dynamic_shapes is not None:
477-
assert isinstance(
478-
dynamic_shapes, list
479-
), f"Expected a list for constraints, got {type(dynamic_shapes)}"
480-
compile_specs.append(CompileSpec("forward", m, args, dynamic_shapes))
481-
482-
method_name_to_prog = {}
483-
for compile_spec in compile_specs:
484-
method_name_to_prog[compile_spec.method_name] = capture(
485-
compile_spec.callable,
486-
compile_spec.args,
487-
config,
488-
compile_spec.dynamic_shapes,
489-
)
490-
491-
return MultiMethodExirExportedProgram(method_name_to_prog, prim_getter_cache)
492-
493-
494363
# This is to bootstrap the missing meta["val"] when 1. ph consists of scalar
495364
# 2. meta["val"] is not properly set in dispatch_trace.
496365
def _instantiate_missing_placeholder_val_with_real_inputs(gm, args):

exir/program/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
ExecutorchProgram,
1414
ExecutorchProgramManager,
1515
ExirExportedProgram,
16-
multi_method_program_to_executorch,
17-
MultiMethodExecutorchProgram,
18-
MultiMethodExirExportedProgram,
1916
to_edge,
2017
)
2118

@@ -25,9 +22,6 @@
2522
"_to_edge",
2623
"to_edge",
2724
"edge_to_executorch_passes",
28-
"MultiMethodExirExportedProgram",
29-
"MultiMethodExecutorchProgram",
30-
"multi_method_program_to_executorch",
3125
"EdgeProgramManager",
3226
"ExecutorchProgramManager",
3327
]

0 commit comments

Comments
 (0)