Skip to content

Commit 1d7c52f

Browse files
wereyzhliu
authored andcommitted
add docstring skip in hybrid script (#1668)
* add docstring skip in hybrid script * fix lint
1 parent 463e5c3 commit 1d7c52f

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

python/tvm/hybrid/parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ast
44
import operator
55
import sys
6-
from .util import make_nop, halide_imm_types
6+
from .util import make_nop, halide_imm_types, is_docstring
77
from .intrin import LOOP_INTRIN, MATH_INTRIN
88
from .var_decl import determine_variable_usage
99
from ..api import thread_axis
@@ -15,7 +15,7 @@
1515

1616
def list_to_block(visit, lst):
1717
"""Convert a list of Python IR nodes to HalideIR Block"""
18-
lst = list(map(visit, lst))
18+
lst = [visit(stmt) for stmt in lst if not is_docstring(stmt)]
1919
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())]
2020
if not lst:
2121
return make_nop()

python/tvm/hybrid/util.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Internal utilities for parsing Python subset to HalideIR"""
22

3+
import ast
34
import inspect
45
import numpy
56
from .intrin import HYBRID_GLOBALS
@@ -22,6 +23,11 @@ def make_nop():
2223
return _make.Evaluate(_api.const(0, dtype='int32'))
2324

2425

26+
def is_docstring(node):
27+
"""Checks if a Python AST node is a docstring"""
28+
return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str)
29+
30+
2531
def _pruned_source(func):
2632
"""Prune source code's extra leading spaces"""
2733
lines = inspect.getsource(func).split('\n')

tests/python/unittest/test_hybrid_script.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def tvm_val_2_py_val(val):
4343

4444
@script
4545
def outer_product(n, m, a, b, c):
46+
"""This is a simple outer product"""
4647
for i in range(n):
4748
for j in range(m):
4849
c[i, j] = a[i] * b[j]

0 commit comments

Comments
 (0)