diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 99184725112..edb3c896d57 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -40,7 +40,15 @@ if is_cuda_available: from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input elif is_npu_available: - from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + try: + from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + except ImportError: + # Since transformers v4.55.1, index_first_axis, pad_input, and unpad_input + # have been consolidated into `transformers.modeling_flash_attention_utils`. + from einops import rearrange + from transformers.modeling_flash_attention_utils import _index_first_axis as index_first_axis + from transformers.modeling_flash_attention_utils import _pad_input as pad_input + from transformers.modeling_flash_attention_utils import _unpad_input as unpad_input __all__ = ["DataParallelPPOActor"]