Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"rich>=14.1.0",
"safetensors>=0.6.2",
"tokenizers>=0.21.2",
"transformers>=4.56.1,<5",
"transformers>=5.0.0,<=5.3.0",
"typer>=0.17.4",
# "wandb>=0.22.0",
"peft",
Expand Down Expand Up @@ -72,7 +72,6 @@ skyrl-train = [
"ninja",
"tensorboard",
"func_timeout",
"transformers>=4.51.0",
"hydra-core==1.3.2",
"accelerate",
"torchdata",
Expand Down Expand Up @@ -217,6 +216,7 @@ override-dependencies = [
"causal-conv1d; sys_platform == 'never'",
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"megatron-core==0.16.1; sys_platform == 'linux'",
"transformers>=5.0.0,<=5.3.0; sys_platform == 'linux'",
"ml_dtypes>=0.5.0; sys_platform == 'linux'",
]

Expand Down
55 changes: 41 additions & 14 deletions skyrl/backends/skyrl_train/distributed/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,13 @@ def init_fn(x: torch.nn.Module):
return x


def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None):
from accelerate import init_empty_weights

def cpu_init_weights():
return torch.device("cpu")

if use_meta_tensor:
if mesh is None:
init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights
else:
init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights
else:
init_context = cpu_init_weights
return init_context
def should_use_meta_init(use_meta_tensor=True, mesh: DeviceMesh = None) -> bool:
"""Return True when this rank should create an empty model on meta device."""
if not use_meta_tensor:
return False
if mesh is None:
return torch.distributed.get_rank() != 0
return mesh.get_coordinate()[-1] != 0


def get_fsdp_wrap_policy(module, config=None, is_lora=False):
Expand Down Expand Up @@ -176,6 +169,14 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):

@torch.no_grad()
def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True):
# Materialize any leftover meta buffers (e.g. non-persistent inv_freq from
# RotaryEmbedding created via from_config on meta device). We must NOT call
# model.to_empty() because that would wipe already-loaded FSDP parameters.
for module in model.modules():
for key in list(module._buffers.keys()):
buf = module._buffers[key]
if buf is not None and buf.device.type == "meta":
module._buffers[key] = torch.empty(buf.shape, dtype=buf.dtype, device="cpu")
model.to("cpu", non_blocking=True)
if empty_cache:
torch.cuda.empty_cache()
Expand Down Expand Up @@ -247,6 +248,27 @@ def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg):
return nullcontext()


def _sync_non_persistent_buffers(model: torch.nn.Module, loaded_sd: dict):
"""Broadcast non-persistent buffers (e.g. inv_freq) from rank 0 to all ranks.

Non-persistent buffers are excluded from state_dict so they are never loaded
by the parameter broadcast loop. On non-rank-0 meta-init they remain on the
meta device with no data; rank 0 has the correctly computed values.
"""
for module in model.modules():
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
for key in sorted(non_persistent):
buf = module._buffers.get(key)
if buf is None:
continue
if dist.get_rank() == 0:
src = buf.detach().cuda()
else:
src = torch.empty(buf.shape, dtype=buf.dtype, device="cuda")
dist.broadcast(src, src=0)
module._buffers[key] = src.cpu()


# Fsdp2 load full state dict from `accelerate`
# Reference: https://github.com/huggingface/accelerate/blob/0af621bbecc0e43f5d43766a4945d3d2236bb8a9/src/accelerate/utils/fsdp_utils.py#L455
# NOTE (sumanthrh): The original code from `accelerate` assumes init on meta device - with cpu init only on rank 0, but the code is compatible with cpu init on all ranks.
Expand Down Expand Up @@ -324,6 +346,11 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
# we set `assign=True` because our params can be on meta device
model.load_state_dict(sharded_sd, assign=True)

# Broadcast non-persistent buffers (e.g. inv_freq from RotaryEmbedding) that
# are excluded from state_dict. On non-rank-0 meta-init these are still on
# meta device with no data; rank 0 has the correctly computed values.
_sync_non_persistent_buffers(model, sharded_sd)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious do you know why upgrading transformers necessitates this change? Seems a little surprising :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the claude writeup of why this was needed:

image

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing :)


# If we don't offload FSDP2 Module to CPU and then back to GPU,
# it will occupy a large amount of reserved GPU memory,which can not be released using torch.cuda.empty_cache()
# even if we are using cpu_offload
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(self, args):
async def run_server(self, **uvicorn_kwargs) -> None:
sock_addr = (self.server_args.host or "", self.server_args.port)
sock = create_server_socket(sock_addr)

set_ulimit()

def signal_handler(*_) -> None:
Expand All @@ -39,7 +38,6 @@ def signal_handler(*_) -> None:

signal.signal(signal.SIGTERM, signal_handler)

# TODO(tgriggs): Move this elsewhere, make configurable.
os.environ["VLLM_USE_V1"] = "1"
engine_args = AsyncEngineArgs.from_cli_args(self.server_args)
engine = AsyncLLMEngine.from_engine_args(
Expand Down Expand Up @@ -147,7 +145,10 @@ async def _destroy_weights_update_group(request: Request):

await shutdown_task

sock.close()
try:
sock.close()
except (AttributeError, OSError):
pass

def run_server_uvloop(self, **uvicorn_kwargs) -> None:
uvloop.run(self.run_server(**uvicorn_kwargs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ async def _send_chunks_legacy(self, chunks: Iterable[WeightChunk]) -> None:
offset = 0
for name, tensor, shape in zip(chunk.names, chunk.tensors, chunk.shapes):
size = tensor.numel()
packed_tensor[offset : offset + size].copy_(tensor.detach().view(-1))
packed_tensor[offset : offset + size].copy_(tensor.detach().reshape(-1))
offset += size
names.append(name)
dtypes.append(self._init_info.model_dtype_str)
Expand Down
130 changes: 63 additions & 67 deletions skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from skyrl.backends.skyrl_train.distributed.fsdp_strategy import FSDPStrategy
from skyrl.backends.skyrl_train.distributed.fsdp_utils import (
fsdp_version,
get_init_weight_context_manager,
should_use_meta_init,
)
from skyrl.backends.skyrl_train.training_batch import (
TrainingInputBatch,
Expand Down Expand Up @@ -165,37 +165,34 @@ def init_model(self, model_path, num_training_steps: int = None):
self._is_lora = self.cfg.policy.model.lora.rank > 0

model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
init_context = get_init_weight_context_manager(
use_meta = should_use_meta_init(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
)
with init_context():

wrapped_model = HFModelWrapper(
model_path,
use_flash_attention_2=self.cfg.flash_attn,
# NOTE (sumanthrh): Model initialization should always be in fp32
# during training
bf16=False,
lora_rank=self.cfg.policy.model.lora.rank,
lora_alpha=self.cfg.policy.model.lora.alpha,
lora_dropout=self.cfg.policy.model.lora.dropout,
lora_init_method=self.cfg.policy.model.lora.init_method,
target_modules=self.cfg.policy.model.lora.target_modules,
exclude_modules=self.cfg.policy.model.lora.exclude_modules,
sequence_parallel_size=self.cfg.policy.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
use_torch_compile=self.cfg.policy.use_torch_compile,
rope_scaling=get_rope_scaling_config(self.cfg),
rope_theta=get_rope_theta_config(self.cfg),
model_config_kwargs=self.cfg.policy.model_config_kwargs,
)
# in-place patch
self._seq_parallel_monkey_patch(model=wrapped_model.model)

if self.cfg.gradient_checkpointing:
wrapped_model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": self.cfg.gradient_checkpointing_use_reentrant}
)
wrapped_model = HFModelWrapper(
model_path,
use_flash_attention_2=self.cfg.flash_attn,
bf16=False,
lora_rank=self.cfg.policy.model.lora.rank,
lora_alpha=self.cfg.policy.model.lora.alpha,
lora_dropout=self.cfg.policy.model.lora.dropout,
lora_init_method=self.cfg.policy.model.lora.init_method,
target_modules=self.cfg.policy.model.lora.target_modules,
exclude_modules=self.cfg.policy.model.lora.exclude_modules,
sequence_parallel_size=self.cfg.policy.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
use_torch_compile=self.cfg.policy.use_torch_compile,
rope_scaling=get_rope_scaling_config(self.cfg),
rope_theta=get_rope_theta_config(self.cfg),
model_config_kwargs=self.cfg.policy.model_config_kwargs,
meta_init=use_meta,
)
self._seq_parallel_monkey_patch(model=wrapped_model.model)

if self.cfg.gradient_checkpointing:
wrapped_model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": self.cfg.gradient_checkpointing_use_reentrant}
)

self.model, self.optimizer, self.scheduler = strategy.prepare(
(wrapped_model, None, None),
Expand Down Expand Up @@ -342,34 +339,33 @@ def init_model(self, model_path, num_training_steps: int = None):
self.strategy = strategy

model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
init_context = get_init_weight_context_manager(
use_meta = should_use_meta_init(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
)
with init_context():
critic = get_llm_for_sequence_regression(
model_path,
"critic",
use_flash_attention_2=self.cfg.flash_attn,
# NOTE (sumanthrh): Model initialization should always be in fp32
# during training
bf16=False,
lora_rank=self.cfg.critic.model.lora.rank,
lora_alpha=self.cfg.critic.model.lora.alpha,
lora_dropout=self.cfg.critic.model.lora.dropout,
target_modules=self.cfg.critic.model.lora.target_modules,
exclude_modules=self.cfg.critic.model.lora.exclude_modules,
value_head_prefix=self.cfg.algorithm.value_head_prefix,
init_value_head=self.cfg.policy.model.path == self.cfg.critic.model.path,
sequence_parallel_size=self.cfg.critic.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
model_config_kwargs=self.cfg.critic.model_config_kwargs,
)
self._seq_parallel_monkey_patch(model=critic, use_parent_class=True)

if self.cfg.gradient_checkpointing:
critic.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": self.cfg.gradient_checkpointing_use_reentrant}
)
critic = get_llm_for_sequence_regression(
model_path,
"critic",
use_flash_attention_2=self.cfg.flash_attn,
bf16=False,
lora_rank=self.cfg.critic.model.lora.rank,
lora_alpha=self.cfg.critic.model.lora.alpha,
lora_dropout=self.cfg.critic.model.lora.dropout,
target_modules=self.cfg.critic.model.lora.target_modules,
exclude_modules=self.cfg.critic.model.lora.exclude_modules,
value_head_prefix=self.cfg.algorithm.value_head_prefix,
init_value_head=self.cfg.policy.model.path == self.cfg.critic.model.path,
sequence_parallel_size=self.cfg.critic.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
model_config_kwargs=self.cfg.critic.model_config_kwargs,
meta_init=use_meta,
)
self._seq_parallel_monkey_patch(model=critic, use_parent_class=True)

if self.cfg.gradient_checkpointing:
critic.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": self.cfg.gradient_checkpointing_use_reentrant}
)

# prepare models/optimizers...
self.model, self.optimizer, self.scheduler = strategy.prepare(
Expand Down Expand Up @@ -412,22 +408,22 @@ def init_model(self, model_path):
self.strategy = strategy

model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
init_context = get_init_weight_context_manager(
use_meta = should_use_meta_init(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
)

with init_context():
wrapped_model = HFModelWrapper(
model_path,
use_flash_attention_2=self.cfg.flash_attn,
bf16=self.cfg.bf16,
sequence_parallel_size=self.cfg.ref.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
rope_scaling=get_rope_scaling_config(self.cfg),
rope_theta=get_rope_theta_config(self.cfg),
model_config_kwargs=self.cfg.ref.model_config_kwargs,
)
self._seq_parallel_monkey_patch(model=wrapped_model.model)
wrapped_model = HFModelWrapper(
model_path,
use_flash_attention_2=self.cfg.flash_attn,
bf16=self.cfg.bf16,
sequence_parallel_size=self.cfg.ref.sequence_parallel_size,
use_sample_packing=self.cfg.use_sample_packing,
rope_scaling=get_rope_scaling_config(self.cfg),
rope_theta=get_rope_theta_config(self.cfg),
model_config_kwargs=self.cfg.ref.model_config_kwargs,
meta_init=use_meta,
)
self._seq_parallel_monkey_patch(model=wrapped_model.model)

self.model = strategy.prepare(wrapped_model)
self.model.eval()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,14 @@ def init_configs(
if hasattr(provider, "q_lora_rank") and hasattr(hf_config, "q_lora_rank"):
provider.q_lora_rank = hf_config.q_lora_rank

# Workaround for transformers v5 moving rope_theta into rope_parameters
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why this is needed, since megatron-bridge already updated NVIDIA-NeMo/Megatron-Bridge#2068 -- if this is still needed, should we raise an issue against megatron-bridge so we can remove this workaround going forward?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe bumping to latest main on megatron-bridge which i'm planning to do here: #1425

should fix this, but i wanted to isolate the transformers v5 update in this PR.

I can look into removing this in #1425

# (previously it was a top-level config attribute). megatron-bridge's
# CONFIG_MAPPING reads config.rope_theta which no longer exists in v5,
# causing it to fall back to the default rotary_base of 10000.
rope_params = getattr(hf_config, "rope_parameters", None) or getattr(hf_config, "rope_scaling", None)
if isinstance(rope_params, dict) and "rope_theta" in rope_params:
provider.rotary_base = rope_params["rope_theta"]

provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size
provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size
provider.pipeline_dtype = torch.bfloat16 if bf16 else torch.float32
Expand Down
Loading
Loading