From 348450ee536fdbb0ad24351a6f410013fc0ea080 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 1 Aug 2023 04:55:43 +0000 Subject: [PATCH 1/6] Enable 2D parameter 1D activation --- examples/pytorch/language-modeling/run_clm.py | 45 +++++++++++++++++++ .../models/llama/modeling_llama.py | 14 ++++++ src/transformers/trainer.py | 2 +- 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 681c5836d611..709830bd3657 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -172,6 +172,14 @@ class ModelArguments: ) }, ) + spmd_2d_sharding: int = field( + default=0, + metadata={ + "help": ( + "Will apply XLA SPMD to 2D sharding, i.e., weights + activations, and spmd_2d_sharding specifies the model dimension" + ) + }, + ) def __post_init__(self): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): @@ -274,6 +282,7 @@ def main(): training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding training_args.spmd_tensor_sharding = model_args.spmd_tensor_sharding + training_args.spmd_2d_sharding = model_args.spmd_2d_sharding # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. @@ -513,6 +522,42 @@ def main(): else: assert len(param.shape) == 2 xs.mark_sharding(param, mesh, range(len(param.shape))) + elif model_args.spmd_2d_sharding > 0: + print('Applying 2D sharding to all parameters') + for name, param in model.named_parameters(): + # Apply 2D sharding: + # embedding (model, data) + # attn QKV (data, model) + # attn O (model, data) + # mlp gate, up (data, model) + # mlp down (model, data) + print('> Sharding tensor', name, param.shape) + model = model_args.spmd_2d_sharding + data = num_devices // model + assert model * data == num_devices + data_model_mesh = xs.Mesh(device_ids, (data, model)) + model_data_mesh = xs.Mesh(device_ids, (model, data)) + + # We don't care about layernorm's weights, and + # LLaMA doesn't use biases. + if len(param.shape) == 1: + continue + + if 'embed_tokens' in name: + xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) + elif 'q_proj' in name or 'k_proj' in name or 'v_proj' in name: + xs.mark_sharding(param, data_model_mesh, range(len(param.shape))) + elif 'o_proj' in name: + xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) + elif 'gate_proj' in name or 'up_proj' in name: + xs.mark_sharding(param, data_model_mesh, range(len(param.shape))) + elif 'down_proj' in name: + xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) + elif 'lm_head' in name: # Not sure what this is but has the same shape as embed_tokens + xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) + + import torch_xla + print(torch_xla._XLAC._get_xla_sharding_spec(param)) # Preprocessing the datasets. # First we tokenize all the texts. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5ba76d714761..1f6cc831970d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -370,6 +370,20 @@ def forward( if not output_attentions: attn_weights = None + # Apply 2D sharding: + # activation (data,, None, model) + # import torch_xla.core.xla_model as xm + # import torch_xla.experimental.xla_sharding as xs + # import torch_xla.runtime as xr + # num_devices = xr.global_runtime_device_count() + # device_ids = torch.arange(num_devices) + # print('> Sharding activations', attn_output.shape) + # model = 2 + # data = num_devices // model + # assert model * data == num_devices + # data_model_mesh = xs.Mesh(device_ids, (data, model)) + # xs.mark_sharding(attn_output, data_model_mesh, (0, None, 1)) + return attn_output, attn_weights, past_key_value diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8184df6df64c..8514b712b5a3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1460,7 +1460,7 @@ def _xla_sharded_dataloader(self, dataloader): if self.args.spmd_batch_sharding: mesh = xs.Mesh(device_ids, (num_devices, 1)) sharding_spec = xs.ShardingSpec(mesh, (0, 1)) - elif self.args.spmd_tensor_sharding > 0: + elif self.args.spmd_tensor_sharding > 0 or model_args.spmd_2d_sharding > 0: tensor = self.args.spmd_tensor_sharding fsdp = num_devices // tensor mesh = xs.Mesh(device_ids, (fsdp, tensor)) From ebe4e2e578250173556a782189e26ba3549880cb Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 1 Aug 2023 05:09:40 +0000 Subject: [PATCH 2/6] fix --- examples/pytorch/language-modeling/run_clm.py | 10 +++++----- src/transformers/trainer.py | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 709830bd3657..a0fdd8285fd7 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -532,11 +532,11 @@ def main(): # mlp gate, up (data, model) # mlp down (model, data) print('> Sharding tensor', name, param.shape) - model = model_args.spmd_2d_sharding - data = num_devices // model - assert model * data == num_devices - data_model_mesh = xs.Mesh(device_ids, (data, model)) - model_data_mesh = xs.Mesh(device_ids, (model, data)) + mod = model_args.spmd_2d_sharding + data = num_devices // mod + assert mod * data == num_devices + data_model_mesh = xs.Mesh(device_ids, (data, mod)) + model_data_mesh = xs.Mesh(device_ids, (mod, data)) # We don't care about layernorm's weights, and # LLaMA doesn't use biases. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8514b712b5a3..0e20b1f3446e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1460,8 +1460,9 @@ def _xla_sharded_dataloader(self, dataloader): if self.args.spmd_batch_sharding: mesh = xs.Mesh(device_ids, (num_devices, 1)) sharding_spec = xs.ShardingSpec(mesh, (0, 1)) - elif self.args.spmd_tensor_sharding > 0 or model_args.spmd_2d_sharding > 0: - tensor = self.args.spmd_tensor_sharding + elif self.args.spmd_tensor_sharding > 0 or self.args.spmd_2d_sharding > 0: + assert self.args.spmd_tensor_sharding == 0 or self.args.spmd_2d_sharding == 0 + tensor = self.args.spmd_tensor_sharding + self.args.spmd_2d_sharding fsdp = num_devices // tensor mesh = xs.Mesh(device_ids, (fsdp, tensor)) partition_spec = (0, None) From 1cf52fef7361fcdbabd535eb2088de2f9af53768 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 1 Aug 2023 19:19:19 +0000 Subject: [PATCH 3/6] shard activations and inputs --- .../models/llama/modeling_llama.py | 40 ++++++++++++++----- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 1f6cc831970d..5a7d6de13a79 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -372,17 +372,19 @@ def forward( # Apply 2D sharding: # activation (data,, None, model) - # import torch_xla.core.xla_model as xm - # import torch_xla.experimental.xla_sharding as xs - # import torch_xla.runtime as xr - # num_devices = xr.global_runtime_device_count() - # device_ids = torch.arange(num_devices) - # print('> Sharding activations', attn_output.shape) - # model = 2 - # data = num_devices // model - # assert model * data == num_devices - # data_model_mesh = xs.Mesh(device_ids, (data, model)) - # xs.mark_sharding(attn_output, data_model_mesh, (0, None, 1)) + import torch_xla.core.xla_model as xm + import torch_xla.experimental.xla_sharding as xs + import torch_xla.runtime as xr + import torch_xla + num_devices = xr.global_runtime_device_count() + device_ids = torch.arange(num_devices) + print('> Sharding activations', attn_output.shape) + model = 2 + data = num_devices // model + assert model * data == num_devices + data_model_mesh = xs.Mesh(device_ids, (data, 1, model)) + xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2)) + print(torch_xla._XLAC._get_xla_sharding_spec(attn_output)) return attn_output, attn_weights, past_key_value @@ -672,7 +674,23 @@ def forward( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) + # Is this the input to the model? hidden_states = inputs_embeds + # Apply 2D sharding: + # input (data,, None, model) + import torch_xla.core.xla_model as xm + import torch_xla.experimental.xla_sharding as xs + import torch_xla.runtime as xr + import torch_xla + num_devices = xr.global_runtime_device_count() + device_ids = torch.arange(num_devices) + print('> Sharding hidden_states', hidden_states.shape) + model = 2 + data = num_devices // model + assert model * data == num_devices + data_model_mesh = xs.Mesh(device_ids, (data, 1, model)) + xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2)) + print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states)) if self.gradient_checkpointing and self.training: if use_cache: From c90ead67ae3815cf73bb3f2b06e722d6b107e5a3 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 1 Aug 2023 19:31:33 +0000 Subject: [PATCH 4/6] Pass the 2d sharding config to the actual model --- examples/pytorch/language-modeling/run_clm.py | 2 ++ src/transformers/models/llama/modeling_llama.py | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index a0fdd8285fd7..466bfbaeeeab 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -453,6 +453,8 @@ def main(): "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) + # Pass the 2d sharding config to the actual model. + config.spmd_2d_sharding = model_args.spmd_2d_sharding if model_args.model_name_or_path: torch_dtype = ( model_args.torch_dtype diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5a7d6de13a79..73e57ee90476 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -240,6 +240,8 @@ class LlamaAttention(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.config = config + # For PyTorch/XLA's SPMD 2D sharding + self.spmd_2d_sharding = config.spmd_2d_sharding self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -379,7 +381,7 @@ def forward( num_devices = xr.global_runtime_device_count() device_ids = torch.arange(num_devices) print('> Sharding activations', attn_output.shape) - model = 2 + model = self.spmd_2d_sharding data = num_devices // model assert model * data == num_devices data_model_mesh = xs.Mesh(device_ids, (data, 1, model)) @@ -575,6 +577,9 @@ class LlamaModel(LlamaPreTrainedModel): def __init__(self, config: LlamaConfig): super().__init__(config) + # For PyTorch/XLA's SPMD 2D sharding + self.spmd_2d_sharding = config.spmd_2d_sharding + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -685,7 +690,7 @@ def forward( num_devices = xr.global_runtime_device_count() device_ids = torch.arange(num_devices) print('> Sharding hidden_states', hidden_states.shape) - model = 2 + model = self.spmd_2d_sharding data = num_devices // model assert model * data == num_devices data_model_mesh = xs.Mesh(device_ids, (data, 1, model)) From 70a798f85ccaa29c39982c8c8e1b48e9c0d3add6 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 1 Aug 2023 22:04:17 +0000 Subject: [PATCH 5/6] Use hybrid mesh and fix gate, up, down --- examples/pytorch/language-modeling/run_clm.py | 8 ++++---- src/transformers/models/llama/modeling_llama.py | 4 ++-- src/transformers/trainer.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 466bfbaeeeab..6a90e624ba86 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -537,8 +537,8 @@ def main(): mod = model_args.spmd_2d_sharding data = num_devices // mod assert mod * data == num_devices - data_model_mesh = xs.Mesh(device_ids, (data, mod)) - model_data_mesh = xs.Mesh(device_ids, (mod, data)) + data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, mod)) + model_data_mesh = xs.HybridMesh(ici_mesh_shape=(mod, data)) # We don't care about layernorm's weights, and # LLaMA doesn't use biases. @@ -552,9 +552,9 @@ def main(): elif 'o_proj' in name: xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) elif 'gate_proj' in name or 'up_proj' in name: - xs.mark_sharding(param, data_model_mesh, range(len(param.shape))) - elif 'down_proj' in name: xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) + elif 'down_proj' in name: + xs.mark_sharding(param, data_model_mesh, range(len(param.shape))) elif 'lm_head' in name: # Not sure what this is but has the same shape as embed_tokens xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 73e57ee90476..ca01a0e3e8bc 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -384,7 +384,7 @@ def forward( model = self.spmd_2d_sharding data = num_devices // model assert model * data == num_devices - data_model_mesh = xs.Mesh(device_ids, (data, 1, model)) + data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model)) xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2)) print(torch_xla._XLAC._get_xla_sharding_spec(attn_output)) @@ -693,7 +693,7 @@ def forward( model = self.spmd_2d_sharding data = num_devices // model assert model * data == num_devices - data_model_mesh = xs.Mesh(device_ids, (data, 1, model)) + data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model)) xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2)) print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states)) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0e20b1f3446e..cd4c08d5b0f0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1464,7 +1464,7 @@ def _xla_sharded_dataloader(self, dataloader): assert self.args.spmd_tensor_sharding == 0 or self.args.spmd_2d_sharding == 0 tensor = self.args.spmd_tensor_sharding + self.args.spmd_2d_sharding fsdp = num_devices // tensor - mesh = xs.Mesh(device_ids, (fsdp, tensor)) + mesh = xs.HybridMesh(ici_mesh_shape=(fsdp, tensor)) partition_spec = (0, None) sharding_spec = xs.ShardingSpec(mesh, partition_spec) From 012ae0c2674ea417bbfe9a7fe1196662041f2ba1 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 1 Aug 2023 22:27:33 +0000 Subject: [PATCH 6/6] Fix comments --- examples/pytorch/language-modeling/run_clm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 6a90e624ba86..299756a9e1ff 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -531,8 +531,8 @@ def main(): # embedding (model, data) # attn QKV (data, model) # attn O (model, data) - # mlp gate, up (data, model) - # mlp down (model, data) + # mlp gate, up (model, data) + # mlp down (data, model) print('> Sharding tensor', name, param.shape) mod = model_args.spmd_2d_sharding data = num_devices // mod