Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from . import _ffi_api
from .buffer import Buffer
from .expr import Call, CommReducer, IntImm, PrimExprWithOp, Var
from .expr import BufferLoad, Call, CommReducer, IntImm, PrimExprWithOp, Var


def _pack_buffer(buf, span=None):
Expand Down Expand Up @@ -553,13 +553,13 @@ def tvm_struct_set(arr, index, field, value):
return call_intrin("int32", "tir.tvm_struct_set", arr, index, field, value)


def address_of(buffer_load, span=None):
def address_of(obj: Union[Buffer, BufferLoad], span: Optional[Span] = None) -> PrimExpr:
"""Returns the address of an element in the buffer
Parameters
----------
buffer_load: BufferLoad
The buffer load.
obj: Union[Buffer, BufferLoad]
The buffer or buffer load.
span : Optional[Span]
The location of this operator in the source code.
Expand All @@ -569,7 +569,15 @@ def address_of(buffer_load, span=None):
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.address_of", buffer_load, span=span)
if isinstance(obj, Buffer):

n_dim = len(obj.shape)
buffer_load = BufferLoad(obj, [0] * n_dim)
return call_intrin("handle", "tir.address_of", buffer_load, span=span)
elif isinstance(obj, BufferLoad):
return call_intrin("handle", "tir.address_of", obj, span=span)
else:
raise ValueError(f"Invalid object type: {type(obj)}")


def lookup_param(param_name, span=None):
Expand Down
9 changes: 9 additions & 0 deletions tests/python/tvmscript/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4266,5 +4266,14 @@ def test_return_none_no_trailing_type():
assert "-> None" not in script


def test_address_of_buffer():
@T.prim_func
def func(a: T.handle):
A = T.match_buffer(a, (128, 128), "float32")
T.evaluate(T.address_of(A))

assert "T.address_of(A[0, 0])" in func.script()


if __name__ == "__main__":
tvm.testing.main()
Loading