Skip to content

Commit

Permalink
Modify Qwen2 TRL command to avoid OOM. (#1630)
Browse files Browse the repository at this point in the history
Add --use_flash_attention to avoid OOM for Qwen2
  • Loading branch information
jiminha authored and regisss committed Dec 20, 2024
1 parent c6df122 commit 425bac7
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
3 changes: 2 additions & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
with deepspeed.OnDevice(dtype=model_dtype, device="meta"):
if (
hasattr(config, 'rope_scaling') and config.rope_scaling
hasattr(config, "rope_scaling")
and config.rope_scaling
and config.rope_scaling["rope_type"] == "llama3"
and config.max_position_embeddings > 8192
):
Expand Down
3 changes: 2 additions & 1 deletion examples/trl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ $ pip install -U -r requirements.txt
--lora_dropout=0.05 \
--lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
--max_seq_length 512 \
--adam_epsilon 1e-08
--adam_epsilon 1e-08 \
--use_flash_attention
```
2. Supervised fine-tuning of the mistralai/Mixtral-8x7B-v0.1 on 4 cards:
Expand Down
19 changes: 7 additions & 12 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ def test_text_generation_bf16_1x(
check_output=check_output,
)

@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))),
reason="Skipping test for G1")

@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1")
@pytest.mark.parametrize(
"model_name, world_size, batch_size, reuse_cache, input_len, output_len, baseline", MODELS_TO_TEST["fp8"]
)
Expand Down Expand Up @@ -413,8 +413,7 @@ def test_text_generation_fp8(
)


@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))),
reason="Skipping test for G1")
@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1")
@pytest.mark.parametrize(
"model_name, world_size, batch_size, reuse_cache, input_len, output_len, baseline",
MODELS_TO_TEST["load_quantized_model_with_autogptq"],
Expand Down Expand Up @@ -450,23 +449,20 @@ def test_text_generation_deepspeed(model_name: str, baseline: float, world_size:
_test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, batch_size=batch_size)


@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))),
reason="Skipping test for G1")
@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1")
@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["torch_compile"])
def test_text_generation_torch_compile(model_name: str, baseline: float, token: str):
_test_text_generation(model_name, baseline, token, torch_compile=True)


@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))),
reason="Skipping test for G1")
@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1")
@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["torch_compile_distributed"])
def test_text_generation_torch_compile_distributed(model_name: str, baseline: float, token: str):
world_size = 8
_test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True)


@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))),
reason="Skipping test for G1")
@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1")
@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["distributed_tp"])
def test_text_generation_distributed_tp(model_name: str, baseline: float, token: str):
world_size = 8
Expand All @@ -489,8 +485,7 @@ def test_text_generation_contrastive_search(
_test_text_generation(model_name, baseline, token, batch_size, reuse_cache, contrastive_search=True)


@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))),
reason="Skipping test for G1")
@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1")
@pytest.mark.parametrize("model_name, batch_size, reuse_cache, baseline", MODELS_TO_TEST["beam_search"])
def test_text_generation_beam_search(model_name: str, baseline: float, batch_size: int, reuse_cache: bool, token: str):
_test_text_generation(model_name, baseline, token, batch_size, reuse_cache, num_beams=3)
Expand Down

0 comments on commit 425bac7

Please sign in to comment.