diff --git a/torchrec/distributed/train_pipeline/types.py b/torchrec/distributed/train_pipeline/types.py index 1c37eb88b..3cbaeca7f 100644 --- a/torchrec/distributed/train_pipeline/types.py +++ b/torchrec/distributed/train_pipeline/types.py @@ -96,3 +96,72 @@ class PipelineState(Enum): IDLE = 0 CALL_FWD = 1 CALL_BWD = 2 + + def __str__(self) -> str: + return self.name + + +@unique +class PipelinePhase(Enum): + """ + Pipeline phase for the train pipeline + + please: + 1. order the phases in the order of execution of base pipeline. + 2. add notes to explain the phases if needed. + + """ + + def __str__(self) -> str: + return self.value + + def __eq__(self, obj: "PipelinePhase") -> bool: + return self.value == obj.value + + # placeholder for empty + NULL = "null" + + # usually the data is first available on CPU when loading from dataloader + # need to move/copy the input batch to device if using GPU training + COPY_BATCH_TO_DEVICE = "copy_batch_to_device" + + # input post processing is needed for sparse data dist pipeline, where the sparse features + # are traced (built) from the ModelInput via fx tracing + INPUT_POST_PROC = "input_post_proc" + + # the sparse features (AKA, KJTs) are in a jagged format so the data size are unknown to + # other ranks. so a comms is needed to exchange the data size info, i.e., the splits + INPUT_SPLITS_DIST = "input_splits_dist" + + # once a rank knows the data size from other ranks (via splits dist), it can initialize + # a all-to-all comms to exchange the actual data of the sparse features + # NOTE: the splits have to be available on the host side + INPUT_DATA_DIST = "input_data_dist" + + # embedding lookup is done in FBGEMM.TBE on each rank + EMBEDDING_LOOKUP = "embedding_lookup" + + # the embedding lookup results (i.e., the embeddings) are needed in each rank, it's often done + # with the output dist with an all_to_all comms + EMBEDDING_OUTPUT_DIST = "embedding_output_dist" + + # A typical DLRM model arch contains sparse arch and dense arch, here we treat the model excluding + # "sparse modules" as dense part. It actually also includes the dense-sharded embedding tables. + DENSE_FORWARD = "dense_forward" + + # model's backward usually uses torch.autograd, the embedding modules' backward is handled by TBE + DENSE_BACKWARD = "dense_backward" + + # on each rank, after dense arch's backward, the gradients are available for the embedding tables + # a backward of "embedding output dist" is needed to gather the embedding gradients from all ranks + # to the rank where the embedding table is hosted. + EMBEDDING_GRAD_DIST = "embedding_grad_dist" + + # TBE backward usually update the embedding table weights inplace + EMBEDDING_BACKWARD = "embedding_backward" + + # we decouple the embedding update from backward just in case the change is not coupled + EMBEDDING_UPDATE = "embedding_update" + + # the optimizer step usually only includes the dense module weights + DENSE_OPTIMIZER_STEP = "dense_optimizer_step"