Skip to content

Commit

Permalink
FSDP: Enable limiting scope using trainer grid rows/columns (#2424)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun authored Feb 14, 2024
1 parent 3d61161 commit 2f691af
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 9 deletions.
18 changes: 18 additions & 0 deletions applications/nlp/transformer/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@

#############################################################################

def _get_sharding_strategy(args: argparse.Namespace) -> lbann.ShardingStrategy:
if args.fsdp_ranks > 0:
return lbann.ShardingStrategy.GRID_ROWS
return lbann.ShardingStrategy.FULL


# Fully-sharded data parallelism (MLP only)
def apply_fsdp_mlp(module: lbann.models.Transformer,
Expand All @@ -34,11 +39,14 @@ def apply_fsdp_mlp(module: lbann.models.Transformer,
enumerate(module.decoder)):
for w in submodule.fc1_weights:
w.sharded = True
w.sharding_strategy = _get_sharding_strategy(args)
for w in submodule.fc2_weights:
w.sharded = True
w.sharding_strategy = _get_sharding_strategy(args)

for w in other_weights:
w.sharded = True
w.sharding_strategy = _get_sharding_strategy(args)


# Fully-sharded data parallelism (all weights)
Expand All @@ -63,6 +71,7 @@ def apply_fsdp_allweights(model: lbann.Model, args: argparse.Namespace):
if layer.weights:
if len(layer.weights) > 0:
layer.weights[0].sharded = True
layer.weights[0].sharding_strategy = _get_sharding_strategy(args)


# Model (FFN tensor) parallelism
Expand Down Expand Up @@ -254,6 +263,15 @@ def add_transformer_parallelism_arguments(parser: argparse.Namespace,
help='Apply Fully-Sharded Data-Parallelism (FSDP) and shard all weights'
)

parser.add_argument(
'--fsdp-ranks',
default=0,
type=int,
help='Number of consecutive nodes to shard weights in FSDP. This '
'setting will modify the LBANN process grid height. (default: 0, shard '
'across all ranks)'
)

parser.add_argument(
'--fsdp-mlp',
action='store_true',
Expand Down
4 changes: 4 additions & 0 deletions applications/nlp/transformer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ def make_batch_script(model: lbann.Model,
"LBANN_DISABLE_DISTCONV": 1,
}

# Set FSDP ranks, if given, by changing the trainer grid height
if args.fsdp_ranks > 0:
script_params['environment']['LBANN_TRAINER_GRID_HEIGHT'] = args.fsdp_ranks

save_text = args.save_prototext
filename = 'experiment.prototext' if save_text else 'experiment.protobin'
# Create Protobuf file
Expand Down
7 changes: 7 additions & 0 deletions include/lbann/weights/weights.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ class weights : public Cloneable<HasAbstractFunction<weights>>
bool is_sharded() const { return m_sharded; }
/** Set weight sharding configuration. */
void set_sharded(bool value) { m_sharded = value; }
/** Get sharding distribution (VC, MC, MR, or STAR if not sharded). */
El::Dist get_sharding_distribution() const { return m_sharding_strategy; }
/** Set sharding distribution (VC, MC, MR, or STAR if not sharded). */
void set_sharding_distribution(El::Dist dist) { m_sharding_strategy = dist; }

// -----------------------------------------------
// Freezing
Expand Down Expand Up @@ -370,6 +374,9 @@ class weights : public Cloneable<HasAbstractFunction<weights>>

/** Whether weights are sharded across ranks. */
bool m_sharded;

/** How weights are sharded across ranks. */
El::Dist m_sharding_strategy;
};

} // namespace lbann
Expand Down
18 changes: 16 additions & 2 deletions python/lbann/core/weights.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Trainable model parameters."""
import abc
from lbann import weights_pb2
from enum import Enum
from typing import Optional
import lbann.core.util

class Initializer(abc.ABC):
Expand All @@ -22,19 +24,28 @@ def export_proto(self):
for c in classes:
globals()[c.__name__] = c


class ShardingStrategy(Enum):
FULL = 0 # Sharded across all ranks (STAR x VC)
GRID_ROWS = 1 # Sharded across the process grid rows (STAR x MC)
GRID_COLS = 2 # Sharded across the process grid columns (STAR x MR)


class Weights:
"""Trainable parameters for neural network."""

global_count = 0 # Static counter, used for default names

def __init__(self, initializer=None, optimizer=None, name=None, datatype=None,
sharded=None):
def __init__(self, initializer=None, optimizer=None, name=None,
datatype=None, sharded=None,
sharding_strategy: Optional[ShardingStrategy] = None):
Weights.global_count += 1
self.name = name if name else 'weights{0}'.format(Weights.global_count)
self.initializer = initializer
self.optimizer = optimizer
self.datatype = datatype
self.sharded = sharded
self.sharding_strategy = sharding_strategy

def export_proto(self):
"""Construct and return a protobuf message."""
Expand All @@ -58,4 +69,7 @@ def export_proto(self):
if self.sharded:
proto.sharded = self.sharded

if self.sharding_strategy is not None:
proto.sharding_strategy = self.sharding_strategy.value

return proto
19 changes: 19 additions & 0 deletions src/proto/factories/weights_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,26 @@ lbann::proto::construct_weights(lbann_comm* comm,
w->set_name(name);
}

// Set sharding configuration and strategy
w->set_sharded(proto_weights.sharded());
if (proto_weights.sharded()) {
El::Dist dist;
switch (proto_weights.sharding_strategy()) {
case lbann_data::ShardingStrategy::FULL:
dist = El::VC;
break;
case lbann_data::ShardingStrategy::GRID_ROWS:
dist = El::MC;
break;
case lbann_data::ShardingStrategy::GRID_COLS:
dist = El::MR;
break;
default:
dist = El::STAR;
break;
}
w->set_sharding_distribution(dist);
}

// Set weights initializer and optimizer
w->set_initializer(std::move(init));
Expand Down
7 changes: 7 additions & 0 deletions src/proto/weights.proto
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,19 @@ import "optimizers.proto";

package lbann_data;

enum ShardingStrategy {
FULL = 0; // Sharded across all ranks (STAR x VC)
GRID_ROWS = 1; // Sharded across the process grid rows (STAR x MC)
GRID_COLS = 2; // Sharded across the process grid columns (STAR x MR)
}

message Weights {
string name = 1;
Optimizer optimizer = 2;
Initializer initializer = 3;
DataType datatype = 4;
bool sharded = 5;
ShardingStrategy sharding_strategy = 6;
}

message Initializer {
Expand Down
9 changes: 5 additions & 4 deletions src/weights/data_type_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,17 @@ void data_type_weights<TensorDataType>::do_setup_()
}

// Construct matrix for weight values
// If sharded, use STAR_VC distribution (column distributed) or VC_STAR (row
// If sharded, use STAR_{VC,MC,MR} distribution or {VC,MC,MR}_STAR (row
// distributed) if width=1.
auto dist = this->get_sharding_distribution();
auto matrix_dist = this->get_matrix_distribution();
bool must_use_vc_star = (this->get_matrix_width() == 1);
bool must_use_x_star = (this->get_matrix_width() == 1);
m_values.reset(AbsDistMatrixType::Instantiate(
*matrix_dist.grid,
matrix_dist.root,
this->is_sharded() ? (must_use_vc_star ? El::VC : El::STAR)
this->is_sharded() ? (must_use_x_star ? dist : El::STAR)
: matrix_dist.colDist,
this->is_sharded() ? (must_use_vc_star ? El::STAR : El::VC)
this->is_sharded() ? (must_use_x_star ? El::STAR : dist)
: matrix_dist.rowDist,
(matrix_dist.blockHeight == 1 && matrix_dist.blockWidth == 1 ? El::ELEMENT
: El::BLOCK),
Expand Down
11 changes: 8 additions & 3 deletions src/weights/weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ std::string get_dims_string(const std::vector<size_t>& matrix_height_dims,

} // namespace

weights::weights() : m_comm(nullptr), m_frozen(false), m_sharded(false)
weights::weights()
: m_comm(nullptr),
m_frozen(false),
m_sharded(false),
m_sharding_strategy(El::STAR)
{

// Initialize weights name
Expand All @@ -84,12 +88,13 @@ weights::weights(lbann_comm& comm) : weights()
template <typename ArchiveT>
void weights::serialize(ArchiveT& ar)
{
ar(CEREAL_NVP(m_name), CEREAL_NVP(m_frozen));
ar(CEREAL_NVP(m_name), CEREAL_NVP(m_frozen), CEREAL_NVP(m_sharded));

// What about:
// m_matrix_height_dims
// m_matrix_width_dims
// m_matrix_dist
// m_sharding_strategy
}

description weights::get_description() const
Expand Down Expand Up @@ -118,7 +123,7 @@ description weights::get_description() const

// Sharding state
if (is_sharded()) {
desc.add("Sharded");
desc.add("Sharded, distribution", get_sharding_distribution());
}

// Derived class contribution
Expand Down

0 comments on commit 2f691af

Please sign in to comment.