diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py index 884d62efc0..b17fbbf727 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py @@ -25,10 +25,10 @@ MOE_MODEL = "Qwen/Qwen1.5-MoE-A2.7B" -def get_test_actor_config() -> SkyRLTrainConfig: +def get_test_actor_config(model: str = MODEL) -> SkyRLTrainConfig: """Get base config with test-specific overrides.""" cfg = SkyRLTrainConfig() - cfg.trainer.policy.model.path = MODEL + cfg.trainer.policy.model.path = model cfg.generator.sampling_params.temperature = 0.0 cfg.generator.sampling_params.top_p = 1 @@ -90,25 +90,25 @@ async def run_single_generation_with_tokens(client, prompt_token_ids, sampling_p @pytest.mark.skipif(_SKYRL_USE_NEW_INFERENCE, reason="New inference pathway doesn't support text based generation") @pytest.mark.parametrize( - "tp_size,pp_size,dp_size", + "tp_size,pp_size,dp_size,model", [ - pytest.param(2, 1, 1), - pytest.param(2, 1, 2), - pytest.param(2, 2, 1), # TP=2, PP=2 + pytest.param(2, 1, 1, MODEL), + pytest.param(2, 1, 2, MOE_MODEL), + pytest.param(2, 2, 1, MODEL), # TP=2, PP=2 ], - ids=["tp2_pp1_dp1", "tp2_pp1_dp2", "tp2_pp2_dp1"], + ids=["tp2_pp1_dp1", "tp2_pp1_dp2_moe", "tp2_pp2_dp1"], ) -def test_inference_engines_generation(ray_init_fixture, tp_size: int, pp_size: int, dp_size: int): +def test_inference_engines_generation(ray_init_fixture, tp_size: int, pp_size: int, dp_size: int, model: str): """ Tests generation with both remote and ray-wrapped engines. """ - cfg = get_test_actor_config() + cfg = get_test_actor_config(model) - prompts = get_test_prompts(MODEL) - tokenizer = AutoTokenizer.from_pretrained(MODEL) + prompts = get_test_prompts(model) + tokenizer = AutoTokenizer.from_pretrained(model) try: - llm_client, remote_server_process = init_remote_inference_servers(tp_size, "vllm", tokenizer, cfg, MODEL) + llm_client, remote_server_process = init_remote_inference_servers(tp_size, "vllm", tokenizer, cfg, model) sampling_params = get_sampling_params_for_backend( cfg.generator.inference_engine.backend, cfg.generator.sampling_params ) @@ -210,8 +210,7 @@ def test_token_based_generation( ): """Test generation using prompt_token_ids.""" - cfg = get_test_actor_config() - cfg.trainer.policy.model.path = model + cfg = get_test_actor_config(model) prompts = get_test_prompts(model, 3) tokenizer = AutoTokenizer.from_pretrained(model) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py index a8b782f924..3c5589df7c 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py @@ -488,7 +488,7 @@ async def test_generator_multi_turn_gsm8k_router_replay(ray_init_fixture): max_prompt_length=2048, max_input_length=max_input_length, max_generate_length=1000, - data_path=os.path.expanduser("/mnt/cluster_storage/data/gsm8k/validation.parquet"), + data_path=os.path.expanduser("~/data/gsm8k/validation.parquet"), env_class="gsm8k_multi_turn", num_prompts=num_prompts, max_turns=2,