Skip to content

Tensor parallel distributed strategy without using deepspeed#1121

Merged
regisss merged 10 commits into
huggingface:mainfrom
kalyanjk:up_tp_strategy
Jul 30, 2024
Merged

Tensor parallel distributed strategy without using deepspeed#1121
regisss merged 10 commits into
huggingface:mainfrom
kalyanjk:up_tp_strategy

Conversation

@kalyanjk
Copy link
Copy Markdown
Contributor

@kalyanjk kalyanjk commented Jul 3, 2024

Tensor parallel by extending GaudiLlamaAttention -> TPGaudiLlamaAttention and GaudiLlamaMLP -> TPGaudiLlamaMLP

use parameter --distributed_strategy="tp" to invoke this code path

code design reference: https://github.com/foundation-model-stack/foundation-model-stack/tree/main

@kalyanjk kalyanjk requested review from libinta and mandy-li as code owners July 3, 2024 14:20
@kalyanjk kalyanjk requested a review from a user July 3, 2024 14:20
@kalyanjk kalyanjk requested a review from regisss as a code owner July 3, 2024 14:20
@libinta libinta added the synapse1.17 PR that should be available along with Synapse 1.17 but have no dependency on Synapse 1.17 content. label Jul 9, 2024
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few comments. Additionally, can you:

Comment thread examples/text-generation/utils.py Outdated
Comment thread examples/text-generation/utils.py Outdated
Comment thread examples/text-generation/utils.py Outdated
Comment thread optimum/habana/distributed/tp_wrapping.py Outdated
Comment thread optimum/habana/distributed/tp.py Outdated
Comment thread optimum/habana/transformers/models/llama/modeling_llama.py Outdated
Comment thread optimum/habana/transformers/models/llama/modeling_llama.py Outdated
Comment thread optimum/habana/transformers/models/llama/modeling_llama.py Outdated
Comment thread optimum/habana/transformers/models/llama/modeling_llama.py Outdated
Comment thread optimum/habana/transformers/models/llama/modeling_llama.py Outdated
@kalyanjk
Copy link
Copy Markdown
Contributor Author

I left a few comments. Additionally, can you:

Done

@kalyanjk kalyanjk force-pushed the up_tp_strategy branch 3 times, most recently from 8ecceca to 2b5f46e Compare July 18, 2024 04:53
@kalyanjk
Copy link
Copy Markdown
Contributor Author

@regisss Addressed all your comments. Please review the changes. Thank you!

@kalyanjk kalyanjk requested a review from regisss July 19, 2024 04:40
Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update your main branch and merge it into this PR? To have the whole CI working again.

Comment thread tests/test_text_generation_example.py Outdated
Comment thread tests/test_text_generation_example.py Outdated
Comment thread examples/text-generation/README.md Outdated

You will also need to add `--torch_compile` in your command.

### Running with Tesor parallel strategy
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
### Running with Tesor parallel strategy
### Running with tensor-parallel strategy

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread examples/text-generation/README.md Outdated
### Running with Tesor parallel strategy
#### Attribution

This repository includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This repository includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details.
> [!NOTE]
> This strategy includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread examples/text-generation/README.md Outdated
You will also need to add `--torch_compile` in your command.

### Running with Tesor parallel strategy
#### Attribution
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove that line, let's put it in a "box" as suggested below

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added in the box, but not sure if this syntax had to be preserved [!WARNING]

Comment thread examples/text-generation/README.md Outdated

This repository includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details.

torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. To enable
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. To enable
> [!WARNING]
> torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models.
To enable...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, Added Note and Warning in box

@libinta
Copy link
Copy Markdown
Collaborator

libinta commented Jul 24, 2024

@kalyanjk can you update based on review comments?

@kalyanjk
Copy link
Copy Markdown
Contributor Author

@kalyanjk can you update based on review comments?

@libinta, I have addressed all. Anything i have missed?

@kalyanjk kalyanjk requested a review from regisss July 26, 2024 05:28
Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few more comments to address.
Also, the test fails on my instance with this error:

Traceback (most recent call last):                                                                                                           
  File "/root/workspace/fork/examples/text-generation/run_generation.py", line 674, in <module>                                              
    main()                                                                                                                                   
  File "/root/workspace/fork/examples/text-generation/run_generation.py", line 317, in main                                                  
    model, assistant_model, tokenizer, generation_config = initialize_model(args, logger)                                                    
  File "/root/workspace/fork/examples/text-generation/utils.py", line 592, in initialize_model                                               
    else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger)                                                                 
  File "/root/workspace/fork/examples/text-generation/utils.py", line 281, in setup_distributed_model_tp                                     
    lazy_sd = serialization.load_state_dict(                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/distributed/serialization.py", line 191, in load_state_dict                   
    assert len(checkpoints) > 0, f"Can't find the requested checkpoint data at {model_path}"                                                 
AssertionError: Can't find the requested checkpoint data at meta-llama/Llama-2-7b-hf

Any idea about what's going on? It seems like a serialization issue. Or is it because it requires Synapse 1.17? I'm running 1.16.

Comment thread examples/text-generation/README.md Outdated

You will also need to add `--torch_compile` in your command.

### Running with tesor-parallel strategy
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
### Running with tesor-parallel strategy
### Running with tensor-parallel strategy

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread examples/text-generation/README.md Outdated
Comment on lines +269 to +271
```bash
NOTE: This strategy includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details.
```
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```bash
NOTE: This strategy includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details.
```
> [!NOTE]
> This strategy includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated with the suggested format

Comment on lines +272 to +276

```bash
WARNING: torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models.
```
To enable torch.compile with tensor parallel strategy, please set the following environment variables before running the
command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```bash
WARNING: torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models.
```
To enable torch.compile with tensor parallel strategy, please set the following environment variables before running the
command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed.
> [!WARNING]
> torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models.
To enable torch.compile with tensor parallel strategy, please set the following environment variables before running the
command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated with the suggested format

Comment thread examples/text-generation/README.md Outdated

Here is an example:
```bash
python ../gaudi_spawn.py --world_size 8 run_generation.py \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
python ../gaudi_spawn.py --world_size 8 run_generation.py \
PT_ENABLE_INT64_SUPPORT=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py --world_size 8 run_generation.py \

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@regisss
Copy link
Copy Markdown
Collaborator

regisss commented Jul 29, 2024

Please run make style too

@kalyanjk
Copy link
Copy Markdown
Contributor Author

I left a few more comments to address. Also, the test fails on my instance with this error:

Traceback (most recent call last):                                                                                                           
  File "/root/workspace/fork/examples/text-generation/run_generation.py", line 674, in <module>                                              
    main()                                                                                                                                   
  File "/root/workspace/fork/examples/text-generation/run_generation.py", line 317, in main                                                  
    model, assistant_model, tokenizer, generation_config = initialize_model(args, logger)                                                    
  File "/root/workspace/fork/examples/text-generation/utils.py", line 592, in initialize_model                                               
    else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger)                                                                 
  File "/root/workspace/fork/examples/text-generation/utils.py", line 281, in setup_distributed_model_tp                                     
    lazy_sd = serialization.load_state_dict(                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/distributed/serialization.py", line 191, in load_state_dict                   
    assert len(checkpoints) > 0, f"Can't find the requested checkpoint data at {model_path}"                                                 
AssertionError: Can't find the requested checkpoint data at meta-llama/Llama-2-7b-hf

Any idea about what's going on? It seems like a serialization issue. Or is it because it requires Synapse 1.17? I'm running 1.16.

Can you provide with absolute path for meta-llama/Llama-2-7b-hf. All my testing is on 1.17, I will verify on 1.16 and update.

@kalyanjk
Copy link
Copy Markdown
Contributor Author

I left a few more comments to address. Also, the test fails on my instance with this error:

Traceback (most recent call last):                                                                                                           
  File "/root/workspace/fork/examples/text-generation/run_generation.py", line 674, in <module>                                              
    main()                                                                                                                                   
  File "/root/workspace/fork/examples/text-generation/run_generation.py", line 317, in main                                                  
    model, assistant_model, tokenizer, generation_config = initialize_model(args, logger)                                                    
  File "/root/workspace/fork/examples/text-generation/utils.py", line 592, in initialize_model                                               
    else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger)                                                                 
  File "/root/workspace/fork/examples/text-generation/utils.py", line 281, in setup_distributed_model_tp                                     
    lazy_sd = serialization.load_state_dict(                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/optimum/habana/distributed/serialization.py", line 191, in load_state_dict                   
    assert len(checkpoints) > 0, f"Can't find the requested checkpoint data at {model_path}"                                                 
AssertionError: Can't find the requested checkpoint data at meta-llama/Llama-2-7b-hf

Any idea about what's going on? It seems like a serialization issue. Or is it because it requires Synapse 1.17? I'm running 1.16.

Can you provide with absolute path for meta-llama/Llama-2-7b-hf. All my testing is on 1.17, I will verify on 1.16 and update.

@regisss successfully verified the sanity test for the 1.16 release using both the 7b and 70b models. Everything is working fine.

@regisss
Copy link
Copy Markdown
Collaborator

regisss commented Jul 29, 2024

There is no absolute path, this is the hub model id and I really think this use case should work as not everybody has the models stored locally. If the absolute path to the model is needed, there should be some code to find the model in the Transformers cache. You can get the default path to cache with:

from huggingface_hub.constants import HF_HUB_CACHE

More information about the structure of the cache here: https://huggingface.co/docs/huggingface_hub/v0.24.2/en/guides/manage-cache#understand-caching

Also, I see I forgot to mention it, can you replace the arg distributed_strategy by parallelism_strategy or something similar everywhere please? We already have distribution_strategy defined here and it would add a lot of confusion to have both distributed_strategy and distribution_strategy. Sorry for noticing this now, I thought I had commented on it before.

@kalyanjk
Copy link
Copy Markdown
Contributor Author

There is no absolute path, this is the hub model id and I really think this use case should work as not everybody has the models stored locally. If the absolute path to the model is needed, there should be some code to find the model in the Transformers cache. You can get the default path to cache with:

from huggingface_hub.constants import HF_HUB_CACHE

More information about the structure of the cache here: https://huggingface.co/docs/huggingface_hub/v0.24.2/en/guides/manage-cache#understand-caching

Also, I see I forgot to mention it, can you replace the arg distributed_strategy by parallelism_strategy or something similar everywhere please? We already have distribution_strategy defined here and it would add a lot of confusion to have both distributed_strategy and distribution_strategy. Sorry for noticing this now, I thought I had commented on it before.

Updated : renamed the distributed_strategy to parallel_strategy.
In process : Is there a way i can test the relative path for model_name - > meta-llama/Llama-2-7b-hf.

@kalyanjk
Copy link
Copy Markdown
Contributor Author

There is no absolute path, this is the hub model id and I really think this use case should work as not everybody has the models stored locally. If the absolute path to the model is needed, there should be some code to find the model in the Transformers cache. You can get the default path to cache with:

from huggingface_hub.constants import HF_HUB_CACHE

More information about the structure of the cache here: https://huggingface.co/docs/huggingface_hub/v0.24.2/en/guides/manage-cache#understand-caching
Also, I see I forgot to mention it, can you replace the arg distributed_strategy by parallelism_strategy or something similar everywhere please? We already have distribution_strategy defined here and it would add a lot of confusion to have both distributed_strategy and distribution_strategy. Sorry for noticing this now, I thought I had commented on it before.

Updated : renamed the distributed_strategy to parallel_strategy. In process : Is there a way i can test the relative path for model_name - > meta-llama/Llama-2-7b-hf.

Updated cache_dir setting for parallel_strategy = tp

@regisss can you please verify if you are able to load the data now

Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes, it looks good to me!
One last thing, as written in the comment below, the test fails on my instance because the throughput I get is too low. Maybe due to a different version of Synapse?

Comment thread tests/test_text_generation_example.py Outdated
@regisss regisss merged commit 139ad89 into huggingface:main Jul 30, 2024
kalyanjk added a commit to kalyanjk/optimum-habana-fork that referenced this pull request Jul 31, 2024
kalyanjk added a commit to kalyanjk/optimum-habana-fork that referenced this pull request Jul 31, 2024
ghost pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Jul 31, 2024
* Revert "Tensor parallel  distributed strategy without using deepspeed (#280) (#299)"

This reverts commit 32c86d3.

* Tensor parallel distributed strategy without using deepspeed (huggingface#1121)

Co-authored-by: Kalyan <kkumar@habana.ai>

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
ghost pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Jul 31, 2024
* Revert "Tensor parallel  distributed strategy without using deepspeed (#280)"

This reverts commit c6e5f9c.

* Tensor parallel distributed strategy without using deepspeed (huggingface#1121)

Co-authored-by: Kalyan <kkumar@habana.ai>

---------

Co-authored-by: Kalyan <kkumar@habana.ai>
xinyu-intel pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Mar 4, 2025
* Revert "Tensor parallel  distributed strategy without using deepspeed (#280)"

This reverts commit c6e5f9c.

* Tensor parallel distributed strategy without using deepspeed (huggingface#1121)

Co-authored-by: Kalyan <kkumar@habana.ai>

---------

Change-Id: Ic30c85e697dbd6a51767e21e1c06c9a20120d9f6
Co-authored-by: Kalyan <kkumar@habana.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

synapse1.17 PR that should be available along with Synapse 1.17 but have no dependency on Synapse 1.17 content.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants