-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfunctional.py
138 lines (105 loc) · 3.39 KB
/
functional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from __future__ import annotations
from exo import *
from exo.API_scheduling import *
"""
Cursor Example
This example introduces the concept of Cursors in Exo 2 paper and demonstrates
how to use scheduling operators with them to manipulate loops and optimize code.
Cursors allow you to select and refer to parts of the code such as expressions,
statements, and code blocks. They also support spatial navigation within a procedure
to proximate locations.
Key concepts covered:
- Finding Cursors with pattern-matching
- Cursor navigation
- Applying scheduling primitives using cursors
- Cursor forwarding after code transformations
- Defining a new scheduling operation
"""
"""
1: Basic loop example using Exo 2
GEMV kernel: y = A * x
Args:
M (size): Number of rows in matrix A
N (size): Number of columns in matrix A
A (tensor): M x N matrix stored in DRAM
x (tensor): N-dimensional vector stored in DRAM
y (tensor): M-dimensional vector stored in DRAM
"""
@proc
def gemv(M: size, N: size, A: f32[M, N], x: f32[N], y: f32[M]):
assert M % 8 == 0
assert N % 8 == 0
for i in seq(0, M):
for j in seq(0, N):
y[i] += A[i, j] * x[j]
print("1: Original GEMV kernel")
print(gemv)
print()
"""
2: Finding cursors
"""
# Find a cursor to the i loop by name
i_loop = gemv.find_loop("i")
# Find the same i loop by pattern
i_loop2 = gemv.find("for i in _: _")
# Check that two cursors are pointing to the same 'i' loop
assert i_loop == i_loop2
print("2: i_loop points to:")
print(i_loop)
print()
"""
3: Navigating with cursors
"""
# Find cursors to key parts of the code
j_loop = i_loop.body()[0] # j is the only statement in i's body
C_store = j_loop.body()[0] # y[i] = ... is the only statement in j's body
j_loop_parent = j_loop.parent() # The parent of the j loop
# Check that j_loop's parent is indeed pointing to the i_loop
assert i_loop == j_loop_parent
print("3: j_loop points to:")
print(j_loop)
print()
"""
4: Applying scheduling primitives & Cursor forwarding
"""
# First, rename the gemv
g = rename(gemv, "gemv_scheduled")
# Divide the i loop by 8
g = divide_loop(g, i_loop, 8, ["io", "ii"], perfect=True)
# Divide the j loop by 8
g = divide_loop(g, j_loop, 8, ["jo", "ji"], perfect=True)
# Now, we want to reorder ii and jo loops, by lifting the scope of j_loop
# We can still use the j_loop cursor!
g1 = lift_scope(g, j_loop)
g2 = lift_scope(g, g.forward(j_loop))
# Assert that g1 and g2 are the same (`j_loop` is implicitly forwarded in the first line)
assert g1 == g2
print("4: Tiled gemv")
print(g1)
print("4: g.forward(j_loop) points to:")
print(g.forward(j_loop))
print()
"""
5: Defining a new scheduling operator
"""
def tile_2D(p, i_lp, j_lp, i_itrs, j_itrs, i_sz, j_sz):
"""
Perform a 2D tiling of the i and j loops.
Args:
p: Procedure to be tiled
i_lp: Name of the i loop
j_lp: Name of the j loop
i_itrs: New iterators for the i loop
j_itrs: New iterators for the j loop
i_sz: Tile size for the i loop
j_sz: Tile size for the j loop
"""
p = divide_loop(p, i_lp, i_sz, i_itrs, perfect=True)
p = divide_loop(p, j_lp, j_sz, j_itrs, perfect=True)
p = lift_scope(p, j_itrs[0])
return p
# Example usage of tile_2D to perform 2D tiling on the gemv kernel.
final_g = tile_2D(gemv, i_loop, j_loop, ["io", "ii"], ["jo", "ji"], 8, 8)
print("5: tile_2D applied gemv:")
print(final_g)
__all__ = ["final_g"]