@@ -566,7 +566,18 @@ def transform_Assign(self, node):
566566 self .context .remove_symbol (var .name )
567567 return tvm .tir .LetStmt (var , value , body , span = tvm_span_from_synr (node .span ))
568568
569- self .report_error ("Unsupported Assign stmt" , node .span )
569+ self .report_error (
570+ """Assignments should be either
571+ 1. A "special statement" with return value
572+ 1.1 Buffer = T.match_buffer()/T.buffer_decl()
573+ 1.2 Var = T.var()
574+ 1.3 Var = T.env_thread()
575+ 2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
576+ 3. A store into a variable: Var[PrimExpr] = PrimExpr
577+ 4. A with scope handler with concise scoping and var def
578+ 4.1 var = T.allocate()""" ,
579+ node .span ,
580+ )
570581
571582 def transform_SubscriptAssign (self , node ):
572583 """Visitor for statements of the form :code:`x[1] = 2`."""
@@ -583,6 +594,12 @@ def transform_SubscriptAssign(self, node):
583594 span = tvm_span_from_synr (node .span ),
584595 )
585596 else :
597+ if symbol .dtype == "handle" and len (indexes ) != 1 :
598+ self .report_error (
599+ "Handles only support one-dimensional indexing. Use `T.match_buffer` to "
600+ "construct a multidimensional buffer from a handle." ,
601+ node .params [0 ].span ,
602+ )
586603 if len (indexes ) != 1 :
587604 self .report_error (
588605 f"Store is only allowed with one index, but { len (indexes )} were provided." ,
@@ -736,9 +753,35 @@ def transform_Call(self, node):
736753 return self .transform_Subscript (node )
737754 if node .func_name .name in self ._binop_maker :
738755 lhs = self .transform (node .params [0 ])
756+ # There is no supertype for everything that can appear in
757+ # an expression, so we manually add what we might get here.
758+ if not isinstance (lhs , (tvm .tir .PrimExpr , BufferSlice )):
759+ # We would really like to report a more specific
760+ # error here, but this parser contains no distinction
761+ # between parsing statements and parsing expressions. All
762+ # rules just call `transform`.
763+ self .report_error (
764+ f"Left hand side of binary op must be a PrimExpr, "
765+ "but it is a {type(lhs).__name__}" ,
766+ node .params [0 ].span ,
767+ )
739768 rhs = self .transform (node .params [1 ])
740- return self ._binop_maker [node .func_name .name ](
741- lhs , rhs , span = tvm_span_from_synr (node .span )
769+ if not isinstance (rhs , (tvm .tir .PrimExpr , BufferSlice )):
770+ self .report_error (
771+ f"Right hand side of binary op must be a PrimExpr, "
772+ "but it is a {type(rhs).__name__}" ,
773+ node .params [1 ].span ,
774+ )
775+ return call_with_error_reporting (
776+ self .report_error ,
777+ node .span ,
778+ lambda node , lhs , rhs , span : self ._binop_maker [node .func_name .name ](
779+ lhs , rhs , span = span
780+ ),
781+ node ,
782+ lhs ,
783+ rhs ,
784+ tvm_span_from_synr (node .span ),
742785 )
743786 if node .func_name .name in self ._unaryop_maker :
744787 rhs = self .transform (node .params [0 ])
@@ -764,6 +807,8 @@ def transform_Call(self, node):
764807 self .transform (k ): self .transform (v ) for k , v in node .keyword_params .items ()
765808 }
766809 if isinstance (func , tvm .tir .op .Op ):
810+ if not "dtype" in kw_args .keys ():
811+ self .report_error (f"{ func } requires a dtype keyword argument." , node .span )
767812 # pattern 2
768813 return tvm .tir .Call (
769814 kw_args ["dtype" ], func , args , span = tvm_span_from_synr (node .span )
@@ -862,15 +907,33 @@ def transform_Subscript(self, node):
862907
863908 indexes = [self .transform (x ) for x in node .params [1 ].values ]
864909 if isinstance (symbol , tvm .tir .expr .Var ):
865- for index in indexes :
866- if not isinstance (index , (tvm .tir .PrimExpr , int )):
867- self .report_error (
868- "Buffer load indexes should be int or PrimExpr, but they are "
869- + type (index ),
870- node .span ,
871- )
872- return tvm .tir .Load (
873- "float32" , symbol , indexes , True , span = tvm_span_from_synr (node .span )
910+ if symbol .dtype == "handle" :
911+ self .report_error (
912+ "Cannot read directly from a handle, use `T.match_buffer` "
913+ "to create a buffer to read from." ,
914+ node .params [0 ].span ,
915+ )
916+ if len (indexes ) > 1 :
917+ self .report_error (
918+ "Only a single index can be provided when indexing into a `var`." ,
919+ node .params [1 ].span ,
920+ )
921+ index = indexes [0 ]
922+ if not isinstance (index , (tvm .tir .PrimExpr , int )):
923+ self .report_error (
924+ "Var load index should be an int or PrimExpr, but it is a" + type (index ),
925+ node .span ,
926+ )
927+
928+ return call_with_error_reporting (
929+ self .report_error ,
930+ node .span ,
931+ tvm .tir .Load ,
932+ "float32" ,
933+ symbol ,
934+ index ,
935+ True ,
936+ span = tvm_span_from_synr (node .span ),
874937 )
875938 elif isinstance (symbol , tvm .tir .Buffer ):
876939 return BufferSlice (
0 commit comments