Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
MOE_MODEL = "Qwen/Qwen1.5-MoE-A2.7B"


def get_test_actor_config() -> SkyRLTrainConfig:
def get_test_actor_config(model: str = MODEL) -> SkyRLTrainConfig:
"""Get base config with test-specific overrides."""
cfg = SkyRLTrainConfig()
cfg.trainer.policy.model.path = MODEL
cfg.trainer.policy.model.path = model

cfg.generator.sampling_params.temperature = 0.0
cfg.generator.sampling_params.top_p = 1
Expand Down Expand Up @@ -90,25 +90,25 @@ async def run_single_generation_with_tokens(client, prompt_token_ids, sampling_p

@pytest.mark.skipif(_SKYRL_USE_NEW_INFERENCE, reason="New inference pathway doesn't support text based generation")
@pytest.mark.parametrize(
"tp_size,pp_size,dp_size",
"tp_size,pp_size,dp_size,model",
[
pytest.param(2, 1, 1),
pytest.param(2, 1, 2),
pytest.param(2, 2, 1), # TP=2, PP=2
pytest.param(2, 1, 1, MODEL),
pytest.param(2, 1, 2, MOE_MODEL),
pytest.param(2, 2, 1, MODEL), # TP=2, PP=2
],
ids=["tp2_pp1_dp1", "tp2_pp1_dp2", "tp2_pp2_dp1"],
ids=["tp2_pp1_dp1", "tp2_pp1_dp2_moe", "tp2_pp2_dp1"],
)
def test_inference_engines_generation(ray_init_fixture, tp_size: int, pp_size: int, dp_size: int):
def test_inference_engines_generation(ray_init_fixture, tp_size: int, pp_size: int, dp_size: int, model: str):
"""
Tests generation with both remote and ray-wrapped engines.
"""
cfg = get_test_actor_config()
cfg = get_test_actor_config(model)

prompts = get_test_prompts(MODEL)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
prompts = get_test_prompts(model)
tokenizer = AutoTokenizer.from_pretrained(model)

try:
llm_client, remote_server_process = init_remote_inference_servers(tp_size, "vllm", tokenizer, cfg, MODEL)
llm_client, remote_server_process = init_remote_inference_servers(tp_size, "vllm", tokenizer, cfg, model)
sampling_params = get_sampling_params_for_backend(
cfg.generator.inference_engine.backend, cfg.generator.sampling_params
)
Expand Down Expand Up @@ -210,8 +210,7 @@ def test_token_based_generation(
):
"""Test generation using prompt_token_ids."""

cfg = get_test_actor_config()
cfg.trainer.policy.model.path = model
cfg = get_test_actor_config(model)

prompts = get_test_prompts(model, 3)
tokenizer = AutoTokenizer.from_pretrained(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ async def test_generator_multi_turn_gsm8k_router_replay(ray_init_fixture):
max_prompt_length=2048,
max_input_length=max_input_length,
max_generate_length=1000,
data_path=os.path.expanduser("/mnt/cluster_storage/data/gsm8k/validation.parquet"),
data_path=os.path.expanduser("~/data/gsm8k/validation.parquet"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line explicitly sets data_path to the same value as its default in the run_generator_end_to_end function signature. To reduce redundancy and improve maintainability, you can remove this line and rely on the default value.

env_class="gsm8k_multi_turn",
num_prompts=num_prompts,
max_turns=2,
Expand Down
Loading