Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion finetrainers/models/cogview4/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,11 @@ def forward(
latents = posterior.sample(generator=generator)
del posterior

latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
if getattr(self.vae_config, "shift_factor", None) is not None:
latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
else:
latents = latents * self.vae_config.scaling_factor

noise = torch.zeros_like(latents).normal_(generator=generator)
timesteps = (sigmas.flatten() * 1000.0).long()

Expand Down
2 changes: 1 addition & 1 deletion finetrainers/parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .accelerate import AccelerateParallelBackend
from .ptd import PytorchDTensorParallelBackend
from .utils import apply_ddp_ptd, apply_fsdp2_ptd, dist_max, dist_mean
from .utils import dist_max, dist_mean


ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend]
Expand Down
174 changes: 169 additions & 5 deletions finetrainers/parallel/accelerate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import datetime
import os
import pathlib
from typing import Optional
import shutil
import time
from typing import Any, Callable, Dict, Optional

import torch
from diffusers.utils import is_accelerate_available

from ..logging import get_logger
from ..utils import get_device_info
from .base import BaseParallelBackend
from .utils import apply_ddp_accelerate
from .base import BaseCheckpointer, BaseParallelBackend


if not is_accelerate_available():
Expand All @@ -23,6 +25,7 @@
DistributedDataParallelKwargs,
InitProcessGroupKwargs,
ProjectConfiguration,
set_seed,
)


Expand Down Expand Up @@ -68,9 +71,31 @@ def __init__(
if dp_degree != world_size:
raise ValueError("Data parallel degree must be equal to world size.")

self._accelerator: Accelerator = None
self._accelerator = None
if world_size == 1:
# Needs special handling for single GPU training
project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir)
dataloader_config = DataLoaderConfiguration(
split_batches=False, dispatch_batches=False, use_stateful_dataloader=True
)
init_process_group_kwargs = InitProcessGroupKwargs(
backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)
)
self._accelerator = Accelerator(
project_config=project_config,
dataloader_config=dataloader_config,
gradient_accumulation_steps=gradient_accumulation_steps,
log_with=None,
kwargs_handlers=[init_process_group_kwargs],
)
if torch.backends.mps.is_available():
self._accelerator.native_amp = False

self._mesh: torch.distributed.DeviceMesh = None

def enable_determinism(self, seed: int) -> None:
set_seed(seed)

def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
project_config = None
ddp_kwargs = None
Expand All @@ -84,7 +109,7 @@ def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
init_process_group_kwargs = InitProcessGroupKwargs(
backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)
)
self._accelerator, model = apply_ddp_accelerate(
self._accelerator, model = apply_ddp(
model,
project_config,
ddp_kwargs,
Expand All @@ -96,6 +121,9 @@ def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
logger.debug("Applied AccelerateParallel::apply_ddp to model.")
return model

def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module:
return self._accelerator.prepare_model(model)

def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
logger.debug("AccelerateParallelBackend::prepare_dataset completed!")
return dataset
Expand Down Expand Up @@ -161,6 +189,9 @@ def _get_mesh():
self._mesh = mesh
return _get_mesh()

def get_checkpointer(self, *args, **kwargs):
return AccelerateCheckpointer(self._accelerator, *args, **kwargs)

@property
def world_size(self):
return self._accelerator.num_processes
Expand Down Expand Up @@ -191,6 +222,8 @@ def wait_for_everyone(self):
self._accelerator.wait_for_everyone()

def destroy(self):
if self.is_main_process:
self.tracker.finish()
self._accelerator.end_training()

@property
Expand All @@ -216,3 +249,134 @@ def context_parallel_enabled(self):
@property
def tensor_parallel_enabled(self):
return self._tp_degree > 1


class AccelerateCheckpointer(BaseCheckpointer):
def __init__(
self,
accelerator: Accelerator,
states: Dict[str, Any],
checkpointing_steps: int,
checkpointing_limit: int,
output_dir: str,
enable: bool = True,
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
_prefix: str = "finetrainers_step",
*args,
**kwargs,
) -> None:
self.accelerator = accelerator
self.states = states

self.checkpointing_steps = checkpointing_steps
self.checkpointing_limit = checkpointing_limit
self.output_dir = pathlib.Path(output_dir)
self.enable = enable
self._callback_fn = _callback_fn
self._prefix = _prefix

def save_model_hook(models, weights, output_dir: str) -> None:
if not self.accelerator.is_main_process:
return

# TODO(aryan): this is a temporary assertion since we only support training transformer at the moment.
# Remove it when adding support for training text encoders/vae and more.
assert len(models) == 1

_callback_fn(weights[0])
torch.save(self.states, os.path.join(output_dir, "states.pt"))

def load_model_hook(models, input_dir) -> None:
self.states = torch.load(os.path.join(input_dir, "states.pt"))

self.accelerator.register_save_state_pre_hook(save_model_hook)
self.accelerator.register_load_state_pre_hook(load_model_hook)

logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'")

def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str:
if not self._should_checkpoint(step, force):
return None

checkpoint_dir = self._get_checkpoint_dir(step)
begin_time = time.monotonic()
self.accelerator.save_state(checkpoint_dir.as_posix(), safe_serialization=True)
end_time = time.monotonic()
logger.info(
f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}"
)
self._purge_stale_checkpoints()

return checkpoint_dir.as_posix()

def load(self, step: int = -1) -> bool:
if not self.enable:
return False
if not self.output_dir.exists():
return False
if step != -1 and not self._get_checkpoint_dir(step).exists():
return False

if step == -1:
latest_checkpoint_dir = self._find_latest_checkpoint_dir()
if latest_checkpoint_dir is None:
return False
step = int(latest_checkpoint_dir.name.split("_")[-1])

checkpoint_dir = self._get_checkpoint_dir(step)
logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}")

begin_time = time.monotonic()
self.accelerator.load_state(checkpoint_dir.as_posix())
end_time = time.monotonic()
logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.")

return True

def _should_checkpoint(self, step: int, force: bool) -> bool:
if not self.enable:
return False
if not force:
if step % self.checkpointing_steps != 0:
return False
return True

def _get_checkpoint_dir(self, step: int) -> pathlib.Path:
return self.output_dir / f"{self._prefix}_{step}"

def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]:
checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]))
return checkpoints[-1] if len(checkpoints) > 0 else None

def _purge_stale_checkpoints(self) -> None:
if self.checkpointing_limit is None or self.checkpointing_limit <= 0:
return
checkpoints = sorted(
self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True
)
for checkpoint in checkpoints[self.checkpointing_limit :]:
logger.info(f"Deleting stale checkpoint: {checkpoint}")
shutil.rmtree(checkpoint, ignore_errors=True)


def apply_ddp(
model: torch.nn.Module,
project_config: Optional[ProjectConfiguration] = None,
ddp_kwargs: Optional[DistributedDataParallelKwargs] = None,
init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None,
dataloader_config: Optional[DataLoaderConfiguration] = None,
gradient_accumulation_steps: Optional[int] = None,
accelerator: Optional[Accelerator] = None,
) -> torch.nn.Module:
if accelerator is None:
accelerator = Accelerator(
project_config=project_config,
dataloader_config=dataloader_config,
gradient_accumulation_steps=gradient_accumulation_steps,
log_with=None,
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
)
if torch.backends.mps.is_available():
accelerator.native_amp = False
accelerator.prepare_model(model)
return accelerator, model
44 changes: 43 additions & 1 deletion finetrainers/parallel/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import torch

Expand All @@ -11,9 +11,18 @@ class BaseParallelBackend:
Base class that contains properties and methods that should be implemented by different parallel backends.
"""

def enable_determinism(self, seed: int) -> None:
raise NotImplementedError("Method `enable_determinism` must be implemented by subclass.")

def apply_ddp(self, *args, **kwargs) -> torch.nn.Module:
raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.")

def apply_fsdp2(self, *args, **kwargs) -> torch.nn.Module:
raise NotImplementedError("Method `apply_fsdp2` must be implemented by subclass.")

def prepare_model(self, *args, **kwargs) -> Any:
raise NotImplementedError("Method `prepare_model` must be implemented by subclass.")

def prepare_dataset(self, *args, **kwargs) -> Any:
raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.")

Expand All @@ -26,6 +35,9 @@ def prepare_optimizer(self, *args, **kwargs) -> Any:
def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
raise NotImplementedError("Method `get_mesh` must be implemented by subclass.")

def get_checkpointer(self, *args, **kwargs) -> None:
raise NotImplementedError("Method `get_checkpointer` must be implemented by subclass.")

def initialize_trackers(
self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
) -> TrackerType:
Expand Down Expand Up @@ -94,3 +106,33 @@ def context_parallel_enabled(self):
@property
def tensor_parallel_enabled(self):
raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.")


class BaseCheckpointer:
r"""
Base class that contains properties and methods that should be implemented by different parallel backends.
"""

def __init__(
self,
dataloader: torch.utils.data.DataLoader,
model_parts: List[torch.nn.Module],
optimizers: Any,
schedulers: Any,
states: Dict[str, Any],
checkpointing_steps: int,
checkpointing_limit: int,
output_dir: str,
enable: bool = True,
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
_prefix: str = "finetrainers_step",
*args,
**kwargs,
) -> None:
raise NotImplementedError("Method `__init__` must be implemented by subclass.")

def save(self, step: int, force: bool, *, _device: Optional[torch.device] = None, _is_main_process: bool) -> str:
raise NotImplementedError("Method `save` must be implemented by subclass.")

def load(self, step: int = -1) -> bool:
raise NotImplementedError("Method `load` must be implemented by subclass.")
Loading