1818# pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
1919from typing import Tuple , Any , Callable , Optional , List , Union , Mapping
2020
21+ import numpy as np
2122import synr
2223import tvm .tir
2324from tvm .runtime import Object
2425from tvm .ir import Span , Range
2526from tvm .tir import Stmt , PrimExpr , IterVar , Var , Buffer , BufferRegion , ForKind
26- import numpy as np
2727
2828from .node import BufferSlice
2929from .utils import buffer_slice_to_region
@@ -159,16 +159,15 @@ def setup_buffer_var(
159159
160160@register
161161class AllocateConst (WithScopeHandler ):
162- """With scope handler tir.allocate (data, extents, dtype, condition)"""
162+ """With scope handler tir.allocate_const (data, extents, dtype, condition)"""
163163
164164 def __init__ (self ):
165165 def allocate_const (raw_data , dtype , shape , span = None ):
166166 list_data = []
167167 for i in raw_data :
168168 list_data .append (i .value )
169169 nd_data = tvm .nd .array (np .asarray (list_data , dtype = dtype ))
170-
171- n = tvm .tir .AllocateConst (self .buffer_var , nd_data , dtype , shape , self .body , span = span )
170+ n = tvm .tir .AllocateConst (self .buffer_var , dtype , shape , nd_data , self .body , span = span )
172171 return n
173172
174173 super ().__init__ (allocate_const , concise_scope = True , def_symbol = True )
@@ -182,15 +181,17 @@ def enter_scope(
182181 span : synr .ast .Span ,
183182 ):
184183 # define buffer vars in symbol table
185- if isinstance (node , ast .With ):
184+ if isinstance (node , synr . ast .With ):
186185 vars = WithScopeHandler .get_optional_vars (node , context )
187186 if len (vars ) != 1 :
188- context .report_error ("Unexpected number of vars" , node .span )
187+ context .report_error (f "Unexpected number of vars: 1 vs. { len ( vars ) } " , node .span )
189188 name = vars [0 ].id .name
190189 var_span = vars [0 ].id .span
191- elif isinstance (node , ast .Assign ):
192- name = node .lhs .id .name
193- var_span = node .lhs .id .span
190+ elif isinstance (node , synr .ast .Assign ):
191+ if len (node .lhs ) != 1 :
192+ context .report_error (f"Unexpected number of vars: 1 vs. { len (node .lhs )} " , node .span )
193+ name = node .lhs [0 ].id .name
194+ var_span = node .lhs [0 ].id .span
194195 else :
195196 raise Exception ("Internal Bug" )
196197
@@ -214,11 +215,7 @@ def launch_thread(env_var, extent, span):
214215 attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent"
215216 return tvm .tir .AttrStmt (
216217 IterVar (
217- (0 , extent ),
218- env_var ,
219- getattr (IterVar , "ThreadIndex" ),
220- thread_id ,
221- span = span ,
218+ (0 , extent ), env_var , getattr (IterVar , "ThreadIndex" ), thread_id , span = span ,
222219 ),
223220 attr_key ,
224221 extent ,
@@ -545,9 +542,7 @@ class Serial(ForScopeHandler):
545542
546543 def __init__ (self ):
547544 def serial (
548- begin : PrimExpr ,
549- end : PrimExpr ,
550- annotations : Optional [Mapping [str , Object ]] = None ,
545+ begin : PrimExpr , end : PrimExpr , annotations : Optional [Mapping [str , Object ]] = None ,
551546 ):
552547 self .create_loop_info (begin , end , ForKind .SERIAL , annotations = annotations )
553548
@@ -560,9 +555,7 @@ class Parallel(ForScopeHandler):
560555
561556 def __init__ (self ):
562557 def parallel (
563- begin : PrimExpr ,
564- end : PrimExpr ,
565- annotations : Optional [Mapping [str , Object ]] = None ,
558+ begin : PrimExpr , end : PrimExpr , annotations : Optional [Mapping [str , Object ]] = None ,
566559 ):
567560 self .create_loop_info (begin , end , ForKind .PARALLEL , annotations = annotations )
568561
@@ -575,9 +568,7 @@ class Vectorized(ForScopeHandler):
575568
576569 def __init__ (self ):
577570 def vectorized (
578- begin : PrimExpr ,
579- end : PrimExpr ,
580- annotations : Optional [Mapping [str , Object ]] = None ,
571+ begin : PrimExpr , end : PrimExpr , annotations : Optional [Mapping [str , Object ]] = None ,
581572 ):
582573 self .create_loop_info (begin , end , ForKind .VECTORIZED , annotations = annotations )
583574
@@ -590,9 +581,7 @@ class Unroll(ForScopeHandler):
590581
591582 def __init__ (self ):
592583 def unroll (
593- begin : PrimExpr ,
594- end : PrimExpr ,
595- annotations : Optional [Mapping [str , Object ]] = None ,
584+ begin : PrimExpr , end : PrimExpr , annotations : Optional [Mapping [str , Object ]] = None ,
596585 ):
597586 self .create_loop_info (begin , end , ForKind .UNROLLED , annotations = annotations )
598587
0 commit comments