@@ -172,14 +172,11 @@ uint16_t RoundToNearestEven(float src) {
172172 * Lower cast between bf16 and fp32
173173 * Lower bf16 FloatImm to int16
174174 */
175- class BF16LowerRewriter : StmtExprMutator {
175+ class BF16LowerRewriter : public StmtExprMutator {
176176 public:
177177 BF16LowerRewriter () {}
178178
179- std::unordered_map<const BufferNode*, Buffer> buffer_remap;
180- std::unordered_map<const VarNode*, Var> var_remap;
181-
182- Stmt operator ()(Stmt s) { return VisitStmt (s); }
179+ using StmtExprMutator::operator ();
183180
184181 PrimExpr VisitExpr_ (const CastNode* op) final {
185182 auto op_val = StmtExprMutator::VisitExpr (op->value );
@@ -190,7 +187,6 @@ class BF16LowerRewriter : StmtExprMutator {
190187 auto uint32_v = Cast (uint32_dtype, op_val);
191188 // to be endian invariant.
192189 return Call (op->dtype , builtin::reinterpret (), {uint32_v << 16 });
193-
194190 } else if (op->dtype .is_bfloat16 ()) {
195191 // if is cast_to_bf16, check if op->value is fp32
196192 CHECK (op->value ->dtype .is_float () && op->value ->dtype .bits () == 32 );
@@ -209,104 +205,104 @@ class BF16LowerRewriter : StmtExprMutator {
209205 }
210206
211207 PrimExpr VisitExpr_ (const VarNode* op) final {
212- auto itr = var_remap.find (op);
213- if (itr != var_remap.end ()) {
208+ Var var = GetRef<Var>(op);
209+
210+ auto itr = var_remap_.find (var);
211+ if (itr != var_remap_.end ()) {
214212 return itr->second ;
213+ } else {
214+ return std::move (var);
215215 }
216- if (op->dtype .is_bfloat16 ()) {
217- CHECK (!op->type_annotation .defined ());
218- auto ret = Var (op->name_hint , op->dtype );
219- var_remap[op] = ret;
220- return std::move (ret);
221- }
222- return StmtExprMutator::VisitExpr_ (op);
223216 }
224217
225218 Stmt VisitStmt_ (const AllocateNode* op) final {
226- Stmt node_holder;
227- const AllocateNode* newop;
228219 if (op->dtype .is_bfloat16 ()) {
229- auto v = Allocate (op-> buffer_var , DataType::UInt (16 , op->dtype .lanes ()), op-> extents ,
230- op->condition , op-> body );
231- node_holder = v ;
232- newop = static_cast < const AllocateNode*>(v. operator ->( ));
220+ DataType dtype = DataType::UInt (16 , op->dtype .lanes ());
221+ Var buffer_var = Var ( op->buffer_var -> name_hint , PointerType ( PrimType (dtype)) );
222+ var_remap_[op-> buffer_var ] = buffer_var ;
223+ return VisitStmt ( Allocate (buffer_var, dtype, op-> extents , op-> condition , op-> body ));
233224 } else {
234- newop = op ;
225+ return StmtExprMutator::VisitStmt_ (op) ;
235226 }
236- return StmtExprMutator::VisitStmt_ (newop);
237227 }
238228
239229 Stmt VisitStmt_ (const BufferStoreNode* op) final {
240- auto itr = buffer_remap. find (op-> buffer . operator ->() );
241- const BufferStoreNode* newop ;
242- BufferStore newop_holder;
243- if (itr != buffer_remap. end ()) {
244- newop_holder = BufferStore (itr-> second , op-> value , op-> indices );
245- newop = newop_holder. operator ->( );
230+ Stmt ret = StmtExprMutator::VisitStmt_ (op);
231+ op = ret. as <BufferStoreNode>() ;
232+
233+ auto it = buffer_remap_. find (op-> buffer );
234+ if (it != buffer_remap_. end ()) {
235+ return BufferStore (it-> second , op-> value , op-> indices );
246236 } else {
247- newop = op ;
237+ return ret ;
248238 }
249- return StmtExprMutator::VisitStmt_ (newop);
250239 }
251240
252241 Stmt VisitStmt_ (const AttrStmtNode* op) final {
253- const AttrStmtNode* newop = op ;
254- Stmt newop_holder ;
255- if ( auto buffer = op-> node . as <BufferNode>()) {
256- auto itr = buffer_remap. find (buffer);
257- if (itr != buffer_remap. end ()) {
258- newop_holder = AttrStmt (itr-> second , op-> attr_key , op-> value , op-> body );
259- newop = newop_holder. as <AttrStmtNode>( );
242+ Stmt ret = StmtExprMutator::VisitStmt_ (op) ;
243+ op = ret. as <AttrStmtNode>() ;
244+
245+ if ( auto * buffer = op-> node . as <BufferNode>()) {
246+ auto it = buffer_remap_. find (GetRef<Buffer>(buffer));
247+ if (it != buffer_remap_. end ()) {
248+ return AttrStmt (it-> second , op-> attr_key , op-> value , op-> body );
260249 }
261- } else if (auto buffer = op->node .as <VarNode>()) {
262- auto itr = var_remap.find (buffer);
263- if (itr != var_remap.end ()) {
264- newop_holder = AttrStmt (itr->second , op->attr_key , op->value , op->body );
265- newop = newop_holder.as <AttrStmtNode>();
250+ } else if (auto * var = op->node .as <VarNode>()) {
251+ auto it = var_remap_.find (GetRef<Var>(var));
252+ if (it != var_remap_.end ()) {
253+ return AttrStmt (it->second , op->attr_key , op->value , op->body );
266254 }
267255 }
268- return StmtExprMutator::VisitStmt_ (newop) ;
256+ return ret ;
269257 }
270258
271259 Stmt VisitStmt_ (const BufferRealizeNode* op) final {
272- auto itr = buffer_remap.find (op->buffer .operator ->());
273- const BufferRealizeNode* newop;
274- Stmt newop_holder;
275- if (itr != buffer_remap.end ()) {
276- auto v = BufferRealize (itr->second , op->bounds , op->condition , op->body );
277- newop_holder = v;
278- newop = v.operator ->();
260+ Stmt ret = StmtExprMutator::VisitStmt_ (op);
261+ op = ret.as <BufferRealizeNode>();
262+
263+ auto it = buffer_remap_.find (op->buffer );
264+ if (it != buffer_remap_.end ()) {
265+ return BufferRealize (it->second , op->bounds , op->condition , op->body );
279266 } else {
280- newop = op;
267+ return ret;
268+ }
269+ }
270+
271+ Stmt VisitStmt_ (const StoreNode* op) final {
272+ // NOTE: we do not explicit recursivly mutate op->buffer_var
273+ Stmt ret = StmtExprMutator::VisitStmt_ (op);
274+ op = ret.as <StoreNode>();
275+
276+ auto it = var_remap_.find (op->buffer_var );
277+ if (it != var_remap_.end ()) {
278+ return Store (it->second , op->value , op->index , op->predicate );
279+ } else {
280+ return ret;
281281 }
282- return StmtExprMutator::VisitStmt_ (newop);
283282 }
284283
285284 PrimExpr VisitExpr_ (const BufferLoadNode* op) final {
286- auto itr = buffer_remap. find (op-> buffer . operator ->() );
287- const BufferLoadNode* newop ;
288- BufferLoad newop_holder;
289- if (itr != buffer_remap. end ()) {
290- newop_holder = BufferLoad (itr-> second , op-> indices );
291- newop = newop_holder. operator ->( );
285+ PrimExpr ret = StmtExprMutator::VisitExpr_ (op);
286+ op = ret. as <BufferLoadNode>() ;
287+
288+ auto it = buffer_remap_. find (op-> buffer );
289+ if (it != buffer_remap_. end ()) {
290+ return BufferLoad (it-> second , op-> indices );
292291 } else {
293- newop = op ;
292+ return ret ;
294293 }
295- return StmtExprMutator::VisitExpr_ (newop);
296294 }
297295
298296 PrimExpr VisitExpr_ (const LoadNode* op) final {
299- bool is_bf16 = false ;
297+ PrimExpr ret = StmtExprMutator::VisitExpr_ (op);
298+ op = ret.as <LoadNode>();
299+
300300 if (op->dtype .is_bfloat16 ()) {
301- is_bf16 = true ;
302- }
303- PrimExpr index = this ->VisitExpr (op->index );
304- PrimExpr predicate = this ->VisitExpr (op->predicate );
305- if (index.same_as (op->index ) && predicate.same_as (op->predicate ) && !is_bf16) {
306- return GetRef<PrimExpr>(op);
301+ auto it = var_remap_.find (op->buffer_var );
302+ CHECK (it != var_remap_.end ()) << " bfloat* var needs to be remapped" ;
303+ return Load (DataType::UInt (16 , op->dtype .lanes ()), it->second , op->index , op->predicate );
307304 } else {
308- return Load (is_bf16 ? DataType::UInt (16 , op->dtype .lanes ()) : op->dtype , op->buffer_var ,
309- index, predicate);
305+ return ret;
310306 }
311307 }
312308
@@ -320,20 +316,31 @@ class BF16LowerRewriter : StmtExprMutator {
320316
321317 void AlterBuffers (PrimFuncNode* op) {
322318 std::vector<std::pair<Var, Buffer>> changes;
319+
323320 for (auto & itr : op->buffer_map ) {
324321 auto oldbuf = itr.second ;
325322 if (oldbuf->dtype .is_bfloat16 ()) {
326- auto newbuf = Buffer (oldbuf->data , DataType::UInt (16 , oldbuf->dtype .lanes ()), oldbuf->shape ,
327- oldbuf->strides , oldbuf->elem_offset , oldbuf->name , oldbuf->scope ,
328- oldbuf->data_alignment , oldbuf->offset_factor , oldbuf->buffer_type );
329- buffer_remap[oldbuf.operator ->()] = newbuf;
323+ DataType dtype = DataType::UInt (16 , oldbuf->dtype .lanes ());
324+ Var buffer_var = Var (oldbuf->data ->name_hint , PointerType (PrimType (dtype)));
325+ auto newbuf = Buffer (buffer_var, dtype, oldbuf->shape , oldbuf->strides , oldbuf->elem_offset ,
326+ oldbuf->name , oldbuf->scope , oldbuf->data_alignment ,
327+ oldbuf->offset_factor , oldbuf->buffer_type );
328+ buffer_remap_[oldbuf] = newbuf;
329+ var_remap_[oldbuf->data ] = buffer_var;
330330 changes.emplace_back (itr.first , newbuf);
331+ } else {
332+ changes.emplace_back (itr);
331333 }
332334 }
333- if (buffer_remap.size () != 0 ) {
335+
336+ if (buffer_remap_.size () != 0 ) {
334337 op->buffer_map = Map<Var, Buffer>(changes.begin (), changes.end ());
335338 }
336339 }
340+
341+ private:
342+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
343+ std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
337344};
338345
339346namespace transform {
0 commit comments