Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from verl.utils.py_functional import convert_to_regular_types
from verl.utils.torch_dtypes import PrecisionType
from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup
from verl.utils.attention_imports import index_first_axis, pad_input, rearrange, unpad_input
from verl.utils.tracking import Tracking
from verl.utils.ulysses import (
gather_outputs_and_unpad,
Expand All @@ -74,11 +75,6 @@
)
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

if is_cuda_available:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
elif is_npu_available:
from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN"))

Expand Down
33 changes: 33 additions & 0 deletions verl/utils/attention_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from verl.utils.device import is_cuda_available, is_npu_available

if is_cuda_available:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
elif is_npu_available:
try:
from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input
except ImportError:
# Since transformers v4.55.1, index_first_axis, pad_input, and unpad_input
# have been consolidated into `transformers.modeling_flash_attention_utils`.
from einops import rearrange
from transformers.modeling_flash_attention_utils import _index_first_axis as index_first_axis
from transformers.modeling_flash_attention_utils import _pad_input as pad_input
from transformers.modeling_flash_attention_utils import _unpad_input as unpad_input
else:
raise RuntimeError("Unsupported device type")


__all__ = ["index_first_axis", "pad_input", "rearrange", "unpad_input"]
20 changes: 12 additions & 8 deletions verl/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,18 @@
logger = logging.getLogger(__name__)


def is_torch_npu_available() -> bool:
is_cuda_available = torch.cuda.is_available()

def is_npu_available() -> bool:
"""Check the availability of NPU"""
try:
import torch_npu # noqa: F401

return torch.npu.is_available()
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
return torch.npu.is_available()
return False
Comment on lines +21 to +23

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this modification :)

except ImportError:
return False


is_cuda_available = torch.cuda.is_available()
is_npu_available = is_torch_npu_available()


def get_visible_devices_keyword() -> str:
"""Function that gets visible devices keyword name.
Returns:
Expand Down Expand Up @@ -93,3 +91,9 @@ def set_expandable_segments(enable: bool) -> None:
"""
if is_cuda_available:
torch.cuda.memory._set_allocator_settings(f"expandable_segments:{enable}")


def __getattr__(name):
if name == "is_npu_available":
return is_npu_available()
raise AttributeError(f"module {__name__} has no attribute {name}")
14 changes: 1 addition & 13 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,10 @@
from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
from verl.utils.attention_imports import index_first_axis, pad_input, rearrange, unpad_input
from verl.workers.actor import BasePPOActor
from verl.workers.config import ActorConfig

if is_cuda_available:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
elif is_npu_available:
try:
from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input
except ImportError:
# Since transformers v4.55.1, index_first_axis, pad_input, and unpad_input
# have been consolidated into `transformers.modeling_flash_attention_utils`.
from einops import rearrange
from transformers.modeling_flash_attention_utils import _index_first_axis as index_first_axis
from transformers.modeling_flash_attention_utils import _pad_input as pad_input
from transformers.modeling_flash_attention_utils import _unpad_input as unpad_input


__all__ = ["DataParallelPPOActor"]

Expand Down
6 changes: 1 addition & 5 deletions verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@
from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
from verl.utils.torch_functional import masked_mean
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
from verl.utils.attention_imports import index_first_axis, pad_input, rearrange, unpad_input
from verl.workers.critic import BasePPOCritic

if is_cuda_available:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
elif is_npu_available:
from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

Expand Down
6 changes: 1 addition & 5 deletions verl/workers/engine/fsdp/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,9 @@
from verl.utils.py_functional import convert_to_regular_types
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs
from verl.utils.attention_imports import index_first_axis, pad_input, rearrange, unpad_input
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

if is_cuda_available:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
elif is_npu_available:
from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input

from verl.trainer.config import CheckpointConfig
from verl.workers.config import FSDPEngineConfig, FSDPOptimizerConfig, HFModelConfig

Expand Down