Skip to content

Commit 8274d14

Browse files
authored
[Relax] Implement operators to inspec DLTensor::strides and offset (#16721)
* [TIR] LowerTVMBuiltin may use device_type from PrimFunc annotation If an allocation occurs within a host function, it may not have a device/host split. * lint fix * [Relax] Implement operators to inspec DLTensor::strides and offset A follow-up PR to #16563. This PR implements similar operators to inspect the runtime values of `DLTensor::strides` and `DLTensor::byte_offset`. In addition, while the element offset is not explicitly present in the `DLTensor` struct, a Relax operator is implemented to infer it from the `byte_offset` and `data_type` fields, for use when interacting with the TIR `BufferNode::elem_offset` field.
1 parent 016b512 commit 8274d14

File tree

9 files changed

+727
-170
lines changed

9 files changed

+727
-170
lines changed

python/tvm/relax/expr.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,33 @@ def shape(self) -> "_DLTensorShapeProxy":
280280
self._check_for_tensor_struct_info()
281281
return _DLTensorShapeProxy(self)
282282

283+
@property
284+
def strides(self) -> "_DLTensorStrideProxy":
285+
"""Returns a proxy object for accessing DLTensor::strides"""
286+
self._check_for_tensor_struct_info()
287+
return _DLTensorStrideProxy(self)
288+
289+
@property
290+
def byte_offset(self) -> "Expr":
291+
"""Returns a proxy object for accessing DLTensor::byte_offset"""
292+
self._check_for_tensor_struct_info()
293+
op = tvm.ir.Op.get("relax.inspect.tensor_byte_offset")
294+
return tvm.relax.Call(op, [self])
295+
296+
@property
297+
def elem_offset(self) -> "Expr":
298+
"""Returns a proxy object for accessing a DLTensor's elem_offset
299+
300+
This parameter is not stored in the DLTensor, but is instead
301+
derived from the DLTensor's byte offset and datatype. This is
302+
exposed in Relax for ease of use, and for translation into the
303+
`tir::BufferNode::elem_offset` field when interacting with TIR
304+
buffers.
305+
"""
306+
self._check_for_tensor_struct_info()
307+
op = tvm.ir.Op.get("relax.inspect.tensor_elem_offset")
308+
return tvm.relax.Call(op, [self])
309+
283310

284311
class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric):
285312
"""A proxy object for unpacking DLDatatype from DLTensor
@@ -431,6 +458,76 @@ def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
431458
return tvm.relax.Call(op, [self.tensor, axis])
432459

433460

461+
class _DLTensorStrideProxy(tvm.runtime.ObjectGeneric):
462+
"""A proxy object for unpacking the strides from DLTensor
463+
464+
Exposes accessors for the `DLTensor::strides` field. Accessing
465+
these fields will produce `relax.Call` expressions, representing
466+
the field's runtime value. If the datatype of the tensor is known
467+
at compile-time, the `relax.Call` will be normalized into a
468+
`relax.PrimValue`, with no runtime cost.
469+
470+
Parameters
471+
----------
472+
tensor: relax.Expr
473+
474+
The relax tensor (or a variable referring to a relax tensor),
475+
whose runtime strides is being inspected.
476+
"""
477+
478+
def __init__(self, tensor):
479+
self.tensor = tensor
480+
481+
def asobject(self):
482+
"""Provide expected in error message
483+
484+
This method is called when `_DLTensorStrideProxy` is used in a
485+
context that requires a `relax.Expr`. This usage is not
486+
supported, and raising an error here can provide suggested
487+
fixes that are not present in the default error message from
488+
`tvm.runtime.convert_to_object`.
489+
"""
490+
raise TypeError(
491+
f"{self.tensor}.strides cannot be converted to a relax expression, "
492+
f"and should be used as a proxy object to access the runtime strides of the DLTensor. "
493+
f"The DLTensor::ndim field can be accessed as len({self.tensor}), "
494+
f"and the DLTensor::strides array can be accessed as {self.tensor}.strides[i]"
495+
)
496+
497+
def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
498+
"""Returns the extent of a tensor axis
499+
500+
Parameters
501+
----------
502+
axis: Union[int, PrimExpr, Expr]
503+
504+
The tensor axis whose extent should be returned. For ease
505+
of use, any python integers or TIR expressions are
506+
converted to `relax.Expr`.
507+
508+
Returns
509+
-------
510+
extent: Expr
511+
512+
The extent of the tensor's axis.
513+
"""
514+
515+
if not isinstance(axis, tvm.relax.Expr):
516+
axis = tvm.relax.PrimValue(axis)
517+
518+
if axis.struct_info_ is not None and not isinstance(
519+
axis.struct_info_, tvm.relax.PrimStructInfo
520+
):
521+
raise TypeError(
522+
f"The index used to access {self.tensor}.strides "
523+
f'must have struct info R.Prim("int64"), '
524+
f"but index {axis} had struct info {axis.struct_info_}."
525+
)
526+
527+
op = tvm.ir.Op.get("relax.inspect.tensor_stride_i")
528+
return tvm.relax.Call(op, [self.tensor, axis])
529+
530+
434531
@tvm._ffi.register_object("relax.expr.Call")
435532
class Call(ExprWithOp):
436533
"""Function call node in Relax.

python/tvm/relax/transform/legalize_ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from . import grad
2424
from . import image
2525
from . import index
26+
from . import inspect_op
2627
from . import linear_algebra
2728
from . import manipulate
2829
from . import nn
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name
18+
"""Legalization functions for DLTensor inspection."""
19+
20+
import enum
21+
22+
from tvm.script import tir as T
23+
24+
from ...block_builder import BlockBuilder
25+
from ...expr import Call, Expr
26+
from .common import register_legalize
27+
28+
29+
class TVMStructFieldKind(enum.IntEnum):
30+
"""Equivalent to tvm::tir::builtin::TVMStructFieldKind
31+
32+
This does not use `enum.auto()` to define the values, because
33+
`enum.auto()` starts from 1, and this must match the C++
34+
definition which starts from 0.
35+
"""
36+
37+
kArrAddr = 0
38+
kArrData = 1
39+
kArrShape = 2
40+
kArrStrides = 3
41+
kArrNDim = 4
42+
kArrTypeCode = 5
43+
kArrTypeBits = 6
44+
kArrTypeLanes = 7
45+
kArrByteOffset = 8
46+
kArrDeviceId = 9
47+
kArrDeviceType = 10
48+
kArrKindBound_ = 11
49+
kTVMValueContent = 12
50+
kTVMValueKindBound_ = 13
51+
52+
53+
@register_legalize("relax.inspect.tensor_stride_i")
54+
def _tensor_stride_i(bb: BlockBuilder, call: Call) -> Expr:
55+
@T.prim_func(private=True)
56+
def _get_tensor_stride_i(dlpack_handle: T.handle, axis: T.int64) -> T.int64:
57+
T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)})
58+
assert T.int64(0) <= axis, "Specified axis may not be negative"
59+
ndim: T.int32 = T.tvm_struct_get(
60+
dlpack_handle, 0, int(TVMStructFieldKind.kArrNDim), "int32"
61+
)
62+
assert axis < T.Cast(
63+
"int64", ndim
64+
), "Specified axis may not be larger than the tensor's dimensionality"
65+
stride_ptr: T.handle("int64") = T.tvm_struct_get(
66+
dlpack_handle, 0, int(TVMStructFieldKind.kArrStrides), "handle"
67+
)
68+
69+
if T.isnullptr(stride_ptr):
70+
shape_ptr: T.handle("int64") = T.tvm_struct_get(
71+
dlpack_handle, 0, int(TVMStructFieldKind.kArrShape), "handle"
72+
)
73+
shape = T.decl_buffer(ndim, "int64", data=shape_ptr)
74+
75+
product = T.decl_buffer([], "int64")
76+
product[()] = 1
77+
78+
# TODO(Lunderberg): Add a TIR lowering pass to allow
79+
# ranges to start somewhere other than zero. This loop
80+
# could then iterate on `range(axis+1, ndim)`.
81+
for dim_offset in range(ndim - (axis + 1)):
82+
dim = dim_offset + (axis + 1)
83+
product[()] = product[()] * shape[dim]
84+
85+
return product[()]
86+
else:
87+
strides = T.decl_buffer(ndim, "int64", data=stride_ptr)
88+
stride: T.int64 = strides[axis]
89+
return stride
90+
91+
gvar = bb.add_func(_get_tensor_stride_i, "_get_tensor_stride_i")
92+
return Call(gvar, call.args)
93+
94+
95+
@register_legalize("relax.inspect.tensor_byte_offset")
96+
def _tensor_byte_offset(bb: BlockBuilder, call: Call) -> Expr:
97+
@T.prim_func(private=True)
98+
def _get_tensor_byte_offset(dlpack_handle: T.handle) -> T.int64:
99+
T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)})
100+
byte_offset: T.uint64 = T.tvm_struct_get(
101+
dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64"
102+
)
103+
return byte_offset
104+
105+
gvar = bb.add_func(_get_tensor_byte_offset, "_get_tensor_byte_offset")
106+
return Call(gvar, call.args)
107+
108+
109+
@register_legalize("relax.inspect.tensor_elem_offset")
110+
def _tensor_elem_offset(bb: BlockBuilder, call: Call) -> Expr:
111+
@T.prim_func(private=True)
112+
def _get_tensor_elem_offset(dlpack_handle: T.handle) -> T.int64:
113+
T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)})
114+
byte_offset: T.uint64 = T.tvm_struct_get(
115+
dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64"
116+
)
117+
scalar_bits: T.uint8 = T.tvm_struct_get(
118+
dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeBits), "uint8"
119+
)
120+
lanes: T.uint16 = T.tvm_struct_get(
121+
dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeLanes), "uint16"
122+
)
123+
bytes_per_element = T.ceildiv(scalar_bits.astype("uint64") * lanes.astype("uint64"), 8)
124+
elem_offset = byte_offset // bytes_per_element
125+
return elem_offset
126+
127+
gvar = bb.add_func(_get_tensor_elem_offset, "_get_tensor_elem_offset")
128+
return Call(gvar, call.args)

0 commit comments

Comments
 (0)