3232from tvm ._ffi .base import TVMError
3333from tvm .ir import GlobalVar
3434from tvm .ir .function import BaseFunc
35+ from tvm .tir import buffer
3536from tvm .tir .function import PrimFunc
3637from . import _ffi_api
3738from . import tir
@@ -154,10 +155,10 @@ class TVMScriptParser(Transformer):
154155 ast .BuiltinOp .Not : tvm .tir .Not ,
155156 }
156157
157- def __init__ (self , base_lienno , tir_namespace ):
158+ def __init__ (self , base_lineno , tir_namespace ):
158159 self .context = None
159160
160- self .base_lineno = base_lienno
161+ self .base_lineno = base_lineno
161162 self .current_lineno = 0
162163 self .current_col_offset = 0
163164 self .tir_namespace = tir_namespace
@@ -249,20 +250,23 @@ def parse_arg_list(self, func, node_call):
249250 func : Function
250251 The function that provides the signature
251252
252- node_call: ast.Call
253+ node_call: Union[ ast.Call, ast.TypeApply, ast.TypeCall]
253254 The AST call node that calls into the function.
254255
255256 Returns
256257 -------
257258 arg_list : list
258259 The parsed positional argument.
259260 """
260- assert isinstance (node_call , ast .Call )
261+ assert isinstance (node_call , ( ast .Call , ast . TypeApply , ast . TypeCall ) )
261262 # collect arguments
262263 args = [self .transform (arg ) for arg in node_call .params ]
263- kw_args = {
264- self .transform (k ): self .transform (v ) for k , v in node_call .keyword_params .items ()
265- }
264+ if isinstance (node_call , ast .TypeApply ):
265+ kw_args = {} # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
266+ else :
267+ kw_args = {
268+ self .transform (k ): self .transform (v ) for k , v in node_call .keyword_params .items ()
269+ }
266270 # get the name and parameter list of func
267271 if isinstance (func , (Intrin , ScopeHandler , SpecialStmt )):
268272 func_name , param_list = func .signature ()
@@ -276,6 +280,7 @@ def parse_arg_list(self, func, node_call):
276280 reader = CallArgumentReader (func_name , args , kw_args , self , node_call )
277281 pos_only , kwargs , varargs = param_list
278282 internal_args = list ()
283+
279284 for i , arg_name in enumerate (pos_only ):
280285 internal_args .append (reader .get_pos_only_arg (i + 1 , arg_name ))
281286 for i , arg_info in enumerate (kwargs ):
@@ -439,8 +444,22 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
439444
440445 # add parameters of function
441446 for arg in node .params :
442- arg_var = tvm .te .var (arg .name , self .parse_type (arg .ty , arg ))
443- self .context .update_symbol (arg .name , arg_var , node )
447+ # Note that this case is for T.match_buffer syntax sugar
448+ if isinstance (arg .ty , (ast .TypeCall , ast .TypeApply )):
449+ result = self .handle_match_buffer_type (arg .ty , arg .name )
450+ if not isinstance (result , buffer .Buffer ):
451+ self .report_error (
452+ "The result type of evaluating TypeCall and TypeApply stmt"
453+ f" is wrong: { type (result )} . It should be a Buffer" ,
454+ node .span ,
455+ )
456+ arg_name_with_handle = arg .name + "_handle"
457+ arg_var = tvm .te .var (arg_name_with_handle , tvm .ir .PrimType ("handle" ))
458+ self .context .func_buffer_map [arg_var ] = result
459+ self .context .update_symbol (arg .name , result , node )
460+ else :
461+ arg_var = tvm .te .var (arg .name , self .parse_type (arg .ty , arg ))
462+ self .context .update_symbol (arg .name , arg_var , node )
444463 self .context .func_params .append (arg_var )
445464
446465 if not check_decorator (node .decorators ):
@@ -1110,6 +1129,30 @@ def transform_TypeConstant(self, node):
11101129 """
11111130 return node .value
11121131
1132+ def transform_TypeTuple (self , node ):
1133+ """Tuple value visitor for types.
1134+
1135+ Mostly used in `transform_TypeCall` and `transform_TypeApply`.
1136+ """
1137+ return [self .transform (value ) for value in node .values ]
1138+
1139+ def handle_match_buffer_type (self , node , buffer_name ):
1140+ """special function to handle syntax sugar for match buffer.
1141+
1142+ This method is for buffer declarations in the function parameters.
1143+ """
1144+ func = self .transform (node .func_name )
1145+ assert isinstance (func , SpecialStmt )
1146+
1147+ # parse args and kwargs for TypeCall and TypeApply
1148+ arg_list = self .parse_arg_list (func , node )
1149+ # Note that the third element in arg_list would always be the 'name'
1150+ # TODO: This index is hardcoded as a workaround. Better to make it programmatic
1151+ if arg_list [2 ] is None :
1152+ arg_list [2 ] = buffer_name
1153+ buf = func .handle (node , self .context , arg_list , node .func_name .span )
1154+ return buf
1155+
11131156 def transform_Return (self , node ):
11141157 self .report_error (
11151158 "TVM script does not support return statements. Instead the last statement in any "
0 commit comments