Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ray Train] Fine-tune dolly-v2-7b with Ray Train, PyTorch Lightning and FSDP - Dimension mismatch error running it #50145

Open
anindya-saha opened this issue Jan 30, 2025 · 0 comments
Labels
bug Something that is supposed to be working; but isn't train Ray Train Related Issue triage Needs triage (eg: priority, bug/not-bug, and owning component)

Comments

@anindya-saha
Copy link

What happened + What you expected to happen

I am following the Fine-tune dolly-v2-7b with Ray Train, PyTorch Lightning and FSDP tutorial at https://docs.ray.io/en/latest/train/examples/lightning/dolly_lightning_fsdp_finetuning.html

I am getting the following dimension mismatch error from the official example:

ataset execution finished in 0.06 seconds: 100%|██████████| 10.0/10.0 [00:00<00:00, 167 row/s] 
- MapBatches(split_text): Tasks: 0; Queued blocks: 0; Resources: 0.0 CPU, 0.0B object store: : 21.6k row [00:00, 381k row/s]
- limit=10: Tasks: 0; Queued blocks: 0; Resources: 0.0 CPU, 1.0KB object store: 100%|██████████| 10.0/10.0 [00:00<00:00, 175 row/s]
/usr/local/lib/python3.10/site-packages/lightning_fabric/strategies/fsdp.py:700: `FSDPStrategy(activation_checkpointing=[<class 'transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer'>])` is deprecated, use `FSDPStrategy(activation_checkpointing_policy={<class 'transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer'>})` instead.
(RayTrainWorker pid=3398) Setting up process group for: env:// [rank=0, world_size=4]
(TorchTrainer pid=3280) Started distributed worker processes: 
(TorchTrainer pid=3280) - (node_id=5283a4e22de757e31d16a4fb1d7e69a38ad8de9989e8333dc32a4f2e, ip=192.168.144.2, pid=3398) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=3280) - (node_id=5283a4e22de757e31d16a4fb1d7e69a38ad8de9989e8333dc32a4f2e, ip=192.168.144.2, pid=3399) world_rank=1, local_rank=1, node_rank=0
(TorchTrainer pid=3280) - (node_id=5283a4e22de757e31d16a4fb1d7e69a38ad8de9989e8333dc32a4f2e, ip=192.168.144.2, pid=3401) world_rank=2, local_rank=2, node_rank=0
(TorchTrainer pid=3280) - (node_id=5283a4e22de757e31d16a4fb1d7e69a38ad8de9989e8333dc32a4f2e, ip=192.168.144.2, pid=3400) world_rank=3, local_rank=3, node_rank=0
(RayTrainWorker pid=3398) GPU available: True (cuda), used: True
(RayTrainWorker pid=3398) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=3398) HPU available: False, using: 0 HPUs
(RayTrainWorker pid=3398) You are using a CUDA device ('NVIDIA A10G') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
(RayTrainWorker pid=3398) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
(RayTrainWorker pid=3398) 
(RayTrainWorker pid=3398)   | Name  | Type               | Params | Mode
(RayTrainWorker pid=3398) ----------------------------------------------------
(RayTrainWorker pid=3398) 0 | model | GPTNeoXForCausalLM | 1.7 B  | eval
(RayTrainWorker pid=3398) ----------------------------------------------------
(RayTrainWorker pid=3398) 1.7 B     Trainable params
(RayTrainWorker pid=3398) 0         Non-trainable params
(RayTrainWorker pid=3398) 1.7 B     Total params
(RayTrainWorker pid=3398) 6,856.057 Total estimated model params size (MB)
(RayTrainWorker pid=3398) 64        Modules in train mode
(RayTrainWorker pid=3398) 456       Modules in eval mode
(SplitCoordinator pid=4240) Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-01-30_23-05-40_705157_7/logs/ray-data
(SplitCoordinator pid=4240) Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(split_text)->MapBatches(tokenize)] -> OutputSplitter[split(4, equal=True)]
(pid=4240) Running 0: 0.00 row [00:00, ? row/s][{'text': 'Before we proceed any further, hear me speak.'}, {'text': 'Speak, speak.'}, {'text': 'You are all resolved rather to die than to famish?'}, {'text': 'Resolved. resolved.'}, {'text': 'First, you know Caius Marcius is chief enemy to the people.'}, {'text': "We know't, we know't."}, {'text': "Let us kill him, and we'll have corn at our own price."}, {'text': "Is't a verdict?"}, {'text': "No more talking on't; let it be done: away, away!"}, {'text': 'One word, good citizens.'}]
MapBatches(tokenize)
+- MapBatches(split_text)
   +- Dataset(num_rows=1, schema={text: string})

View detailed results here: /mnt/ray_experiments/ray_results/finetune_dolly-v2-7b
To visualize your results with TensorBoard, run: `tensorboard --logdir /tmp/ray/session_2025-01-30_23-05-40_705157_7/artifacts/2025-01-30_23-05-47/finetune_dolly-v2-7b/driver_artifacts`

Training started with configuration:
╭────────────────────────────────────────────────────────────────╮
│ Training config                                                │
├────────────────────────────────────────────────────────────────┤
│ train_loop_config/batch_size_per_worker                     10 │
│ train_loop_config/eps                                    1e-08 │
│ train_loop_config/lr                                     2e-05 │
│ train_loop_config/strategy                ...t 0x7fbe86ec9fc0> │
╰────────────────────────────────────────────────────────────────╯
(RayTrainWorker pid=3398) FullyShardedDataParallel(
(RayTrainWorker pid=3398)   (_fsdp_wrapped_module): DollyV2Model(
(RayTrainWorker pid=3398)     (model): GPTNeoXForCausalLM(
(RayTrainWorker pid=3398)       (gpt_neox): GPTNeoXModel(
(RayTrainWorker pid=3398)         (embed_in): Embedding(50280, 4096)
(RayTrainWorker pid=3398)         (emb_dropout): Dropout(p=0.0, inplace=False)
(RayTrainWorker pid=3398)         (layers): ModuleList(
(RayTrainWorker pid=3398)           (0-31): 32 x FullyShardedDataParallel(
(RayTrainWorker pid=3398)             (_fsdp_wrapped_module): CheckpointWrapper(
(RayTrainWorker pid=3398)               (_checkpoint_wrapped_module): GPTNeoXLayer(
(RayTrainWorker pid=3398)                 (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(RayTrainWorker pid=3398)                 (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(RayTrainWorker pid=3398)                 (post_attention_dropout): Dropout(p=0.0, inplace=False)
(RayTrainWorker pid=3398)                 (post_mlp_dropout): Dropout(p=0.0, inplace=False)
(RayTrainWorker pid=3398)                 (attention): GPTNeoXSdpaAttention(
(RayTrainWorker pid=3398)                   (rotary_emb): GPTNeoXRotaryEmbedding()
(RayTrainWorker pid=3398)                   (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
(RayTrainWorker pid=3398)                   (dense): Linear(in_features=4096, out_features=4096, bias=True)
(RayTrainWorker pid=3398)                   (attention_dropout): Dropout(p=0.0, inplace=False)
(RayTrainWorker pid=3398)                 )
(RayTrainWorker pid=3398)                 (mlp): GPTNeoXMLP(
(RayTrainWorker pid=3398)                   (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
(RayTrainWorker pid=3398)                   (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
(RayTrainWorker pid=3398)                   (act): GELUActivation()
(RayTrainWorker pid=3398)                 )
(RayTrainWorker pid=3398)               )
(RayTrainWorker pid=3398)             )
(RayTrainWorker pid=3398)           )
(RayTrainWorker pid=3398)         )
(RayTrainWorker pid=3398)         (final_layer_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(RayTrainWorker pid=3398)         (rotary_emb): GPTNeoXRotaryEmbedding()
(RayTrainWorker pid=3398)       )
(RayTrainWorker pid=3398)       (embed_out): Linear(in_features=4096, out_features=50280, bias=False)
(RayTrainWorker pid=3398)     )
(RayTrainWorker pid=3398)   )
(RayTrainWorker pid=3398) )
(pid=4240) ✔️  Dataset execution finished in 9.63 seconds: 100%|██████████| 21.6k/21.6k [00:09<00:00, 2.25k row/s]         ]
(pid=4240) - MapBatches(split_text)->MapBatches(tokenize): Tasks: 0; Queued blocks: 0; Resources: 0.0 CPU, 0.0B object store: : 21.6k row [00:09, 2.25k row/s]
(pid=4240) - split(4, equal=True): Tasks: 0; Queued blocks: 0; Resources: 0.0 CPU, 126.7MB object store; [locality disabled]: 100%|██████████| 21.6k/21.6k [00:09<00:00, 2.25k row/s]
2025-01-30 23:06:33,470 ERROR tune_controller.py:1331 -- Trial task failed for trial TorchTrainer_bda8c_00000                                         
Traceback (most recent call last):                                                                                                                                                   
  File "/usr/local/lib/python3.10/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/usr/local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/ray/_private/worker.py", line 2745, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/site-packages/ray/_private/worker.py", line 901, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::_Inner.train() (pid=3280, ip=192.168.144.2, actor_id=bdf864e50095bc0813d8511001000000, repr=TorchTrainer)
  File "/usr/local/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "/usr/local/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 57, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(RuntimeError): ray::_RayTrainWorker__execute.get_next() (pid=3401, ip=192.168.144.2, actor_id=a0dc5ee4d6d01e161a447dd701000000, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7fd48693c040>)
  File "/usr/local/lib/python3.10/site-packages/ray/train/_internal/worker_group.py", line 33, in __execute
    raise skipped from exception_cause(skipped)
  File "/usr/local/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 176, in discard_return_wrapper
    train_func(*args, **kwargs)
  File "/workspace/production/fine_tune_dolly/pipeline.py", line 157, in train_func
    pl_trainer.fit(model, train_dataloaders=train_dataloader)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1025, in _run_stage
    self.fit_loop.run()
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 250, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 190, in run
    self._optimizer_step(batch_idx, closure)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 268, in _optimizer_step
    call._call_lightning_module_hook(
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 167, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1306, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 153, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 238, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/fsdp.py", line 150, in optimizer_step
    closure_result = closure()
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 144, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 138, in closure
    self._backward_fn(step_output.closure_loss)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 239, in backward_fn
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 212, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 72, in backward
    model.backward(tensor, *args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1101, in backward
    loss.backward(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1125, in unpack_hook
    frame.recompute_fn(*args)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1519, in recompute_fn
    fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 704, in forward
    attention_layer_outputs = self.attention(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 487, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
RuntimeError: The expanded size of the tensor (512) must match the existing size (256) at non-singleton dimension 3.  Target sizes: [10, 32, 256, 512].  Tensor sizes: [10, 1, 256, 256]
2025-01-30 23:06:33,477 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/mnt/ray_experiments/ray_results/finetune_dolly-v2-7b' in 0.0032s.
2025-01-30 23:06:33,479 ERROR tune.py:1037 -- Trials did not complete: [TorchTrainer_bda8c_00000]
Epoch 0: |          | 0/? [00:00<?, ?it/s] 

Training errored after 0 iterations at 2025-01-30 23:06:33. Total running time: 46s
Error file: /tmp/ray/session_2025-01-30_23-05-40_705157_7/artifacts/2025-01-30_23-05-47/finetune_dolly-v2-7b/driver_artifacts/TorchTrainer_bda8c_00000_0_2025-01-30_23-05-47/error.txt

ray.exceptions.RayTaskError(RuntimeError): ray::_Inner.train() (pid=3280, ip=192.168.144.2, actor_id=bdf864e50095bc0813d8511001000000, repr=TorchTrainer)
  File "/usr/local/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "/usr/local/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 57, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(RuntimeError): ray::_RayTrainWorker__execute.get_next() (pid=3401, ip=192.168.144.2, actor_id=a0dc5ee4d6d01e161a447dd701000000, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7fd48693c040>)
  File "/usr/local/lib/python3.10/site-packages/ray/train/_internal/worker_group.py", line 33, in __execute
    raise skipped from exception_cause(skipped)
  File "/usr/local/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 176, in discard_return_wrapper
    train_func(*args, **kwargs)
  File "/workspace/production/fine_tune_dolly/pipeline.py", line 157, in train_func
    pl_trainer.fit(model, train_dataloaders=train_dataloader)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1025, in _run_stage
    self.fit_loop.run()
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 250, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 190, in run
    self._optimizer_step(batch_idx, closure)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 268, in _optimizer_step
    call._call_lightning_module_hook(
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 167, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1306, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 153, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 238, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/fsdp.py", line 150, in optimizer_step
    closure_result = closure()
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 144, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 138, in closure
    self._backward_fn(step_output.closure_loss)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 239, in backward_fn
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 212, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py", line 72, in backward
    model.backward(tensor, *args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1101, in backward
    loss.backward(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1125, in unpack_hook
    frame.recompute_fn(*args)
  File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1519, in recompute_fn
    fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 704, in forward
    attention_layer_outputs = self.attention(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 487, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
RuntimeError: The expanded size of the tensor (512) must match the existing size (256) at non-singleton dimension 3.  Target sizes: [10, 32, 256, 512].  Tensor sizes: [10, 1, 256, 256]

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/production/fine_tune_dolly/pipeline.py", line 190, in <module>
    result = trainer.fit()
  File "/usr/local/lib/python3.10/site-packages/ray/train/base_trainer.py", line 638, in fit
    raise TrainingFailedError(
ray.train.base_trainer.TrainingFailedError: The Ray Train run failed. Please inspect the previous error messages for a cause. After fixing the issue (assuming that the error is not caused by your own application logic, but rather an error such as OOM), you can restart the run from scratch or continue this run.
To continue this run, you can use: `trainer = TorchTrainer.restore("/mnt/ray_experiments/ray_results/finetune_dolly-v2-7b")`.
To start a new run that will retry on training failures, set `train.RunConfig(failure_config=train.FailureConfig(max_failures))` in the Trainer's `run_config` with `max_failures > 0`, or `max_failures = -1` for unlimited retries.
(RayTrainWorker pid=3400) LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3] [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)

Versions / Dependencies

# all common requirements
--extra-index-url https://download.pytorch.org/whl/cu121
datasets==3.1.0
diffusers==0.31.0
evaluate==0.4.3
kubernetes==31.0.0
matplotlib
mlflow==2.18.0
numpy
omegaconf
pandas==2.2.3
psutil
pynvml
pytorch-lightning==2.4.0
ray[default,tune,rllib]==2.38.0
scikit-learn
sentencepiece
structlog
torch==2.5.0+cu121
torchmetrics==1.6.0
torchvision==0.20.0+cu121
transformers==4.46.3
deepspeed==0.16.0
accelerate==1.1.1

Reproduction script

Below is my script from the tutorial:

# https://docs.ray.io/en/latest/train/examples/lightning/dolly_lightning_fsdp_finetuning.html
import json
import re

import pandas as pd
import pytorch_lightning as pl
import torch
import transformers
from datasets import concatenate_datasets, load_dataset
from deepspeed.ops.adam import DeepSpeedCPUAdam
from transformers import AutoModelForCausalLM, AutoTokenizer

import ray

NUM_WORKERS = 4
CPUS_PER_WORKER = 10

MODEL_NAME = "databricks/dolly-v2-7b"
STORAGE_PATH = "/mnt/ray_experiments/ray_results"
EXPERIMENT_NAME = "fine-tune-vicuna-13b-deepspeed"


def split_text(batch: pd.DataFrame) -> pd.DataFrame:
    text = list(batch["text"])
    flat_text = "".join(text)
    split_text = [
        x.strip()
        for x in flat_text.split("\n")
        if x.strip() and not x.strip()[-1] == ":"
    ]
    return pd.DataFrame(split_text, columns=["text"])


def tokenize(batch: pd.DataFrame) -> dict:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    ret = tokenizer(
        list(batch["text"]),
        truncation=True,
        max_length=256,
        padding="max_length",
        return_tensors="np",
    )
    ret["labels"] = ret["input_ids"].copy()
    return dict(ret)


ray.init()

# load the dataset from huggingface
hf_dataset = load_dataset("tiny_shakespeare", trust_remote_code=True)

# convert it into ray dataset
train_ds = ray.data.from_huggingface(hf_dataset["train"])

# First split the dataset into multiple sentences.
train_ds = train_ds.map_batches(split_text, batch_format="pandas")
print(train_ds.take(10))

# Then tokenize the dataset.
train_ds = train_ds.map_batches(tokenize, batch_format="pandas")
print(train_ds)


# Define the Lightning Model
class DollyV2Model(pl.LightningModule):
    def __init__(self, lr=2e-5, eps=1e-8):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        self.eps = eps
        self.model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

    def forward(self, batch):
        outputs = self.model(
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        return outputs.loss

    def training_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log("train_loss", loss, prog_bar=True, on_step=True)
        return loss

    def configure_optimizers(self):
        if self.global_rank == 0:
            print(self.trainer.model)
        return torch.optim.AdamW(
            self.trainer.model.parameters(), lr=self.lr, eps=self.eps
        )


import functools

import pytorch_lightning as pl
import ray.train
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
from ray.train.lightning import (
    RayFSDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)
from ray.train.torch import TorchTrainer
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer

# Define the model sharding policy:
# Wrap every GPTNeoXLayer as its own FSDP instance
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy, transformer_layer_cls={GPTNeoXLayer}
)

fsdp_strategy = RayFSDPStrategy(
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
    forward_prefetch=True,
    auto_wrap_policy=auto_wrap_policy,
    limit_all_gathers=True,
    activation_checkpointing=[GPTNeoXLayer],
)


def train_func(config):
    """Training function for each worker"""

    # upack the `train_loop_config`
    lr = config["lr"]
    eps = config["eps"]
    strategy = config["strategy"]
    batch_size_per_worker = config["batch_size_per_worker"]
    # accumulate_grad_batches = config["accumulate_grad_batches"]

    # Model
    model = DollyV2Model(lr=lr, eps=eps)

    # prepare ray data ingestion
    train_ds = ray.train.get_dataset_shard("train")
    train_dataloader = train_ds.iter_torch_batches(batch_size=batch_size_per_worker)

    # Lightning Trainer
    pl_trainer = pl.Trainer(
        max_epochs=1,
        devices="auto",
        accelerator="auto",
        precision="16-mixed",
        strategy=strategy,
        plugins=[RayLightningEnvironment()],
        callbacks=[RayTrainReportCallback()],
        enable_checkpointing=False,
    )
    pl_trainer = prepare_trainer(pl_trainer)

    pl_trainer.fit(model, train_dataloaders=train_dataloader)


# Save Ray Train checkpoints according to the performance on validation set
run_config = RunConfig(
    name="finetune_dolly-v2-7b",
    storage_path=STORAGE_PATH,
    checkpoint_config=CheckpointConfig(num_to_keep=1),
)

# Scale the FSDP training workload across NUM_WORKERS GPUs
# You can change this config based on your compute resources.
scaling_config = ScalingConfig(
    num_workers=NUM_WORKERS, use_gpu=True, trainer_resources={"memory": 100 * 1024**3}
)

# Configuration to pass into train_func
train_config = {
    "lr": 2e-5,
    "eps": 1e-8,
    "strategy": fsdp_strategy,
    "batch_size_per_worker": 10,
}

# Define a TorchTrainer and launch you training workload
trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=train_config,
    run_config=run_config,
    scaling_config=scaling_config,
    datasets={"train": train_ds},
)

result = trainer.fit()

print(result)

ray.shutdown()

Issue Severity

High: It blocks me from completing my task.

@anindya-saha anindya-saha added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jan 30, 2025
@anindya-saha anindya-saha changed the title [Ray Train] [Ray Train] Fine-tune dolly-v2-7b with Ray Train, PyTorch Lightning and FSDP - Dimension mismatch error running it Jan 30, 2025
@jcotant1 jcotant1 added the train Ray Train Related Issue label Jan 31, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't train Ray Train Related Issue triage Needs triage (eg: priority, bug/not-bug, and owning component)
Projects
None yet
Development

No branches or pull requests

2 participants