diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index fc26703d..68865ddb 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -206,6 +206,11 @@ def apply_tp( torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, + # Not for land in the current state, need to align on best way to expose this + # for various AC options. For now just hack it in here to get a clean + # measurement. + torch.ops.aten.abs.default, + torch.ops.aten.max.default, }