Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def _make_connector_with_fake_worker(
)
worker = connector.connector_worker
assert isinstance(worker.nixl_wrapper, FakeNixlWrapper)
worker.nixl_wrapper.set_cycles_before_xfer_done(cycles_before_done)
worker.kv_cache_layout = "HND"
if do_handshake:
remote_agents = worker._nixl_handshake(
Expand Down
16 changes: 1 addition & 15 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,6 @@ def transfer(self, handle: int) -> str:
def get_xfer_telemetry(self, handle: int) -> dict:
return get_default_xfer_telemetry()

############################################################
# Follow are for changing the behavior during testing.
############################################################

def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
Comment thread
yewentao256 marked this conversation as resolved.


@contextlib.contextmanager
def _make_fake_nixl_pkg():
Expand Down Expand Up @@ -578,10 +571,7 @@ def test_multi_xfer_one_engine(
"""Test case where multiple xfers are initiated to the same engine.

This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are
multiple transfer states for the same `request_id`, and `get_finished`
should handle it correctly (wait for all transfers to be done).
`request_id`.
"""
vllm_config = create_vllm_config()

Expand All @@ -598,7 +588,6 @@ def test_multi_xfer_one_engine(
)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
Expand Down Expand Up @@ -1304,7 +1293,6 @@ def test_scheduler_kv_connector_stats_aggregation():
# Worker stats with transfer metrics
worker_stats = NixlKVConnectorStats()
worker_stats.record_transfer(get_default_xfer_telemetry())
worker_stats.data["remote_tokens"] = []

# Scheduler stats with custom metric (needs dummy transfer to avoid being skipped)
scheduler_stats = NixlKVConnectorStats()
Expand All @@ -1314,7 +1302,6 @@ def test_scheduler_kv_connector_stats_aggregation():
"post_duration": [0],
"bytes_transferred": [0],
"num_descriptors": [0],
"remote_tokens": [128],
}
)

Expand Down Expand Up @@ -1355,7 +1342,6 @@ def test_scheduler_kv_connector_stats_aggregation():
).scheduler_stats.kv_connector_stats
nixl_stats = final_stats["NixlConnector"]
assert nixl_stats.num_successful_transfers == 2
assert nixl_stats.data["remote_tokens"] == [128]


@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
Expand Down
15 changes: 6 additions & 9 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,12 @@ def update_finished_set(
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = kv_output.kv_connector_stats
elif kv_connector_stats := kv_output.kv_connector_stats:
if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats
else:
assert isinstance(
aggregated_kv_connector_stats, type(kv_connector_stats)
)
aggregated_kv_connector_stats = (
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)
assert isinstance(
aggregated_kv_connector_stats, type(kv_connector_stats)
)
aggregated_kv_connector_stats = aggregated_kv_connector_stats.aggregate(
kv_connector_stats
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is ok


# Aggregate kv_connector_worker_meta from all workers.
if aggregated_kv_connector_worker_meta is None:
Expand Down
38 changes: 0 additions & 38 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,9 +1251,6 @@ def get_dcp_group() -> GroupCoordinator:
return _DCP


# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group

_PP: GroupCoordinator | None = None


Expand Down Expand Up @@ -1821,31 +1818,6 @@ def model_parallel_is_initialized():
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
"""Patch the tp group temporarily until this function ends.

This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.

Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _TP_STATE_PATCHED
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

_TP_STATE_PATCHED = True
old_tp_group = get_tp_group()
global _TP
_TP = tp_group
try:
yield
finally:
# restore the original state
_TP_STATE_PATCHED = False
_TP = old_tp_group


def get_tensor_model_parallel_world_size() -> int:
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
Expand All @@ -1856,16 +1828,6 @@ def get_tensor_model_parallel_rank() -> int:
return get_tp_group().rank_in_group


def get_decode_context_model_parallel_world_size() -> int:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao can we remove these?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 seems that @youkaichao is quite busy recently, check again this is not used at all, could we land this PR in this case?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

"""Return world size for the decode context model parallel group."""
return get_dcp_group().world_size


def get_decode_context_model_parallel_rank() -> int:
"""Return my rank for the decode context model parallel group."""
return get_dcp_group().rank_in_group


def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment."""
assert _NODE_COUNT is not None, "distributed environment is not initialized"
Expand Down
Loading