Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ class ModelArguments:
)
},
)
spmd_2d_sharding: int = field(
default=0,
metadata={
"help": (
"Will apply XLA SPMD to 2D sharding, i.e., weights + activations, and spmd_2d_sharding specifies the model dimension"
)
},
)

def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
Expand Down Expand Up @@ -274,6 +282,7 @@ def main():
training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding
training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding
training_args.spmd_tensor_sharding = model_args.spmd_tensor_sharding
training_args.spmd_2d_sharding = model_args.spmd_2d_sharding

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
Expand Down Expand Up @@ -444,6 +453,8 @@ def main():
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)

# Pass the 2d sharding config to the actual model.
config.spmd_2d_sharding = model_args.spmd_2d_sharding
if model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
Expand Down Expand Up @@ -513,6 +524,42 @@ def main():
else:
assert len(param.shape) == 2
xs.mark_sharding(param, mesh, range(len(param.shape)))
elif model_args.spmd_2d_sharding > 0:
print('Applying 2D sharding to all parameters')
for name, param in model.named_parameters():
# Apply 2D sharding:
# embedding (model, data)
# attn QKV (data, model)
# attn O (model, data)
# mlp gate, up (model, data)
# mlp down (data, model)
print('> Sharding tensor', name, param.shape)
mod = model_args.spmd_2d_sharding
data = num_devices // mod
assert mod * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, mod))
model_data_mesh = xs.HybridMesh(ici_mesh_shape=(mod, data))

# We don't care about layernorm's weights, and
# LLaMA doesn't use biases.
if len(param.shape) == 1:
continue

if 'embed_tokens' in name:
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
elif 'q_proj' in name or 'k_proj' in name or 'v_proj' in name:
xs.mark_sharding(param, data_model_mesh, range(len(param.shape)))
elif 'o_proj' in name:
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
elif 'gate_proj' in name or 'up_proj' in name:
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
elif 'down_proj' in name:
xs.mark_sharding(param, data_model_mesh, range(len(param.shape)))
elif 'lm_head' in name: # Not sure what this is but has the same shape as embed_tokens
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))

import torch_xla
print(torch_xla._XLAC._get_xla_sharding_spec(param))

# Preprocessing the datasets.
# First we tokenize all the texts.
Expand Down
37 changes: 37 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
# For PyTorch/XLA's SPMD 2D sharding
self.spmd_2d_sharding = config.spmd_2d_sharding
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
Expand Down Expand Up @@ -370,6 +372,22 @@ def forward(
if not output_attentions:
attn_weights = None

# Apply 2D sharding:
# activation (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding activations', attn_output.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(attn_output))

return attn_output, attn_weights, past_key_value


Expand Down Expand Up @@ -559,6 +577,9 @@ class LlamaModel(LlamaPreTrainedModel):

def __init__(self, config: LlamaConfig):
super().__init__(config)
# For PyTorch/XLA's SPMD 2D sharding
self.spmd_2d_sharding = config.spmd_2d_sharding

self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

Expand Down Expand Up @@ -658,7 +679,23 @@ def forward(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

# Is this the input to the model?
hidden_states = inputs_embeds
# Apply 2D sharding:
# input (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding hidden_states', hidden_states.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,10 +1460,11 @@ def _xla_sharded_dataloader(self, dataloader):
if self.args.spmd_batch_sharding:
mesh = xs.Mesh(device_ids, (num_devices, 1))
sharding_spec = xs.ShardingSpec(mesh, (0, 1))
elif self.args.spmd_tensor_sharding > 0:
tensor = self.args.spmd_tensor_sharding
elif self.args.spmd_tensor_sharding > 0 or self.args.spmd_2d_sharding > 0:
assert self.args.spmd_tensor_sharding == 0 or self.args.spmd_2d_sharding == 0
tensor = self.args.spmd_tensor_sharding + self.args.spmd_2d_sharding
fsdp = num_devices // tensor
mesh = xs.Mesh(device_ids, (fsdp, tensor))
mesh = xs.HybridMesh(ici_mesh_shape=(fsdp, tensor))
partition_spec = (0, None)
sharding_spec = xs.ShardingSpec(mesh, partition_spec)

Expand Down