Skip to content

Commit d5ce386

Browse files
committed
[FFI][FEAT] AutoDLPack to enable external tensor args.
This PR introduces autodlpack feature to the tvm ffi. When an ffi Function takes Tensor argument that conforms to DLPack it automatically imports into NDArray and pass as argument. The feature will allow compiled function to directly take torch.Tensor as input argument without extra set of changes. When a function returns NDArray, the return value still needs to be converted back via torch.from_dlpack. However, a common use case is the destination passing, where all inputs outputs are pre-allocated and passed into the function. AutoDLPack effectively enables zero overhead support for a wide range of python arrays. We also added a benchmark script to measure the overall ffi overhead. One thing to note is that there is still continuguous and alignment requirement that is needed by underlying DSL compiler, as of now we use a global value. So x.continugous is still needed before passing the argument if tranpose or other ops are performed.
1 parent bcb68b1 commit d5ce386

File tree

5 files changed

+376
-0
lines changed

5 files changed

+376
-0
lines changed

ffi/scripts/benchmark_dlpack.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
This script is used to benchmark the API overhead of different
19+
python FFI API calling overhead, through DLPack API.
20+
21+
Specifically, we would like to understand the overall overhead
22+
python/C++ API calls. The general goal is to understand the overall
23+
space and get a sense of what are the possible operations.
24+
25+
We pick function f(x, y, z) where x, y, z are length 1 tensors.
26+
The benchmark is running in eager mode so we can see what is possible.
27+
It is orthogonal to other optimizations. For example cudagraph can
28+
eliminate these overheads completely. So the goal is to get a sense
29+
of what is possible under eager mode.
30+
31+
Summary of some takeaways:
32+
- numpy.add roughly takes 0.36 us per call, which gives roughly what can
33+
be done in python env.
34+
- torch.add on gpu takes about 3.7us per call, giving us an idea of what
35+
roughly we need to get to in eager mode.
36+
-
37+
38+
"""
39+
import torch
40+
import numpy as np
41+
from tvm import ffi as tvm_ffi
42+
import time
43+
44+
45+
def print_speed(name, speed):
46+
print(f"{name:<40} {speed} sec/call")
47+
48+
def print_error(name, error):
49+
print(f"{name:<40} {error}")
50+
51+
def baseline_torch_add(repeat):
52+
"""Run torch.add with one element
53+
"""
54+
def run_bench(device):
55+
x = torch.arange(1, device=device)
56+
y = torch.arange(1, device=device)
57+
z = torch.arange(1, device=device)
58+
59+
torch.add(x, y, out=z)
60+
if device == "cuda":
61+
torch.cuda.synchronize()
62+
start = time.time()
63+
for i in range(repeat):
64+
torch.add(x, y, out=z)
65+
# note we deliberately do not use torch.cuda.synchronize()
66+
# because we want to see the overhead of the FFI call.
67+
end = time.time()
68+
print_speed(f"torch.add[{device}]", (end - start) / repeat)
69+
# rough take away: add on cuda roughly takes 3e-6 sec/call
70+
run_bench("cpu")
71+
run_bench("cuda")
72+
73+
def baseline_numpy_add(repeat):
74+
"""Run numpy.add with one element
75+
"""
76+
x = np.arange(1)
77+
y = np.arange(1)
78+
z = np.arange(1)
79+
80+
np.add(x, y, out=z)
81+
start = time.time()
82+
for i in range(repeat):
83+
np.add(x, y, out=z)
84+
end = time.time()
85+
speed = (end - start) / repeat
86+
print_speed("numpy.add", speed)
87+
88+
89+
def baseline_cupy_add(repeat):
90+
"""Run cupy.add with one element
91+
"""
92+
try:
93+
import cupy
94+
except ImportError:
95+
# skip if cupy is not installed
96+
return
97+
x = cupy.arange(1)
98+
y = cupy.arange(1)
99+
z = cupy.arange(1)
100+
101+
cupy.add(x, y, out=z)
102+
start = time.time()
103+
for i in range(repeat):
104+
cupy.add(x, y, out=z)
105+
end = time.time()
106+
speed = (end - start) / repeat
107+
print_speed("cupy.add", speed)
108+
109+
def tvm_ffi_nop(repeat):
110+
"""Overhead of tvm FFI python call via calling a NOP.
111+
112+
testing.nop is defined in c++ and do nothing.
113+
"""
114+
nop = tvm_ffi.get_global_func("testing.nop")
115+
x = tvm_ffi.from_dlpack(torch.arange(1))
116+
y = tvm_ffi.from_dlpack(torch.arange(1))
117+
z = tvm_ffi.from_dlpack(torch.arange(1))
118+
nop(x, y, z)
119+
start = time.time()
120+
for i in range(repeat):
121+
y = tvm_ffi.from_dlpack(x)
122+
end = time.time()
123+
print_speed("tvm.ffi.nop", (end - start) / repeat)
124+
125+
126+
def bench_ffi_nop_from_dlpack(name, x, y, z, repeat):
127+
"""run dlpack conversion + tvm.ffi.nop
128+
129+
Measures overhead of running dlpack for each args then invoke
130+
"""
131+
nop = tvm_ffi.get_global_func("testing.nop")
132+
tx = tvm_ffi.from_dlpack(x)
133+
ty = tvm_ffi.from_dlpack(y)
134+
tz = tvm_ffi.from_dlpack(z)
135+
nop(tx, ty, tz)
136+
137+
start = time.time()
138+
for i in range(repeat):
139+
tx = tvm_ffi.from_dlpack(x)
140+
ty = tvm_ffi.from_dlpack(y)
141+
tz = tvm_ffi.from_dlpack(z)
142+
nop(tx, ty, tz)
143+
end = time.time()
144+
print_speed(name, (end - start) / repeat)
145+
146+
147+
def tvm_ffi_nop_from_torch_dlpack(repeat):
148+
"""run dlpack conversion + tvm.ffi.nop
149+
150+
Measures overhead of running dlpack for each args then invoke
151+
"""
152+
x = torch.arange(1)
153+
y = torch.arange(1)
154+
z = torch.arange(1)
155+
bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(torch)", x, y, z, repeat)
156+
157+
158+
def tvm_ffi_nop_from_numpy_dlpack(repeat):
159+
"""run dlpack conversion + tvm.ffi.nop
160+
161+
Measures overhead of running dlpack for each args then invoke
162+
"""
163+
x = np.arange(1)
164+
y = np.arange(1)
165+
z = np.arange(1)
166+
bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(numpy)", x, y, z, repeat)
167+
168+
169+
def tvm_ffi_self_dlpack_nop(repeat):
170+
"""run dlpack conversion + tvm.ffi.nop
171+
172+
Measures overhead of running dlpack for each args then invoke
173+
"""
174+
x = tvm_ffi.from_dlpack(torch.arange(1))
175+
y = tvm_ffi.from_dlpack(torch.arange(1))
176+
z = tvm_ffi.from_dlpack(torch.arange(1))
177+
bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(tvm)", x, y, z, repeat)
178+
179+
180+
def bench_ffi_nop_from_dlpack(name, x, y, z, repeat):
181+
"""run dlpack conversion + tvm.ffi.nop
182+
183+
Measures overhead of running dlpack for each args then invoke
184+
"""
185+
nop = tvm_ffi.get_global_func("testing.nop")
186+
tx = tvm_ffi.from_dlpack(x)
187+
ty = tvm_ffi.from_dlpack(y)
188+
tz = tvm_ffi.from_dlpack(z)
189+
nop(tx, ty, tz)
190+
191+
start = time.time()
192+
for i in range(repeat):
193+
tx = tvm_ffi.from_dlpack(x)
194+
ty = tvm_ffi.from_dlpack(y)
195+
tz = tvm_ffi.from_dlpack(z)
196+
nop(tx, ty, tz)
197+
end = time.time()
198+
print_speed(name, (end - start) / repeat)
199+
200+
201+
def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat):
202+
"""
203+
Measures overhead of running dlpack for each args then invoke
204+
but uses the legacy torch.utils.dlpack.to_dlpack API
205+
206+
This helps to measure possible implementation overhead of torch.
207+
"""
208+
nop = tvm_ffi.get_global_func("testing.nop")
209+
x = torch.arange(1)
210+
y = torch.arange(1)
211+
z = torch.arange(1)
212+
213+
tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x))
214+
ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y))
215+
tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z))
216+
nop(tx, ty, tz)
217+
218+
start = time.time()
219+
for i in range(repeat):
220+
tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x))
221+
ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y))
222+
tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z))
223+
nop(tx, ty, tz)
224+
end = time.time()
225+
speed = (end - start) / repeat
226+
print_speed("tvm.ffi.nop+from_dlpack(torch.utils)", speed)
227+
228+
229+
def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
230+
"""
231+
Measures overhead of running dlpack via auto convert by directly
232+
take torch.Tensor as inputs.
233+
"""
234+
nop = tvm_ffi.get_global_func("testing.nop")
235+
nop(x, y, z)
236+
start = time.time()
237+
for i in range(repeat):
238+
nop(x, y, z)
239+
end = time.time()
240+
speed = (end - start) / repeat
241+
print_speed(name, speed)
242+
243+
def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"):
244+
"""
245+
Measures overhead of running dlpack via auto convert by directly
246+
take torch.Tensor as inputs.
247+
"""
248+
# use larger to ensure alignment req is met
249+
x = torch.arange(1, device=device)
250+
y = torch.arange(1, device=device)
251+
z = torch.arange(1, device=device)
252+
bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat)
253+
254+
255+
def tvm_ffi_nop_autodlpack_from_numpy(repeat):
256+
"""
257+
Measures overhead of running dlpack via auto convert by directly
258+
take numpy.ndarray as inputs.
259+
"""
260+
# use larger to ensure alignment req is met
261+
x = np.arange(256)
262+
y = np.arange(256)
263+
z = np.arange(256)
264+
bench_tvm_ffi_nop_autodlpack("tvm.ffi.nop.autodlpack(numpy)", x, y, z, repeat)
265+
266+
267+
def bench_to_dlpack(x, name, repeat):
268+
x.__dlpack__()
269+
start = time.time()
270+
for i in range(repeat):
271+
x.__dlpack__()
272+
end = time.time()
273+
speed = (end - start) / repeat
274+
print_speed(name, speed)
275+
276+
277+
def bench_to_dlpack_versioned(x, name, repeat, max_version=(1,1)):
278+
"""
279+
Measures overhead of running dlpack with latest 1.1.
280+
"""
281+
try:
282+
x.__dlpack__(max_version=max_version)
283+
start = time.time()
284+
for i in range(repeat):
285+
x.__dlpack__(max_version=max_version)
286+
end = time.time()
287+
speed = (end - start) / repeat
288+
print_speed(name, speed)
289+
except Exception as e:
290+
print_error(name, e)
291+
292+
293+
def bench_torch_utils_to_dlpack(repeat):
294+
"""
295+
Measures overhead of running torch.utils.dlpack.to_dlpack
296+
"""
297+
x = torch.arange(1)
298+
torch.utils.dlpack.to_dlpack(x)
299+
start = time.time()
300+
for i in range(repeat):
301+
torch.utils.dlpack.to_dlpack(x)
302+
end = time.time()
303+
speed = (end - start) / repeat
304+
print_speed("torch.utils.dlpack.to_dlpack", speed)
305+
306+
307+
def main():
308+
repeat = 1000
309+
print("-----------------------------")
310+
print("Benchmark f(x, y, z) overhead")
311+
print("-----------------------------")
312+
baseline_numpy_add(repeat)
313+
baseline_torch_add(repeat)
314+
baseline_cupy_add(repeat)
315+
tvm_ffi_nop(repeat)
316+
tvm_ffi_nop_from_torch_dlpack(repeat)
317+
tvm_ffi_nop_from_numpy_dlpack(repeat)
318+
tvm_ffi_self_dlpack_nop(repeat)
319+
tvm_ffi_nop_from_torch_utils_to_dlpack(repeat)
320+
tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu")
321+
tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda")
322+
tvm_ffi_nop_autodlpack_from_numpy(repeat)
323+
print("-------------------------------")
324+
print("Benchmark x.__dlpack__ overhead")
325+
print("-------------------------------")
326+
bench_torch_utils_to_dlpack(repeat)
327+
bench_to_dlpack(torch.arange(1), "torch.__dlpack__", repeat)
328+
bench_to_dlpack(np.arange(1), "numpy.__dlpack__", repeat)
329+
bench_to_dlpack(tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__", repeat)
330+
print("---------------------------------------------------")
331+
print("Benchmark x.__dlpack__(max_version=(1,1)) overhead")
332+
print("---------------------------------------------------")
333+
bench_to_dlpack_versioned(torch.arange(1), "torch.__dlpack__(max_version=(1,1))", repeat)
334+
bench_to_dlpack_versioned(np.arange(1), "numpy.__dlpack__(max_version=(1,1))", repeat)
335+
bench_to_dlpack_versioned(tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__(max_version=(1,1))", repeat)
336+
337+
338+
if __name__ == "__main__":
339+
main()

python/tvm/ffi/convert.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ def convert(value: Any) -> Any:
5454
return core._convert_to_ffi_func(value)
5555
elif value is None:
5656
return None
57+
elif hasattr(value, "__dlpack__"):
58+
return core.from_dlpack(
59+
value,
60+
required_alignment=core.__dlpack_auto_import_required_alignment__,
61+
)
5762
elif isinstance(value, Exception):
5863
return core._convert_to_ffi_error(value)
5964
else:

python/tvm/ffi/cython/function.pxi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except
7171
elif isinstance(arg, Object):
7272
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
7373
out[i].v_ptr = (<Object>arg).chandle
74+
elif hasattr(arg, "__dlpack__"):
75+
arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__)
76+
out[i].type_index = kTVMFFINDArray
77+
out[i].v_ptr = (<NDArray>arg).chandle
78+
temp_args.append(arg)
7479
elif isinstance(arg, PyNativeObject):
7580
arg = arg.__tvm_ffi_object__
7681
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)

python/tvm/ffi/cython/ndarray.pxi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
# under the License.
1717

1818
__dlpack_version__ = (1, 1)
19+
__dlpack_auto_import_required_alignment__ = 8
1920
_CLASS_NDARRAY = None
2021

22+
2123
def _set_class_ndarray(cls):
2224
global _CLASS_NDARRAY
2325
_CLASS_NDARRAY = cls

0 commit comments

Comments
 (0)