Skip to content

Commit 818493d

Browse files
committed
- Fix unit test
1 parent 3d1c45f commit 818493d

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

include/tvm/tir/stmt.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,9 @@ class AllocateConstNode : public StmtNode {
599599
/*! \brief The optional data associated to the constant.
600600
*/
601601
Optional<runtime::NDArray> data;
602-
/*! \brief If the PrimFunc containing the Stmt is added to IRModule,
603-
this is an optional index to indicate the index within
604-
"Constants" attribute, that is a Array<NDArray> of IRModule.
602+
/*!
603+
* \brief If the PrimFunc containing the Stmt is added to IRModule, this is an optional index
604+
* to indicate the index within "constants" attribute, that is a Array<NDArray> of IRModule.
605605
*/
606606
Optional<Integer> irmod_storage_idx;
607607
/*! \brief The type of the buffer. */

python/tvm/tir/stmt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ class AllocateConst(Stmt):
358358
data_or_idx : Union[NDArray, int]
359359
If an NDArray, this is the const data associated with the
360360
constant. If an integer, this is the index into the
361-
"Constants" attribute of the `IRModule` that contains the
361+
"constants" attribute of the `IRModule` that contains the
362362
`AllocateConst`.
363363
364364
body : Stmt

tests/python/unittest/test_custom_datatypes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytest
2222
import tvm
2323
import tvm.topi.testing
24+
import tvm.testing
2425
from tvm import relay
2526
from tvm.relay.testing.layers import batch_norm_infer
2627
from tvm.target.datatype import (
@@ -560,4 +561,4 @@ def test_posites2():
560561

561562

562563
if __name__ == "__main__":
563-
pytest.main([__file__])
564+
tvm.testing.main()

tests/python/unittest/test_tir_transform_extract_constants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import tvm
1919
from tvm import tir
2020
from tvm.script import tir as T
21+
import tvm.testing
2122

2223

2324
@tvm.script.ir_module
@@ -49,7 +50,7 @@ def constant3(a: T.handle) -> None:
4950

5051
def test_const_extraction():
5152
mod = tvm.tir.transform.ExtractPrimFuncConstants()(Module4)
52-
constants = mod.attrs["Constants"]
53+
constants = mod.attrs["constants"]
5354
assert len(constants) == 2
5455

5556
def _visit(stmt):
@@ -63,4 +64,4 @@ def _visit(stmt):
6364

6465

6566
if __name__ == "__main__":
66-
test_const_extraction()
67+
tvm.testing.main()

0 commit comments

Comments
 (0)