@@ -96,3 +96,72 @@ class PipelineState(Enum):
9696    IDLE  =  0 
9797    CALL_FWD  =  1 
9898    CALL_BWD  =  2 
99+ 
100+     def  __str__ (self ) ->  str :
101+         return  self .name 
102+ 
103+ 
104+ @unique  
105+ class  PipelinePhase (Enum ):
106+     """ 
107+     Pipeline phase for the train pipeline 
108+ 
109+     please: 
110+         1. order the phases in the order of execution of base pipeline. 
111+         2. add notes to explain the phases if needed. 
112+ 
113+     """ 
114+ 
115+     def  __str__ (self ) ->  str :
116+         return  self .value 
117+ 
118+     def  __eq__ (self , obj : "PipelinePhase" ) ->  bool :
119+         return  self .value  ==  obj .value 
120+ 
121+     # placeholder for empty 
122+     NULL  =  "null" 
123+ 
124+     # usually the data is first available on CPU when loading from dataloader 
125+     # need to move/copy the input batch to device if using GPU training 
126+     COPY_BATCH_TO_DEVICE  =  "copy_batch_to_device" 
127+ 
128+     # input post processing is needed for sparse data dist pipeline, where the sparse features 
129+     # are traced (built) from the ModelInput via fx tracing 
130+     INPUT_POST_PROC  =  "input_post_proc" 
131+ 
132+     # the sparse features (AKA, KJTs) are in a jagged format so the data size are unknown to 
133+     # other ranks. so a comms is needed to exchange the data size info, i.e., the splits 
134+     INPUT_SPLITS_DIST  =  "input_splits_dist" 
135+ 
136+     # once a rank knows the data size from other ranks (via splits dist), it can initialize 
137+     # a all-to-all comms to exchange the actual data of the sparse features 
138+     # NOTE: the splits have to be available on the host side 
139+     INPUT_DATA_DIST  =  "input_data_dist" 
140+ 
141+     # embedding lookup is done in FBGEMM.TBE on each rank 
142+     EMBEDDING_LOOKUP  =  "embedding_lookup" 
143+ 
144+     # the embedding lookup results (i.e., the embeddings) are needed in each rank, it's often done 
145+     # with the output dist with an all_to_all comms 
146+     EMBEDDING_OUTPUT_DIST  =  "embedding_output_dist" 
147+ 
148+     # A typical DLRM model arch contains sparse arch and dense arch, here we treat the model excluding 
149+     # "sparse modules" as dense part. It actually also includes the dense-sharded embedding tables. 
150+     DENSE_FORWARD  =  "dense_forward" 
151+ 
152+     # model's backward usually uses torch.autograd, the embedding modules' backward is handled by TBE 
153+     DENSE_BACKWARD  =  "dense_backward" 
154+ 
155+     # on each rank, after dense arch's backward, the gradients are available for the embedding tables 
156+     # a backward of "embedding output dist" is needed to gather the embedding gradients from all ranks 
157+     # to the rank where the embedding table is hosted. 
158+     EMBEDDING_GRAD_DIST  =  "embedding_grad_dist" 
159+ 
160+     # TBE backward usually update the embedding table weights inplace 
161+     EMBEDDING_BACKWARD  =  "embedding_backward" 
162+ 
163+     # we decouple the embedding update from backward just in case the change is not coupled 
164+     EMBEDDING_UPDATE  =  "embedding_update" 
165+ 
166+     # the optimizer step usually only includes the dense module weights 
167+     DENSE_OPTIMIZER_STEP  =  "dense_optimizer_step" 
0 commit comments