diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 57aa060cd0c5..97eb3d5b2fe1 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -26,7 +26,7 @@ from . import _ffi_api from .buffer import Buffer -from .expr import Call, CommReducer, IntImm, PrimExprWithOp, Var +from .expr import BufferLoad, Call, CommReducer, IntImm, PrimExprWithOp, Var def _pack_buffer(buf, span=None): @@ -553,13 +553,13 @@ def tvm_struct_set(arr, index, field, value): return call_intrin("int32", "tir.tvm_struct_set", arr, index, field, value) -def address_of(buffer_load, span=None): +def address_of(obj: Union[Buffer, BufferLoad], span: Optional[Span] = None) -> PrimExpr: """Returns the address of an element in the buffer Parameters ---------- - buffer_load: BufferLoad - The buffer load. + obj: Union[Buffer, BufferLoad] + The buffer or buffer load. span : Optional[Span] The location of this operator in the source code. @@ -569,7 +569,15 @@ def address_of(buffer_load, span=None): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.address_of", buffer_load, span=span) + if isinstance(obj, Buffer): + + n_dim = len(obj.shape) + buffer_load = BufferLoad(obj, [0] * n_dim) + return call_intrin("handle", "tir.address_of", buffer_load, span=span) + elif isinstance(obj, BufferLoad): + return call_intrin("handle", "tir.address_of", obj, span=span) + else: + raise ValueError(f"Invalid object type: {type(obj)}") def lookup_param(param_name, span=None): diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index cebfb842bad9..af2db34415f8 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -4266,5 +4266,14 @@ def test_return_none_no_trailing_type(): assert "-> None" not in script +def test_address_of_buffer(): + @T.prim_func + def func(a: T.handle): + A = T.match_buffer(a, (128, 128), "float32") + T.evaluate(T.address_of(A)) + + assert "T.address_of(A[0, 0])" in func.script() + + if __name__ == "__main__": tvm.testing.main()