Skip to content

Commit 64911ab

Browse files
authored
[Runtime] Implemented Datatype.itemsize() (#16880)
* [Runtime] Implemented Datatype.itemsize()
1 parent d0cbb02 commit 64911ab

File tree

4 files changed

+58
-6
lines changed

4 files changed

+58
-6
lines changed

python/tvm/_ffi/runtime_ctypes.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,20 @@ def __eq__(self, other):
212212
def __ne__(self, other):
213213
return not self.__eq__(other)
214214

215+
def itemsize(self):
216+
"""Get the number of bytes of a single element of this data type. When the number of lanes
217+
is greater than 1, the itemsize is the size of the vector type.
218+
219+
Returns
220+
-------
221+
itemsize : int
222+
The number of bytes of a single element of this data type
223+
"""
224+
lanes_as_int = ctypes.c_int16(self.lanes).value
225+
if lanes_as_int < 0:
226+
raise ValueError("Cannot determine itemsize for scalable vector types")
227+
return (self.bits * self.lanes + 7) // 8
228+
215229

216230
if ml_dtypes is not None:
217231
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"

python/tvm/dlight/gpu/gemv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):
5757
def get_bytes(dtype: Union[DataType, str]) -> int:
5858
if isinstance(dtype, str):
5959
dtype = DataType(dtype)
60-
return dtype.bits * dtype.lanes // 8
60+
return dtype.itemsize()
6161

6262

6363
def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]:

python/tvm/dlight/gpu/low_batch_gemv.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""A rule for low-batch GEMM / decode-GEMM using GEMV schedule."""
18-
import re
1918
from functools import reduce
2019
from typing import List, Optional, Set, Union
2120

@@ -55,10 +54,9 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):
5554

5655

5756
def get_bytes(dtype: Union[DataType, str]) -> int:
58-
num = re.findall(r"\d+", dtype)
59-
if len(num) != 1:
60-
raise ValueError(f"Cannot get bytes from {dtype}")
61-
return int(num[0]) // 8
57+
if isinstance(dtype, str):
58+
dtype = DataType(dtype)
59+
return dtype.itemsize()
6260

6361

6462
def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]:

tests/python/ir/test_dtype.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Test data type related API"""
18+
import tvm
19+
from tvm import DataType
20+
import tvm.testing
21+
import pytest
22+
23+
24+
@pytest.mark.parametrize(
25+
"dtype_str, expected_size",
26+
[("float32", 4), ("float32x4", 16), ("e5m2_float8x4", 4), ("uint8", 1)],
27+
)
28+
def test_dtype_itemsize(dtype_str, expected_size):
29+
dtype = DataType(dtype_str)
30+
assert dtype.itemsize() == expected_size
31+
32+
33+
@pytest.mark.parametrize("dtype_str", [("int32xvscalex4")])
34+
def test_dtype_itemmize_error(dtype_str):
35+
with pytest.raises(ValueError):
36+
size = DataType(dtype_str).itemsize()
37+
38+
39+
if __name__ == "__main__":
40+
tvm.testing.main()

0 commit comments

Comments
 (0)