Skip to content

Commit d9560ea

Browse files
committed
Merge branch 'main' of https://github.com/tile-ai/tilelang into assume_1121
2 parents cfc429d + bf90a5f commit d9560ea

File tree

7 files changed

+203
-11
lines changed

7 files changed

+203
-11
lines changed

src/tl_templates/cuda/atomic.h

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
169169
}
170170
}
171171

172+
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 890))
172173
template <typename T1, typename T2>
173174
TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
174175
int memory_order = int(cuda::memory_order_relaxed)) {
@@ -236,14 +237,18 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
236237
}
237238
}
238239
} else {
239-
#if CUDART_VERSION >= 11080
240-
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
241-
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
242-
#else
243-
TL_NOT_IMPLEMENTED();
244-
#endif
240+
atomicAdd(reinterpret_cast<NT1 *>(address), cuda_cast<NT1>(val));
245241
}
246242
}
243+
#else
244+
template <typename T1, typename T2>
245+
TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
246+
int memory_order = int(cuda::memory_order_relaxed)) {
247+
using NT1 = typename normalize_atomic_type<T1>::type;
248+
(void)memory_order;
249+
atomicAdd(reinterpret_cast<NT1 *>(&ref), cuda_cast<NT1>(val));
250+
}
251+
#endif
247252

248253
template <typename T1, typename T2>
249254
TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
@@ -643,6 +648,48 @@ AtomicAddx4Ret(float *ref, float *val,
643648
return ret_val;
644649
}
645650
}
651+
#else
652+
TL_DEVICE void AtomicAddx2(float *ref, float *val,
653+
int memory_order = int(cuda::memory_order_relaxed)) {
654+
(void)memory_order;
655+
float2 add_val = *reinterpret_cast<float2 *>(val);
656+
atomicAdd(ref + 0, add_val.x);
657+
atomicAdd(ref + 1, add_val.y);
658+
}
659+
660+
TL_DEVICE float2
661+
AtomicAddx2Ret(float *ref, float *val,
662+
int memory_order = int(cuda::memory_order_relaxed)) {
663+
(void)memory_order;
664+
float2 add_val = *reinterpret_cast<float2 *>(val);
665+
float2 ret;
666+
ret.x = atomicAdd(ref + 0, add_val.x);
667+
ret.y = atomicAdd(ref + 1, add_val.y);
668+
return ret;
669+
}
670+
671+
TL_DEVICE void AtomicAddx4(float *ref, float *val,
672+
int memory_order = int(cuda::memory_order_relaxed)) {
673+
(void)memory_order;
674+
float4 add_val = *reinterpret_cast<float4 *>(val);
675+
atomicAdd(ref + 0, add_val.x);
676+
atomicAdd(ref + 1, add_val.y);
677+
atomicAdd(ref + 2, add_val.z);
678+
atomicAdd(ref + 3, add_val.w);
679+
}
680+
681+
TL_DEVICE float4
682+
AtomicAddx4Ret(float *ref, float *val,
683+
int memory_order = int(cuda::memory_order_relaxed)) {
684+
(void)memory_order;
685+
float4 add_val = *reinterpret_cast<float4 *>(val);
686+
float4 ret;
687+
ret.x = atomicAdd(ref + 0, add_val.x);
688+
ret.y = atomicAdd(ref + 1, add_val.y);
689+
ret.z = atomicAdd(ref + 2, add_val.z);
690+
ret.w = atomicAdd(ref + 3, add_val.w);
691+
return ret;
692+
}
646693
#endif
647694

648695
template <typename T> TL_DEVICE T AtomicLoad(T &ref, int memory_order) {
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import tilelang
2+
import tilelang.testing
3+
import tilelang.language as T
4+
5+
6+
def test_tilelang_intimm():
7+
T.int32(0x7fffffff)
8+
T.int32(-0x7fffffff - 1)
9+
T.uint32(0xffffffff)
10+
T.int64(0x7fffffffffffffff)
11+
T.int64(-0x7fffffffffffffff - 1)
12+
T.uint64(0xffffffffffffffff)
13+
14+
a = T.int32()
15+
a & 0x7fffffff
16+
17+
a = T.uint32()
18+
a & 0xffffffff
19+
20+
a = T.int64()
21+
a & 0x7fffffffffffffff
22+
23+
a = T.uint64()
24+
a & T.uint64(0xffffffffffffffff)
25+
26+
27+
if __name__ == '__main__':
28+
tilelang.testing.main()

testing/python/language/test_tilelang_language_frontend_v2.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,64 @@ def prim_call_macro():
394394
except ValueError:
395395
pass
396396

397+
try:
398+
399+
@T.macro
400+
def macro_with_var(x: T.Ref):
401+
x = 1 # noqa: F841
402+
403+
@T.prim_func
404+
def prim_call_macro():
405+
with T.Kernel(1):
406+
x = T.alloc_var(T.int32)
407+
macro_with_var(x)
408+
409+
assert 'x[0] = 1' in prim_call_macro.script()
410+
finally:
411+
pass
412+
413+
try:
414+
415+
@T.macro
416+
def macro_with_var(x: T.Ref):
417+
x = 1 # noqa: F841
418+
419+
@T.prim_func
420+
def prim_call_macro():
421+
with T.Kernel(1):
422+
x = 1
423+
macro_with_var(x)
424+
425+
raise RuntimeError("Expect to report an error, x should not be passed as T.Var")
426+
except ValueError:
427+
pass
428+
429+
430+
def frame_inside_macro():
431+
432+
@tilelang.jit
433+
def get_sample_kernel():
434+
435+
@T.macro
436+
def transform(x):
437+
return x + 1
438+
439+
@T.prim_func
440+
def sample_kernel(
441+
num_blocks: T.int32,
442+
idx_out: T.Tensor[(32,), T.int32],
443+
):
444+
with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841
445+
fragment = T.alloc_fragment(32, 'int32')
446+
T.copy(idx_out, fragment)
447+
448+
for i in T.Parallel(32):
449+
idx_out[i] = transform(fragment[i])
450+
451+
return sample_kernel
452+
453+
kernel = get_sample_kernel() # noqa: F841
454+
397455

398456
if __name__ == '__main__':
399457
tilelang.testing.main()

tilelang/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
FragmentBuffer, # noqa: F401
2323
SharedBuffer, # noqa: F401
2424
LocalBuffer, # noqa: F401
25+
Ref, # noqa: F401
2526
)
2627
from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401
2728
from .frame import has_let_value, get_let_value # noqa: F401

tilelang/language/proxy.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""The language interface for tl programs."""
22
from __future__ import annotations
33

4-
from typing import Any, SupportsIndex, TYPE_CHECKING
4+
from typing import Any, SupportsIndex, TYPE_CHECKING, Generic, TypeVar
55
from collections.abc import Sequence
66
from typing_extensions import Self
77

@@ -263,13 +263,21 @@ class SharedBuffer(BaseTensor):
263263

264264
class LocalBuffer(BaseTensor):
265265
...
266+
267+
_T = TypeVar('_T')
268+
269+
class Ref(Generic[_T], tir.Var):
270+
...
266271
else:
267272
Tensor = TensorProxy() # pylint: disable=invalid-name
268273
StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name
269274
FragmentBuffer = FragmentBufferProxy() # pylint: disable=invalid-name
270275
SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name
271276
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
272277

278+
class Ref:
279+
...
280+
273281

274282
def ptr(dtype: str | None = None,
275283
storage_scope: str = "global",

tilelang/language/v2/builder.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class MacroFrame(Frame):
8080
...
8181

8282

83+
class ExitedMacroFrame(Frame):
84+
...
85+
86+
8387
class BoolOpFrame(Frame):
8488
...
8589

@@ -164,8 +168,22 @@ def macro(self, name=None, annotations=None):
164168
save = self.name_inside_frame, self.arg_annotations
165169
self.name_inside_frame = {}
166170
self.arg_annotations = annotations or {}
167-
with self.with_frame(MacroFrame()):
168-
yield
171+
pos = len(self.frames)
172+
# here we add a ExitedMacroFrame to preserve the frame stack inside macro
173+
# because macro may bind some variable, and return it
174+
#
175+
# ```py
176+
# @T.macro
177+
# def foo(x):
178+
# y = x + 1
179+
# return y
180+
# @T.prim_func
181+
# def bar():
182+
# c = foo(1) # macro generates let y = x + 1
183+
# d = c # d = c should lay inside frame of `let y = x + 1`
184+
self.frames.append(MacroFrame())
185+
yield
186+
self.frames[pos] = ExitedMacroFrame()
169187
self.name_inside_frame, self.arg_annotations = save
170188

171189
def get(self):
@@ -335,7 +353,7 @@ def bind(self, name, value, annot=BaseBuilder.empty):
335353
assert frame is not None, f"Variable `{name}` is not defined inside any control flow."
336354
if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames:
337355
logger.warning(
338-
f'Variable `{name}` shadows another declared value, Are you forgetting to allocate it as a var?',
356+
f'Variable `{name}` is declared twice, are you looking for a T.alloc_var?',
339357
stack_info=True,
340358
stacklevel=2,
341359
)
@@ -475,7 +493,11 @@ def rval(self, name: str, value: Any) -> Any:
475493
return self.unwrap_value(value)
476494

477495
def macro_arg(self, name, value):
478-
if self.arg_annotations.get(name, None) is Var:
496+
from tilelang.language.proxy import Ref
497+
annot_value = self.arg_annotations.get(name, None)
498+
if annot_value is Var or annot_value is Ref:
499+
if annot_value is Var:
500+
logger.warning('Use `T.Var` as macro annotations is deprecated, please use `T.Ref`')
479501
is_var = isinstance(value, tvm.tir.BufferLoad) and value.buffer.scope() == 'local.var'
480502
if not is_var:
481503
raise ValueError(

tilelang/language/v2/dtypes.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,12 @@
8787
'float8_e8m0fnu': 'Float8E8M0FNU'
8888
}
8989

90+
int_ = int
91+
9092

9193
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var:
94+
if isinstance(expr, int_):
95+
return tvm.tir.const(expr, dtype=self)
9296
if self in _STR_TO_TVM_DTYPE_CALL:
9397
attr = _STR_TO_TVM_DTYPE_CALL[self]
9498
call = getattr(tb_ffi, attr, None)
@@ -151,6 +155,10 @@ class int8(dtype): ...
151155
class int16(dtype): ...
152156
class int32(dtype): ...
153157
class int64(dtype): ...
158+
class int8x2(dtype): ...
159+
class int16x2(dtype): ...
160+
class int32x2(dtype): ...
161+
class int64x2(dtype): ...
154162
class int8x4(dtype): ...
155163
class int16x4(dtype): ...
156164
class int32x4(dtype): ...
@@ -175,6 +183,10 @@ class uint8(dtype): ...
175183
class uint16(dtype): ...
176184
class uint32(dtype): ...
177185
class uint64(dtype): ...
186+
class uint8x2(dtype): ...
187+
class uint16x2(dtype): ...
188+
class uint32x2(dtype): ...
189+
class uint64x2(dtype): ...
178190
class uint8x4(dtype): ...
179191
class uint16x4(dtype): ...
180192
class uint32x4(dtype): ...
@@ -308,6 +320,10 @@ class bfloat16(dtype): ...
308320
int16 = dtype('int16')
309321
int32 = dtype('int32')
310322
int64 = dtype('int64')
323+
int8x2 = dtype('int8x2')
324+
int16x2 = dtype('int16x2')
325+
int32x2 = dtype('int32x2')
326+
int64x2 = dtype('int64x2')
311327
int8x4 = dtype('int8x4')
312328
int16x4 = dtype('int16x4')
313329
int32x4 = dtype('int32x4')
@@ -332,6 +348,10 @@ class bfloat16(dtype): ...
332348
uint16 = dtype('uint16')
333349
uint32 = dtype('uint32')
334350
uint64 = dtype('uint64')
351+
uint8x2 = dtype('uint8x2')
352+
uint16x2 = dtype('uint16x2')
353+
uint32x2 = dtype('uint32x2')
354+
uint64x2 = dtype('uint64x2')
335355
uint8x4 = dtype('uint8x4')
336356
uint16x4 = dtype('uint16x4')
337357
uint32x4 = dtype('uint32x4')
@@ -464,6 +484,10 @@ class bfloat16(dtype): ...
464484
'int16',
465485
'int32',
466486
'int64',
487+
'int8x2',
488+
'int16x2',
489+
'int32x2',
490+
'int64x2',
467491
'int8x4',
468492
'int16x4',
469493
'int32x4',
@@ -488,6 +512,10 @@ class bfloat16(dtype): ...
488512
'uint16',
489513
'uint32',
490514
'uint64',
515+
'uint8x2',
516+
'uint16x2',
517+
'uint32x2',
518+
'uint64x2',
491519
'uint8x4',
492520
'uint16x4',
493521
'uint32x4',

0 commit comments

Comments
 (0)