Skip to content

mlm training fails due to large message size for nested_gather on torch_xla #16005

@miladm

Description

@miladm

The PyTorch/XLA/TPU HF tests for mlm-bert and mlm-roberta fail as discussed below. I have extensively tested this issue on both 2VM and 1VM machines. On both machines, when I set --num_core 1, the test passes as expected, and when I set --num_core 8 I get the error below.

This error suggests the mesh_reduce API called by evaluate() > evaluation_loop > nested_xla_mesh_reduce() communicates larger than expected tensor payloads.

Reference to an older issue which sounds relevant here.

Repro command:

python3 examples/pytorch/xla_spawn.py  --num_cores 8  examples/pytorch/language-modeling/run_mlm.py  --logging_dir ./tensorboard-metric --cache_dir ./cache_dir  --dataset_name wikitext  --dataset_config_name wikitext-2-raw-v1  --do_train  --do_eval  --overwrite_output_dir  --output_dir language-modeling  --logging_steps 30  --save_steps 3000  --overwrite_cache  --tpu_metrics_debug  --model_type=bert --tokenizer=bert-base-cased --num_train_epochs 1 --per_device_train_batch_size 16 --per_device_eval_batch_size 4

Error message:

***** train metrics *****
  epoch                    =        1.0
  train_loss               =      8.969
  train_runtime            = 0:02:58.03
  train_samples            =       4771
  train_samples_per_second =     26.798
  train_steps_per_second   =      0.213
03/09/2022 03:22:36 - INFO - run_mlm - *** Evaluate ***
[INFO|trainer.py:570] 2022-03-09 03:22:36,278 >> The following columns in the evaluation set  don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: special_tokens_mask. If special_tokens_mask are not expected by `BertForMaskedLM.forward`,  you can safely ignore this message.
[INFO|trainer.py:2403] 2022-03-09 03:22:36,281 >> ***** Running Evaluation *****
[INFO|trainer.py:2405] 2022-03-09 03:22:36,281 >>   Num examples = 493
[INFO|trainer.py:2408] 2022-03-09 03:22:36,281 >>   Batch size = 2
Exception in device=TPU:7: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
Exception in device=TPU:2: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
Exception in device=TPU:0: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
Exception in device=TPU:3: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)

...

Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 582, in _mp_fn
    main()
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 545, in main
    metrics = trainer.evaluate()
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2271, in evaluate
    output = eval_loop(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2460, in evaluation_loop
    logits = self._nested_gather(logits)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2546, in _nested_gather
    tensors = nested_xla_mesh_reduce(tensors, name)
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 582, in _mp_fn
    main()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 163, in nested_xla_mesh_reduce
    return xm.mesh_reduce(name, tensors, torch.cat)
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 545, in main
    metrics = trainer.evaluate()
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 582, in _mp_fn
    main()
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 545, in main
    metrics = trainer.evaluate()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 974, in mesh_reduce
    xdata = rendezvous(tag, bio.getvalue())
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2271, in evaluate
    output = eval_loop(
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2271, in evaluate
    output = eval_loop(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 926, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2460, in evaluation_loop
    logits = self._nested_gather(logits)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2460, in evaluation_loop
    logits = self._nested_gather(logits)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2546, in _nested_gather
    tensors = nested_xla_mesh_reduce(tensors, name)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 163, in nested_xla_mesh_reduce
    return xm.mesh_reduce(name, tensors, torch.cat)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2546, in _nested_gather
    tensors = nested_xla_mesh_reduce(tensors, name)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 163, in nested_xla_mesh_reduce
    return xm.mesh_reduce(name, tensors, torch.cat)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 974, in mesh_reduce
    xdata = rendezvous(tag, bio.getvalue())
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 974, in mesh_reduce
    xdata = rendezvous(tag, bio.getvalue())
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 926, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 926, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
Traceback (most recent call last):

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions