Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"rich>=14.1.0",
"safetensors>=0.6.2",
"tokenizers>=0.21.2",
"transformers>=4.56.1,<5",
"transformers>=5.0.0,<=5.3.0",
"typer>=0.17.4",
# "wandb>=0.22.0",
"peft",
Expand Down Expand Up @@ -72,7 +72,6 @@ skyrl-train = [
"ninja",
"tensorboard",
"func_timeout",
"transformers>=4.51.0",
"hydra-core==1.3.2",
"accelerate",
"torchdata",
Expand Down Expand Up @@ -217,6 +216,7 @@ override-dependencies = [
"causal-conv1d; sys_platform == 'never'",
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"megatron-core==0.16.1; sys_platform == 'linux'",
"transformers>=5.0.0,<=5.3.0; sys_platform == 'linux'",
"ml_dtypes>=0.5.0; sys_platform == 'linux'",
]

Expand Down
55 changes: 41 additions & 14 deletions skyrl/backends/skyrl_train/distributed/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,13 @@ def init_fn(x: torch.nn.Module):
return x


def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None):
from accelerate import init_empty_weights

def cpu_init_weights():
return torch.device("cpu")

if use_meta_tensor:
if mesh is None:
init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights
else:
init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights
else:
init_context = cpu_init_weights
return init_context
def should_use_meta_init(use_meta_tensor=True, mesh: DeviceMesh = None) -> bool:
"""Return True when this rank should create an empty model on meta device."""
if not use_meta_tensor:
return False
if mesh is None:
return torch.distributed.get_rank() != 0
return mesh.get_coordinate()[-1] != 0


def get_fsdp_wrap_policy(module, config=None, is_lora=False):
Expand Down Expand Up @@ -176,6 +169,14 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):

@torch.no_grad()
def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True):
# Materialize any leftover meta buffers (e.g. non-persistent inv_freq from
# RotaryEmbedding created via from_config on meta device). We must NOT call
# model.to_empty() because that would wipe already-loaded FSDP parameters.
for module in model.modules():
for key in list(module._buffers.keys()):
buf = module._buffers[key]
if buf is not None and buf.device.type == "meta":
module._buffers[key] = torch.empty(buf.shape, dtype=buf.dtype, device="cpu")
model.to("cpu", non_blocking=True)
if empty_cache:
torch.cuda.empty_cache()
Expand Down Expand Up @@ -247,6 +248,27 @@ def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg):
return nullcontext()


def _sync_non_persistent_buffers(model: torch.nn.Module, loaded_sd: dict):
"""Broadcast non-persistent buffers (e.g. inv_freq) from rank 0 to all ranks.

Non-persistent buffers are excluded from state_dict so they are never loaded
by the parameter broadcast loop. On non-rank-0 meta-init they remain on the
meta device with no data; rank 0 has the correctly computed values.
"""
for module in model.modules():
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
for key in sorted(non_persistent):
buf = module._buffers.get(key)
if buf is None:
continue
if dist.get_rank() == 0:
src = buf.detach().cuda()
else:
src = torch.empty(buf.shape, dtype=buf.dtype, device="cuda")
dist.broadcast(src, src=0)
module._buffers[key] = src.cpu()


# Fsdp2 load full state dict from `accelerate`
# Reference: https://github.com/huggingface/accelerate/blob/0af621bbecc0e43f5d43766a4945d3d2236bb8a9/src/accelerate/utils/fsdp_utils.py#L455
# NOTE (sumanthrh): The original code from `accelerate` assumes init on meta device - with cpu init only on rank 0, but the code is compatible with cpu init on all ranks.
Expand Down Expand Up @@ -324,6 +346,11 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
# we set `assign=True` because our params can be on meta device
model.load_state_dict(sharded_sd, assign=True)

# Broadcast non-persistent buffers (e.g. inv_freq from RotaryEmbedding) that
# are excluded from state_dict. On non-rank-0 meta-init these are still on
# meta device with no data; rank 0 has the correctly computed values.
_sync_non_persistent_buffers(model, sharded_sd)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious do you know why upgrading transformers necessitates this change? Seems a little surprising :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the claude writeup of why this was needed:

image

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing :)


# If we don't offload FSDP2 Module to CPU and then back to GPU,
# it will occupy a large amount of reserved GPU memory,which can not be released using torch.cuda.empty_cache()
# even if we are using cpu_offload
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(self, args):
async def run_server(self, **uvicorn_kwargs) -> None:
sock_addr = (self.server_args.host or "", self.server_args.port)
sock = create_server_socket(sock_addr)

set_ulimit()

def signal_handler(*_) -> None:
Expand All @@ -39,7 +38,6 @@ def signal_handler(*_) -> None:

signal.signal(signal.SIGTERM, signal_handler)

# TODO(tgriggs): Move this elsewhere, make configurable.
os.environ["VLLM_USE_V1"] = "1"
engine_args = AsyncEngineArgs.from_cli_args(self.server_args)
engine = AsyncLLMEngine.from_engine_args(
Expand Down Expand Up @@ -147,7 +145,10 @@ async def _destroy_weights_update_group(request: Request):

await shutdown_task

sock.close()
try:
sock.close()
except (AttributeError, OSError):
pass

def run_server_uvloop(self, **uvicorn_kwargs) -> None:
uvloop.run(self.run_server(**uvicorn_kwargs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ async def _send_chunks_legacy(self, chunks: Iterable[WeightChunk]) -> None:
offset = 0
for name, tensor, shape in zip(chunk.names, chunk.tensors, chunk.shapes):
size = tensor.numel()
packed_tensor[offset : offset + size].copy_(tensor.detach().view(-1))
packed_tensor[offset : offset + size].copy_(tensor.detach().reshape(-1))
offset += size
names.append(name)
dtypes.append(self._init_info.model_dtype_str)
Expand Down
130 changes: 63 additions & 67 deletions skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from skyrl.backends.skyrl_train.distributed.fsdp_strategy import FSDPStrategy
from skyrl.backends.skyrl_train.distributed.fsdp_utils import (
fsdp_version,
get_init_weight_context_manager,
should_use_meta_init,
)
from skyrl.backends.skyrl_train.training_batch import (
TrainingInputBatch,
Expand Down Expand Up @@ -165,37 +165,34 @@ def init_model(self, model_path, num_training_steps: int = None):
self._is_lora = self.cfg.policy.model.lora.rank > 0

model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
init_context = get_init_weight_context_manager(
use_meta = should_use_meta_init(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
)
with init_context():

wrapped_model = HFModelWrapper(
model_path,
use_flash_attention_2=self.cfg.flash_attn,
# NOTE (sumanthrh): Model initialization should always be in fp32
# during training
bf16=False,
lora_rank=self.cfg.policy.model.lora.rank,
lora_alpha=self.cfg.policy.model.lora.alpha,
lora_dropout=self.cfg.policy.model.lora.dropout,
lora_init_method=self.cfg.policy.model.lora.init_method,
target_modules=self.cfg.policy.model.lora.target_modules,
exclude_modules=self.cfg.policy.model.lora.exclude_modules,
sequence_parallel_size=self.cfg.policy.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
use_torch_compile=self.cfg.policy.use_torch_compile,
rope_scaling=get_rope_scaling_config(self.cfg),
rope_theta=get_rope_theta_config(self.cfg),
model_config_kwargs=self.cfg.policy.model_config_kwargs,
)
# in-place patch
self._seq_parallel_monkey_patch(model=wrapped_model.model)

if self.cfg.gradient_checkpointing:
wrapped_model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": self.cfg.gradient_checkpointing_use_reentrant}
)
wrapped_model = HFModelWrapper(
model_path,
use_flash_attention_2=self.cfg.flash_attn,
bf16=False,
lora_rank=self.cfg.policy.model.lora.rank,
lora_alpha=self.cfg.policy.model.lora.alpha,
lora_dropout=self.cfg.policy.model.lora.dropout,
lora_init_method=self.cfg.policy.model.lora.init_method,
target_modules=self.cfg.policy.model.lora.target_modules,
exclude_modules=self.cfg.policy.model.lora.exclude_modules,
sequence_parallel_size=self.cfg.policy.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
use_torch_compile=self.cfg.policy.use_torch_compile,
rope_scaling=get_rope_scaling_config(self.cfg),
rope_theta=get_rope_theta_config(self.cfg),
model_config_kwargs=self.cfg.policy.model_config_kwargs,
meta_init=use_meta,
)
self._seq_parallel_monkey_patch(model=wrapped_model.model)

if self.cfg.gradient_checkpointing:
wrapped_model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": self.cfg.gradient_checkpointing_use_reentrant}
)

self.model, self.optimizer, self.scheduler = strategy.prepare(
(wrapped_model, None, None),
Expand Down Expand Up @@ -342,34 +339,33 @@ def init_model(self, model_path, num_training_steps: int = None):
self.strategy = strategy

model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
init_context = get_init_weight_context_manager(
use_meta = should_use_meta_init(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
)
with init_context():
critic = get_llm_for_sequence_regression(
model_path,
"critic",
use_flash_attention_2=self.cfg.flash_attn,
# NOTE (sumanthrh): Model initialization should always be in fp32
# during training
bf16=False,
lora_rank=self.cfg.critic.model.lora.rank,
lora_alpha=self.cfg.critic.model.lora.alpha,
lora_dropout=self.cfg.critic.model.lora.dropout,
target_modules=self.cfg.critic.model.lora.target_modules,
exclude_modules=self.cfg.critic.model.lora.exclude_modules,
value_head_prefix=self.cfg.algorithm.value_head_prefix,
init_value_head=self.cfg.policy.model.path == self.cfg.critic.model.path,
sequence_parallel_size=self.cfg.critic.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
model_config_kwargs=self.cfg.critic.model_config_kwargs,
)
self._seq_parallel_monkey_patch(model=critic, use_parent_class=True)

if self.cfg.gradient_checkpointing:
critic.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": self.cfg.gradient_checkpointing_use_reentrant}
)
critic = get_llm_for_sequence_regression(
model_path,
"critic",
use_flash_attention_2=self.cfg.flash_attn,
bf16=False,
lora_rank=self.cfg.critic.model.lora.rank,
lora_alpha=self.cfg.critic.model.lora.alpha,
lora_dropout=self.cfg.critic.model.lora.dropout,
target_modules=self.cfg.critic.model.lora.target_modules,
exclude_modules=self.cfg.critic.model.lora.exclude_modules,
value_head_prefix=self.cfg.algorithm.value_head_prefix,
init_value_head=self.cfg.policy.model.path == self.cfg.critic.model.path,
sequence_parallel_size=self.cfg.critic.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
model_config_kwargs=self.cfg.critic.model_config_kwargs,
meta_init=use_meta,
)
self._seq_parallel_monkey_patch(model=critic, use_parent_class=True)

if self.cfg.gradient_checkpointing:
critic.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": self.cfg.gradient_checkpointing_use_reentrant}
)

# prepare models/optimizers...
self.model, self.optimizer, self.scheduler = strategy.prepare(
Expand Down Expand Up @@ -412,22 +408,22 @@ def init_model(self, model_path):
self.strategy = strategy

model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
init_context = get_init_weight_context_manager(
use_meta = should_use_meta_init(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
)

with init_context():
wrapped_model = HFModelWrapper(
model_path,
use_flash_attention_2=self.cfg.flash_attn,
bf16=self.cfg.bf16,
sequence_parallel_size=self.cfg.ref.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
rope_scaling=get_rope_scaling_config(self.cfg),
rope_theta=get_rope_theta_config(self.cfg),
model_config_kwargs=self.cfg.ref.model_config_kwargs,
)
self._seq_parallel_monkey_patch(model=wrapped_model.model)
wrapped_model = HFModelWrapper(
model_path,
use_flash_attention_2=self.cfg.flash_attn,
bf16=self.cfg.bf16,
sequence_parallel_size=self.cfg.ref.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
rope_scaling=get_rope_scaling_config(self.cfg),
rope_theta=get_rope_theta_config(self.cfg),
model_config_kwargs=self.cfg.ref.model_config_kwargs,
meta_init=use_meta,
)
self._seq_parallel_monkey_patch(model=wrapped_model.model)

self.model = strategy.prepare(wrapped_model)
self.model.eval()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,14 @@ def init_configs(
if hasattr(provider, "q_lora_rank") and hasattr(hf_config, "q_lora_rank"):
provider.q_lora_rank = hf_config.q_lora_rank

# Workaround for transformers v5 moving rope_theta into rope_parameters

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why this is needed, since megatron-bridge already updated NVIDIA-NeMo/Megatron-Bridge#2068 -- if this is still needed, should we raise an issue against megatron-bridge so we can remove this workaround going forward?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe bumping to latest main on megatron-bridge which i'm planning to do here: #1425

should fix this, but i wanted to isolate the transformers v5 update in this PR.

I can look into removing this in #1425

# (previously it was a top-level config attribute). megatron-bridge's
# CONFIG_MAPPING reads config.rope_theta which no longer exists in v5,
# causing it to fall back to the default rotary_base of 10000.
rope_params = getattr(hf_config, "rope_parameters", None) or getattr(hf_config, "rope_scaling", None)
if isinstance(rope_params, dict) and "rope_theta" in rope_params:
provider.rotary_base = rope_params["rope_theta"]

provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size
provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size
provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32
Expand Down
Loading
Loading