Commit db18f3e
committed
[TIR][Target] Support device call compilation
This PR introduces support for device call compilation in TVM by enhancing the BindTarget pass to properly handle functions called from both host and device contexts. The key improvement is the ability to automatically create host-specific duplicates of functions that are called from both host and device code, ensuring proper target binding for heterogeneous compilation.
- **Function Classification**: Analyzes call patterns to identify functions called from host vs device contexts
- **Smart Target Binding**: Automatically binds appropriate targets based on calling context:
- Functions called only from host → host target
- Functions called only from device → device target
- Functions called from both → device target + host duplicate
- **Call Site Updates**: Updates call sites in externally exposed functions to use appropriate duplicates
- Improved device function extraction and kernel generation
- Better handling of error propagation for different device types
- Enhanced buffer declaration and parameter management
- Support for `__device__` function calls in CUDA kernels
- Proper function signature generation for device functions
- Enhanced calling convention handling
- Updated build pipeline to handle device call compilation
- Improved target-specific compilation logic
The following example demonstrates how the BindTarget pass handles functions called from both host and device contexts:
```python
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(a: T.int32, b: T.int32) -> T.int32:
return a + b
@T.prim_func
def main(
A: T.Buffer((128, 128), "int32"),
B: T.Buffer((128, 128), "int32"),
C: T.Buffer((128, 128), "int32"),
):
T.func_attr({"global_symbol": "main"})
length: T.int32 = Module.add(64, 64) # Call from host
for bx in T.thread_binding(length, "blockIdx.x"):
for tx in T.thread_binding(length, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from device
```
After applying `BindTarget(cuda, host="llvm")`, the pass automatically:
1. Creates a device version of `add` with CUDA target
2. Creates a host duplicate `add_host` with LLVM target
3. Updates the main function to call `add_host` from host context and `add` from device context
This enables seamless compilation of mixed host/device code while maintaining proper target-specific optimizations and code generation.
- **Automatic Target Binding**: No manual target annotation required for most use cases
- **Heterogeneous Compilation**: Proper support for functions called from multiple contexts
- **Code Reuse**: Shared functions can be called from both host and device without duplication
- **Performance**: Maintains target-specific optimizations for each context
- **Developer Experience**: Simplifies writing mixed host/device code
The implementation is backward compatible and integrates seamlessly with existing TVM compilation pipelines.1 parent 8e478f5 commit db18f3e
File tree
10 files changed
+594
-70
lines changed- python/tvm/tir
- src
- target
- opt
- source
- tir/transforms
- tests/python
- codegen
- tir-transform
10 files changed
+594
-70
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
21 | | - | |
| 20 | + | |
22 | 21 | | |
23 | 22 | | |
24 | 23 | | |
| |||
28 | 27 | | |
29 | 28 | | |
30 | 29 | | |
31 | | - | |
| 30 | + | |
32 | 31 | | |
33 | 32 | | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
34 | 37 | | |
35 | 38 | | |
36 | 39 | | |
37 | | - | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
38 | 44 | | |
39 | 45 | | |
40 | 46 | | |
41 | 47 | | |
42 | | - | |
| 48 | + | |
43 | 49 | | |
44 | | - | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
45 | 101 | | |
46 | 102 | | |
47 | | - | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | | - | |
53 | | - | |
54 | | - | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
62 | | - | |
63 | | - | |
64 | | - | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
65 | 111 | | |
66 | | - | |
67 | | - | |
68 | | - | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
69 | 119 | | |
70 | 120 | | |
71 | 121 | | |
| |||
162 | 212 | | |
163 | 213 | | |
164 | 214 | | |
165 | | - | |
| 215 | + | |
166 | 216 | | |
167 | 217 | | |
168 | 218 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
66 | 66 | | |
67 | 67 | | |
68 | 68 | | |
69 | | - | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
70 | 72 | | |
71 | 73 | | |
72 | 74 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
134 | 134 | | |
135 | 135 | | |
136 | 136 | | |
137 | | - | |
138 | | - | |
139 | | - | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
140 | 143 | | |
141 | 144 | | |
142 | 145 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
140 | 140 | | |
141 | 141 | | |
142 | 142 | | |
143 | | - | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
144 | 156 | | |
145 | 157 | | |
146 | 158 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
49 | | - | |
| 49 | + | |
| 50 | + | |
50 | 51 | | |
51 | 52 | | |
52 | 53 | | |
| |||
0 commit comments