From 191fa67fd8c6bdb7a8ba97f2efd45810b61be322 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 12 Feb 2024 22:49:28 -0800 Subject: [PATCH] FSDP: Enable limiting scope using trainer grid rows/columns --- applications/nlp/transformer/parallelism.py | 18 ++++++++++++++++++ applications/nlp/transformer/trainer.py | 4 ++++ include/lbann/weights/weights.hpp | 7 +++++++ python/lbann/core/weights.py | 18 ++++++++++++++++-- src/proto/factories/weights_factory.cpp | 19 +++++++++++++++++++ src/proto/weights.proto | 7 +++++++ src/weights/data_type_weights.cpp | 9 +++++---- src/weights/weights.cpp | 11 ++++++++--- 8 files changed, 84 insertions(+), 9 deletions(-) diff --git a/applications/nlp/transformer/parallelism.py b/applications/nlp/transformer/parallelism.py index 3c5e90a50c4..7c56faf51eb 100644 --- a/applications/nlp/transformer/parallelism.py +++ b/applications/nlp/transformer/parallelism.py @@ -11,6 +11,11 @@ ############################################################################# +def _get_sharding_strategy(args: argparse.Namespace) -> lbann.ShardingStrategy: + if args.fsdp_ranks > 0: + return lbann.ShardingStrategy.GRID_ROWS + return lbann.ShardingStrategy.FULL + # Fully-sharded data parallelism (MLP only) def apply_fsdp_mlp(module: lbann.models.Transformer, @@ -34,11 +39,14 @@ def apply_fsdp_mlp(module: lbann.models.Transformer, enumerate(module.decoder)): for w in submodule.fc1_weights: w.sharded = True + w.sharding_strategy = _get_sharding_strategy(args) for w in submodule.fc2_weights: w.sharded = True + w.sharding_strategy = _get_sharding_strategy(args) for w in other_weights: w.sharded = True + w.sharding_strategy = _get_sharding_strategy(args) # Fully-sharded data parallelism (all weights) @@ -63,6 +71,7 @@ def apply_fsdp_allweights(model: lbann.Model, args: argparse.Namespace): if layer.weights: if len(layer.weights) > 0: layer.weights[0].sharded = True + layer.weights[0].sharding_strategy = _get_sharding_strategy(args) # Model (FFN tensor) parallelism @@ -254,6 +263,15 @@ def add_transformer_parallelism_arguments(parser: argparse.Namespace, help='Apply Fully-Sharded Data-Parallelism (FSDP) and shard all weights' ) + parser.add_argument( + '--fsdp-ranks', + default=0, + type=int, + help='Number of consecutive nodes to shard weights in FSDP. This ' + 'setting will modify the LBANN process grid height. (default: 0, shard ' + 'across all ranks)' + ) + parser.add_argument( '--fsdp-mlp', action='store_true', diff --git a/applications/nlp/transformer/trainer.py b/applications/nlp/transformer/trainer.py index 0a0d040c998..d34a894ecd1 100644 --- a/applications/nlp/transformer/trainer.py +++ b/applications/nlp/transformer/trainer.py @@ -215,6 +215,10 @@ def make_batch_script(model: lbann.Model, "LBANN_DISABLE_DISTCONV": 1, } + # Set FSDP ranks, if given, by changing the trainer grid height + if args.fsdp_ranks > 0: + script_params['environment']['LBANN_TRAINER_GRID_HEIGHT'] = args.fsdp_ranks + save_text = args.save_prototext filename = 'experiment.prototext' if save_text else 'experiment.protobin' # Create Protobuf file diff --git a/include/lbann/weights/weights.hpp b/include/lbann/weights/weights.hpp index 745a8233457..9b83dc10e1f 100644 --- a/include/lbann/weights/weights.hpp +++ b/include/lbann/weights/weights.hpp @@ -262,6 +262,10 @@ class weights : public Cloneable> bool is_sharded() const { return m_sharded; } /** Set weight sharding configuration. */ void set_sharded(bool value) { m_sharded = value; } + /** Get sharding distribution (VC, MC, MR, or STAR if not sharded). */ + El::Dist get_sharding_distribution() const { return m_sharding_strategy; } + /** Set sharding distribution (VC, MC, MR, or STAR if not sharded). */ + void set_sharding_distribution(El::Dist dist) { m_sharding_strategy = dist; } // ----------------------------------------------- // Freezing @@ -370,6 +374,9 @@ class weights : public Cloneable> /** Whether weights are sharded across ranks. */ bool m_sharded; + + /** How weights are sharded across ranks. */ + El::Dist m_sharding_strategy; }; } // namespace lbann diff --git a/python/lbann/core/weights.py b/python/lbann/core/weights.py index cdc0816415c..128975fa654 100644 --- a/python/lbann/core/weights.py +++ b/python/lbann/core/weights.py @@ -1,6 +1,8 @@ """Trainable model parameters.""" import abc from lbann import weights_pb2 +from enum import Enum +from typing import Optional import lbann.core.util class Initializer(abc.ABC): @@ -22,19 +24,28 @@ def export_proto(self): for c in classes: globals()[c.__name__] = c + +class ShardingStrategy(Enum): + FULL = 0 # Sharded across all ranks (STAR x VC) + GRID_ROWS = 1 # Sharded across the process grid rows (STAR x MC) + GRID_COLS = 2 # Sharded across the process grid columns (STAR x MR) + + class Weights: """Trainable parameters for neural network.""" global_count = 0 # Static counter, used for default names - def __init__(self, initializer=None, optimizer=None, name=None, datatype=None, - sharded=None): + def __init__(self, initializer=None, optimizer=None, name=None, + datatype=None, sharded=None, + sharding_strategy: Optional[ShardingStrategy] = None): Weights.global_count += 1 self.name = name if name else 'weights{0}'.format(Weights.global_count) self.initializer = initializer self.optimizer = optimizer self.datatype = datatype self.sharded = sharded + self.sharding_strategy = sharding_strategy def export_proto(self): """Construct and return a protobuf message.""" @@ -58,4 +69,7 @@ def export_proto(self): if self.sharded: proto.sharded = self.sharded + if self.sharding_strategy is not None: + proto.sharding_strategy = self.sharding_strategy.value + return proto diff --git a/src/proto/factories/weights_factory.cpp b/src/proto/factories/weights_factory.cpp index b16dd46d137..540deb2fcbb 100644 --- a/src/proto/factories/weights_factory.cpp +++ b/src/proto/factories/weights_factory.cpp @@ -174,7 +174,26 @@ lbann::proto::construct_weights(lbann_comm* comm, w->set_name(name); } + // Set sharding configuration and strategy w->set_sharded(proto_weights.sharded()); + if (proto_weights.sharded()) { + El::Dist dist; + switch (proto_weights.sharding_strategy()) { + case lbann_data::ShardingStrategy::FULL: + dist = El::VC; + break; + case lbann_data::ShardingStrategy::GRID_ROWS: + dist = El::MC; + break; + case lbann_data::ShardingStrategy::GRID_COLS: + dist = El::MR; + break; + default: + dist = El::STAR; + break; + } + w->set_sharding_distribution(dist); + } // Set weights initializer and optimizer w->set_initializer(std::move(init)); diff --git a/src/proto/weights.proto b/src/proto/weights.proto index d0a431b207d..8b6cdf53177 100644 --- a/src/proto/weights.proto +++ b/src/proto/weights.proto @@ -31,12 +31,19 @@ import "optimizers.proto"; package lbann_data; +enum ShardingStrategy { + FULL = 0; // Sharded across all ranks (STAR x VC) + GRID_ROWS = 1; // Sharded across the process grid rows (STAR x MC) + GRID_COLS = 2; // Sharded across the process grid columns (STAR x MR) +} + message Weights { string name = 1; Optimizer optimizer = 2; Initializer initializer = 3; DataType datatype = 4; bool sharded = 5; + ShardingStrategy sharding_strategy = 6; } message Initializer { diff --git a/src/weights/data_type_weights.cpp b/src/weights/data_type_weights.cpp index de3da85c0e5..0c292f89471 100644 --- a/src/weights/data_type_weights.cpp +++ b/src/weights/data_type_weights.cpp @@ -239,16 +239,17 @@ void data_type_weights::do_setup_() } // Construct matrix for weight values - // If sharded, use STAR_VC distribution (column distributed) or VC_STAR (row + // If sharded, use STAR_{VC,MC,MR} distribution or {VC,MC,MR}_STAR (row // distributed) if width=1. + auto dist = this->get_sharding_distribution(); auto matrix_dist = this->get_matrix_distribution(); - bool must_use_vc_star = (this->get_matrix_width() == 1); + bool must_use_x_star = (this->get_matrix_width() == 1); m_values.reset(AbsDistMatrixType::Instantiate( *matrix_dist.grid, matrix_dist.root, - this->is_sharded() ? (must_use_vc_star ? El::VC : El::STAR) + this->is_sharded() ? (must_use_x_star ? dist : El::STAR) : matrix_dist.colDist, - this->is_sharded() ? (must_use_vc_star ? El::STAR : El::VC) + this->is_sharded() ? (must_use_x_star ? El::STAR : dist) : matrix_dist.rowDist, (matrix_dist.blockHeight == 1 && matrix_dist.blockWidth == 1 ? El::ELEMENT : El::BLOCK), diff --git a/src/weights/weights.cpp b/src/weights/weights.cpp index 50c97d4487b..f618c74743b 100644 --- a/src/weights/weights.cpp +++ b/src/weights/weights.cpp @@ -64,7 +64,11 @@ std::string get_dims_string(const std::vector& matrix_height_dims, } // namespace -weights::weights() : m_comm(nullptr), m_frozen(false), m_sharded(false) +weights::weights() + : m_comm(nullptr), + m_frozen(false), + m_sharded(false), + m_sharding_strategy(El::STAR) { // Initialize weights name @@ -84,12 +88,13 @@ weights::weights(lbann_comm& comm) : weights() template void weights::serialize(ArchiveT& ar) { - ar(CEREAL_NVP(m_name), CEREAL_NVP(m_frozen)); + ar(CEREAL_NVP(m_name), CEREAL_NVP(m_frozen), CEREAL_NVP(m_sharded)); // What about: // m_matrix_height_dims // m_matrix_width_dims // m_matrix_dist + // m_sharding_strategy } description weights::get_description() const @@ -118,7 +123,7 @@ description weights::get_description() const // Sharding state if (is_sharded()) { - desc.add("Sharded"); + desc.add("Sharded, distribution", get_sharding_distribution()); } // Derived class contribution