Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 120 additions & 80 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
AutocastKwargs,
DataLoaderConfiguration,
DeepSpeedPlugin,
DistributedDataParallelKwargs,
DistributedType,
DynamoBackend,
FP8RecipeKwargs,
Expand All @@ -64,13 +63,13 @@
LoggerType,
MegatronLMPlugin,
MSAMPRecipeKwargs,
ParallelismConfig,
PrecisionType,
ProfileKwargs,
ProjectConfiguration,
RNGType,
TERecipeKwargs,
TorchDynamoPlugin,
TorchTensorParallelPlugin,
apply_fp8_autowrap,
check_os_kernel,
clean_state_dict_for_safetensors,
Expand Down Expand Up @@ -206,9 +205,6 @@ class Accelerator:
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
Tweak your torch tensor parallel. This argument is optional and can be configured directly using
*accelerate config*
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
Expand Down Expand Up @@ -279,7 +275,6 @@ def __init__(
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
torch_tp_plugin: TorchTensorParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
Expand All @@ -291,6 +286,7 @@ def __init__(
dynamo_backend: DynamoBackend | str | None = None,
dynamo_plugin: TorchDynamoPlugin | None = None,
deepspeed_plugins: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
parallelism_config: ParallelismConfig | None = None,
):
self.trackers = []
if project_config is not None:
Expand Down Expand Up @@ -364,13 +360,6 @@ def __init__(
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")

if isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")

if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")

if fsdp_plugin is None: # init from env variables
fsdp_plugin = (
FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
Expand All @@ -384,9 +373,6 @@ def __init__(
if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")

if torch_tp_plugin is not None and not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")

if megatron_lm_plugin is None: # init from env variables
megatron_lm_plugin = (
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
Expand All @@ -401,7 +387,6 @@ def __init__(
raise ImportError("Megatron is not installed. please build it from source.")

# Kwargs handlers
self.ddp_handler = None
self.scaler_handler = None
self.init_handler = None
self.fp8_recipe_handler = None
Expand All @@ -414,7 +399,6 @@ def __init__(

found_handlers = set()
handler_class_to_attr = {
DistributedDataParallelKwargs: "ddp_handler",
GradScalerKwargs: "scaler_handler",
InitProcessGroupKwargs: "init_handler",
FP8RecipeKwargs: "fp8_recipe_handler",
Expand All @@ -439,19 +423,30 @@ def __init__(
if "recipe_handler" in handler_attr and not self.has_fp8_handler:
self.has_fp8_handler = True

parallelism_config = parallelism_config or ParallelismConfig()
parallelism_config._init_from_kwargs(kwargs_handlers)

kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
self.state = AcceleratorState(
mixed_precision=mixed_precision,
cpu=cpu,
dynamo_plugin=dynamo_plugin,
deepspeed_plugin=deepspeed_plugins,
fsdp_plugin=fsdp_plugin,
torch_tp_plugin=torch_tp_plugin,
megatron_lm_plugin=megatron_lm_plugin,
parallelism_config=parallelism_config,
_from_accelerator=True,
**kwargs,
)

# Helper flag to check if we are in a composable parallelism setup
# Later we can add DeepSpeed, etc
self._composable_parallelism_enabled = self.is_fsdp2

# This is a bit clunky, as this needs to be called after `AcceleratorState` is initialized, but _init_from_kwargs has to be called before
parallelism_config._validate(self)
self._set_device_mesh()

self.fp8_enabled = self.state.mixed_precision == "fp8" or mixed_precision == "fp8"

# Check for automatic FP8 recipe creation
Expand Down Expand Up @@ -623,10 +618,26 @@ def use_distributed(self):
"""
return self.state.use_distributed

@property
def multi_device(self):
return self.use_distributed and self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_SDAA,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
)

@property
def distributed_type(self):
return self.state.distributed_type

@property
def parallelism_config(self):
return self.state.parallelism_config

@property
def num_processes(self):
return self.state.num_processes
Expand Down Expand Up @@ -707,6 +718,14 @@ def mixed_precision(self):
def is_fsdp2(self):
return self.state.is_fsdp2

@property
def is_composable_parallelism_enabled(self):
return self.is_fsdp2

@property
def device_mesh(self):
return self.state.device_mesh

@contextmanager
def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
"""
Expand Down Expand Up @@ -1211,15 +1230,7 @@ def join_uneven_inputs(self, joinables, even_batches=None):
... optimizer.zero_grad()
```
"""
if self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_SDAA,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
):
if self.multi_device:
dl_even_batches_values = []

if even_batches is not None:
Expand Down Expand Up @@ -1621,43 +1632,37 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
elif device_placement and not self.verify_device_map(model):
model = model.to(self.device)
if not evaluation_mode:
if self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_SDAA,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
):
if any(p.requires_grad for p in model.parameters()):
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
# TODO: Look at enabling native TP training directly with a proper config
if os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true":
if self.device.type == "hpu":
device_ids, output_device = [self.device.index], self.device.index
if self.multi_device:
if self.parallelism_config.dp_enabled:
if any(p.requires_grad for p in model.parameters()):
kwargs = (
self.parallelism_config.dp_handler.to_kwargs()
if self.parallelism_config.dp_handler is not None
else {}
)
# TODO: Look at enabling native TP training directly with a proper config
if os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true":
if self.device.type == "hpu":
device_ids, output_device = [self.device.index], self.device.index
else:
device_ids, output_device = [self.local_process_index], self.local_process_index
else:
device_ids, output_device = [self.local_process_index], self.local_process_index
else:
device_ids, output_device = None, None

model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=device_ids, output_device=output_device, **kwargs
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
if not hasattr(model, "tp_size"):
raise NotImplementedError(
"Model should undergo tensor parallel before passing it to accelerate."
"You can use .from_pretrained(..., tp_plan='auto') if the model supports"
)
if model.tp_size != self.state.torch_tp_plugin.tp_size:
raise ValueError(
f"tp_size in the plugin {self.state.torch_tp_plugin.tp_size} should be same as model's tp size {model.tp_size}"
)
device_ids, output_device = None, None
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=device_ids, output_device=output_device, **kwargs
)
if self.parallelism_config.dp_handler is not None:
self.parallelism_config.dp_handler.register_comm_hook(model)
elif self.parallelism_config.tp_enabled:
if not hasattr(model, "tp_size"):
raise NotImplementedError(
"Model should undergo tensor parallel before passing it to accelerate."
"You can use .from_pretrained(..., tp_plan='auto') if the model supports"
)
if model.tp_size != self.parallelism_config.tp_size:
raise ValueError(
f"tp_size in the plugin {self.parallelism_config.tp_size} should be same as model's tp size {model.tp_size}"
)
elif self.is_fsdp2:
raise ValueError(
"FSDP2 preparation should be done via `accelerate.prepare()`, as it requires a model and an optimizer."
Expand Down Expand Up @@ -1791,11 +1796,11 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
self._models[-1] = model
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
elif self.distributed_type == DistributedType.MULTI_CPU and self.parallelism_config.dp_enabled:
kwargs = self.parallelism_config.dp_handler.to_kwargs() if self.parallelism_config.dp_handler else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
if self.parallelism_config.dp_handler is not None:
self.parallelism_config.dp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# Now we can apply the FP8 autocast
Expand All @@ -1809,6 +1814,49 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
return model

def _set_device_mesh(self):
mesh_dims = {}
pc = self.parallelism_config

if pc is not None and pc.tp_enabled:
mesh_dims["tp"] = pc.tp_size
if pc is not None and pc.dp_enabled:
mesh_dims["dp"] = pc.dp_size

if self.is_fsdp2:
mesh_dims["fsdp"] = self.num_processes // (pc.tp_size * pc.dp_size)

# mesh_dims["cp"] = 1
# mesh_dims["pp"] = 1
# mesh_dims["ep"] = 1

if len(mesh_dims) == 0:
self.state.device_mesh = None
return

# Sort mesh_dims by the canonical order: "dp", "fsdp", "pp", "cp", "tp", "ep"
mesh_order = ["pp", "dp", "fsdp", "cp", "tp", "ep"]
sorted_items = sorted(
mesh_dims.items(), key=lambda x: mesh_order.index(x[0]) if x[0] in mesh_order else len(mesh_order)
)
sorted_items = [(name, value) for name, value in sorted_items if value > 1]
mesh_names = tuple(name for name, _ in sorted_items)
mesh_values = tuple(value for _, value in sorted_items)

device_mesh = torch.distributed.init_device_mesh(
self.device.type, mesh_shape=mesh_values, mesh_dim_names=mesh_names
)

# device_mesh[("fsdp", "cp")]._flatten("fsdp_cp")
if all(name in device_mesh.mesh_dim_names for name in ("fsdp", "dp")):
device_mesh[("dp", "fsdp")]._flatten("dp_fsdp")

# ("cp", "dp", "pp" and "ep") will be used for CP, DP, PP and EP respectively
# ("fsdp", "cp") compose a "fsdp_cp" submesh, over which model is going to be sharded
# ("dp", "fsdp") compose a "dp_fsdp" submesh, over which data is going to be replicated
# ("dp", "fsdp_cp") compose a 2D mesh, where model is sharded over "fsdp_cp" and replicated over "dp", resulting in a HSDP support
self.state.device_mesh = device_mesh

def _prepare_ao(self, *args):
if not is_torchao_available():
raise ImportError(
Expand Down Expand Up @@ -2319,11 +2367,10 @@ def _prepare_device_mesh(self):
Prepare the device mesh for distributed training. The dataloader will determine how to load data based on the
device mesh.
"""
if self.state.torch_tp_plugin:
return self.state.torch_tp_plugin.torch_device_mesh
elif self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
if self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
return self.state.ds_device_mesh
return None
else:
return self.state.device_mesh

def _prepare_msamp(self, *args, device_placement):
if not is_msamp_available():
Expand Down Expand Up @@ -3544,14 +3591,7 @@ def _inner(folder):

map_location = load_model_func_kwargs.pop("map_location", None)
if map_location is None:
if self.num_processes > 1 and self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_SDAA,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_HPU,
):
if self.num_processes > 1 and self.multi_device and self.distributed_type != DistributedType.MULTI_XPU:
map_location = "on_device"
else:
map_location = "cpu"
Expand Down
9 changes: 5 additions & 4 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DistributedType,
DynamoBackend,
GradientAccumulationPlugin,
ParallelismConfig,
check_cuda_fp8_capability,
check_cuda_p2p_ib_support,
deepspeed_required,
Expand Down Expand Up @@ -866,6 +867,8 @@ class AcceleratorState:
- **device** (`torch.device`) -- The device to use.
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
in use.
- **parallelism_config** ([`~accelerate.utils.ParallelismConfig`]) -- The parallelism configuration
for the current training environment. This is used to configure the distributed training environment.
- **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
- **local_process_index** (`int`) -- The index of the current process on the current server.
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
Expand Down Expand Up @@ -894,8 +897,8 @@ def __init__(
dynamo_plugin=None,
deepspeed_plugin=None,
fsdp_plugin=None,
torch_tp_plugin=None,
megatron_lm_plugin=None,
parallelism_config: ParallelismConfig | None = None,
_from_accelerator: bool = False,
**kwargs,
):
Expand All @@ -908,8 +911,8 @@ def __init__(
self._check_initialized(mixed_precision, cpu)
if not self.initialized:
self.deepspeed_plugins = None
self.parallelism_config = parallelism_config
self.use_ipex = None
self.torch_tp_plugin = torch_tp_plugin
mixed_precision = (
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
if mixed_precision is None
Expand Down Expand Up @@ -985,8 +988,6 @@ def __init__(
self.distributed_type = DistributedType.MEGATRON_LM
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
self.megatron_lm_plugin = megatron_lm_plugin
if self.torch_tp_plugin is not None:
self.distributed_type = DistributedType.TP
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
if is_ipex_available():
# check if user disables it explicitly
Expand Down
Loading