diff --git a/tests/conftest.py b/tests/conftest.py index 9e101909b..696ea9a87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -125,7 +125,7 @@ def remote_openai_server(request): if 'tp_size' in params: tp_size = params['tp_size'] - skip_unsupported_tp_size(int(tp_size)) + skip_unsupported_tp_size(int(tp_size), backend) server_args.extend(["--tensor-parallel-size", str(tp_size)]) try: diff --git a/tests/e2e/test_spyre_basic.py b/tests/e2e/test_spyre_basic.py index 3609b8931..d386eaefe 100644 --- a/tests/e2e/test_spyre_basic.py +++ b/tests/e2e/test_spyre_basic.py @@ -23,7 +23,8 @@ pytest.param(2, marks=pytest.mark.multi), pytest.param(4, marks=pytest.mark.multi), pytest.param(8, marks=pytest.mark.multi), -]) +], + ids=lambda val: f"TP({val})") @pytest.mark.parametrize("backend", get_spyre_backend_list()) def test_output( model: str, @@ -45,7 +46,7 @@ def test_output( After debugging, DISABLE_ASSERTS should be reset to 'False'. ''' - skip_unsupported_tp_size(tp_size) + skip_unsupported_tp_size(tp_size, backend) prompts = get_chicken_soup_prompts(4) diff --git a/tests/e2e/test_spyre_online.py b/tests/e2e/test_spyre_online.py index 4c3c876b1..689dc4b72 100644 --- a/tests/e2e/test_spyre_online.py +++ b/tests/e2e/test_spyre_online.py @@ -9,7 +9,8 @@ pytest.param(2, marks=pytest.mark.multi), pytest.param(4, marks=pytest.mark.multi), pytest.param(8, marks=pytest.mark.multi), -]) +], + ids=lambda val: f"TP({val})") @pytest.mark.parametrize("backend", get_spyre_backend_list()) @pytest.mark.parametrize("warmup_shape", [[ (64, 20, 1), diff --git a/tests/e2e/test_spyre_prompt_logprobs.py b/tests/e2e/test_spyre_prompt_logprobs.py index b9ea6bd48..cebfb772c 100644 --- a/tests/e2e/test_spyre_prompt_logprobs.py +++ b/tests/e2e/test_spyre_prompt_logprobs.py @@ -19,10 +19,11 @@ @pytest.mark.parametrize("backend", get_spyre_backend_list()) @pytest.mark.parametrize("model", get_spyre_model_list()) @pytest.mark.parametrize("tp_size", [ - pytest.param(1, id="tp_size"), - pytest.param(2, marks=pytest.mark.multi, id="tp_size"), - pytest.param(4, marks=pytest.mark.multi, id="tp_size") -]) + pytest.param(1), + pytest.param(2, marks=pytest.mark.multi), + pytest.param(4, marks=pytest.mark.multi) +], + ids=lambda val: f"TP({val})") def test_prompt_logprobs( backend: str, model: str, @@ -33,7 +34,7 @@ def test_prompt_logprobs( This test checks the prompt_logprobs output from vllm against a reference implementation using huggingface. ''' - skip_unsupported_tp_size(tp_size) + skip_unsupported_tp_size(tp_size, backend) num_prompt_logprobs = 5 prompts = get_chicken_soup_prompts(4) diff --git a/tests/spyre_util.py b/tests/spyre_util.py index 07b14fcb8..31a463cfa 100644 --- a/tests/spyre_util.py +++ b/tests/spyre_util.py @@ -548,7 +548,13 @@ def create_random_request( **extra_kwargs) -def skip_unsupported_tp_size(size: int): +def skip_unsupported_tp_size(size: int, backend: str): + if backend in ["eager", "inductor"]: + # Spyre cards aren't required for running TP on CPU backends + # But it's really slow to run tp > 2 + if size > 2: + pytest.skip("Skipping TP test on CPU with TP size > 2") + return cards = int(os.getenv("AIU_WORLD_SIZE", "0")) if cards < size: pytest.skip(f"Cannot run TP size {size}: " diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index f63c1824d..628014363 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -397,10 +397,6 @@ def execute_model( masks=model_input.input_masks, is_prompt=model_input.is_prompt) - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return EMPTY_MODEL_RUNNER_OUTPUT - # Compute the logits. logits = self.model.compute_logits(hidden_states, None) @@ -434,6 +430,10 @@ def execute_model( prompt_logprobs_dicts = self._get_prompt_logprobs_dict( logits=logits, model_inputs=model_input) + # Only return outputs from the driver worker + if not self.is_driver_worker: + return EMPTY_MODEL_RUNNER_OUTPUT + model_output = ModelRunnerOutput( req_ids=list(req_id_to_index.keys()), req_id_to_index=req_id_to_index,