Skip to content

Commit f09e61b

Browse files
author
Siyuan Feng
authored
[TIR] Extend address_of to support Buffer objects (#18068)
This commit enhances the address_of function to accept both Buffer and BufferLoad objects. When a Buffer is passed, it automatically creates a BufferLoad with zero indices to get the base address.
1 parent 9dad95d commit f09e61b

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

python/tvm/tir/op.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from . import _ffi_api
2828
from .buffer import Buffer
29-
from .expr import Call, CommReducer, IntImm, PrimExprWithOp, Var
29+
from .expr import BufferLoad, Call, CommReducer, IntImm, PrimExprWithOp, Var
3030

3131

3232
def _pack_buffer(buf, span=None):
@@ -553,13 +553,13 @@ def tvm_struct_set(arr, index, field, value):
553553
return call_intrin("int32", "tir.tvm_struct_set", arr, index, field, value)
554554

555555

556-
def address_of(buffer_load, span=None):
556+
def address_of(obj: Union[Buffer, BufferLoad], span: Optional[Span] = None) -> PrimExpr:
557557
"""Returns the address of an element in the buffer
558558
559559
Parameters
560560
----------
561-
buffer_load: BufferLoad
562-
The buffer load.
561+
obj: Union[Buffer, BufferLoad]
562+
The buffer or buffer load.
563563
564564
span : Optional[Span]
565565
The location of this operator in the source code.
@@ -569,7 +569,15 @@ def address_of(buffer_load, span=None):
569569
call : PrimExpr
570570
The call expression.
571571
"""
572-
return call_intrin("handle", "tir.address_of", buffer_load, span=span)
572+
if isinstance(obj, Buffer):
573+
574+
n_dim = len(obj.shape)
575+
buffer_load = BufferLoad(obj, [0] * n_dim)
576+
return call_intrin("handle", "tir.address_of", buffer_load, span=span)
577+
elif isinstance(obj, BufferLoad):
578+
return call_intrin("handle", "tir.address_of", obj, span=span)
579+
else:
580+
raise ValueError(f"Invalid object type: {type(obj)}")
573581

574582

575583
def lookup_param(param_name, span=None):

tests/python/tvmscript/test_tvmscript_roundtrip.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4266,5 +4266,14 @@ def test_return_none_no_trailing_type():
42664266
assert "-> None" not in script
42674267

42684268

4269+
def test_address_of_buffer():
4270+
@T.prim_func
4271+
def func(a: T.handle):
4272+
A = T.match_buffer(a, (128, 128), "float32")
4273+
T.evaluate(T.address_of(A))
4274+
4275+
assert "T.address_of(A[0, 0])" in func.script()
4276+
4277+
42694278
if __name__ == "__main__":
42704279
tvm.testing.main()

0 commit comments

Comments
 (0)