Skip to content

Fixed recursion error when uses both wrapped PEFT and DeepSpped#1400

Closed
kplau1128 wants to merge 1 commit into
huggingface:mainfrom
kplau1128:fix_peft_ds_recursion
Closed

Fixed recursion error when uses both wrapped PEFT and DeepSpped#1400
kplau1128 wants to merge 1 commit into
huggingface:mainfrom
kplau1128:fix_peft_ds_recursion

Conversation

@kplau1128
Copy link
Copy Markdown
Contributor

@kplau1128 kplau1128 commented Oct 6, 2024

Recursion error in compute_loss. The crux of the issue lies in circular references introduced by the combined wrappers in Sentence Transformers.

Optimum-Habana example:
optimum-habana/examples/sentence-transformers-training/sts

Command Line:
python ../../gaudi_spawn.py --use_deepspeed --world_size 2 training_stsbenchmark.py --peft

@kplau1128 kplau1128 requested a review from regisss as a code owner October 6, 2024 00:34
@yafshar
Copy link
Copy Markdown
Contributor

yafshar commented Oct 8, 2024

@kplau1128 thanks for the fix.

  • please run make style and fix if there is an error or issue.
  • please run make test_installs; python -m pytest tests/sentence_transformers/test_training_stsbenchmark.py make sure there is no new issue

for name, child in loss.named_children():
if isinstance(child, torch.nn.Module):
# Avoid replacing the model again if it's already the desired model
if not (name == "model" and child is model):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not (name == "model" and child is model):
if name != "model" or child is not model:

can we simplify the condition! (Personal suggestion)

and model != self.model # Only if the model is wrapped
and hasattr(loss_fn, "model") # Only if the loss stores the model
and loss_fn.model != model # Only if the wrapped model is not already stored
and loss_fn.model != self.model # Assign the original model instead
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kplau1128 this does not sound correct to me! Here the goal is to insert the wrapped model (distributed or compiled) into the loss function, not the original model

Copy link
Copy Markdown
Contributor

@yafshar yafshar Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that without this fix the recursion happen, but the workaround is different from the original intention

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it does not do anything since loss_fn.model is self.model

Copy link
Copy Markdown
Contributor Author

@kplau1128 kplau1128 Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yafshar You are right.

The crux of the issue lies in circular references introduced by the combined wrappers. I will rework the workaround only to modify the override_model_in_loss function to skip override loss when combined wrappers.

@kplau1128 kplau1128 force-pushed the fix_peft_ds_recursion branch from 4359ed9 to 92ce0c9 Compare October 9, 2024 22:59
@kplau1128
Copy link
Copy Markdown
Contributor Author

kplau1128 commented Oct 9, 2024

Re-work the workaround to check PEFT-wrapped model and skip override model in loss.
This will prevent recursion error.

Ran make style no error
Ran make test_installs; python -m pytest tests/sentence_transformers/test_training_stsbenchmark.py passed, no new issue

- Recursion error in compute_loss. The crux of the issue lies in
  circular references introduced by the combined wrappers.
@kplau1128 kplau1128 force-pushed the fix_peft_ds_recursion branch from 92ce0c9 to 0f8adbf Compare October 10, 2024 19:18
@kplau1128
Copy link
Copy Markdown
Contributor Author

Reworked as @nngokhale recommended, just add a check condition in compute_loss for override loss_fn.

@kplau1128 kplau1128 requested a review from yafshar October 11, 2024 01:14
@yafshar
Copy link
Copy Markdown
Contributor

yafshar commented Oct 11, 2024

@kplau1128 this does not sound correct to me. Let me check it more, I will update here

@yafshar
Copy link
Copy Markdown
Contributor

yafshar commented Oct 16, 2024

@kplau1128 I made another PR #1428 which I think root causes this issue. Please take a look and close this PR if you agree with that solution.

@kplau1128
Copy link
Copy Markdown
Contributor Author

@kplau1128 I made another PR #1428 which I think root causes this issue. Please take a look and close this PR if you agree with that solution.

Thanks @yafshar. Looks similar to my previous rework tried to do 92ce0c9, but your PR is better. I will close this one.

@kplau1128 kplau1128 closed this Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants