diff --git a/tests/experimental/transfer_queue/test_controller.py b/tests/experimental/transfer_queue/test_controller.py index 6577cd9e163..3b45da2a561 100644 --- a/tests/experimental/transfer_queue/test_controller.py +++ b/tests/experimental/transfer_queue/test_controller.py @@ -220,41 +220,42 @@ def test_get_prompt_metadata(self, setup_teardown_register_controller_info): mode="insert", ) ) + metadata.reorder([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]) assert metadata.global_indexes == [ - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - 27, - 28, - 29, - 30, 31, + 30, + 29, + 28, + 27, + 26, + 25, + 24, + 23, + 22, + 21, + 20, + 19, + 18, + 17, + 16, ] assert metadata.local_indexes == [ - 8, - 9, - 10, - 11, - 12, - 13, - 14, 15, - 8, - 9, - 10, - 11, - 12, - 13, 14, + 13, + 12, + 11, + 10, + 9, + 8, 15, + 14, + 13, + 12, + 11, + 10, + 9, + 8, ] storage_ids = metadata.storage_ids assert len(set(storage_ids[: len(storage_ids) // 2])) == 1 diff --git a/verl/experimental/transfer_queue/metadata.py b/verl/experimental/transfer_queue/metadata.py index 7346c292116..6d81e7f2ca3 100644 --- a/verl/experimental/transfer_queue/metadata.py +++ b/verl/experimental/transfer_queue/metadata.py @@ -480,6 +480,41 @@ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMet return BatchMeta(samples=merged_samples, extra_info=merged_extra_info) + def reorder(self, indices: list[int]): + """ + Reorder the SampleMeta in the BatchMeta according to the given indices. + + The operation is performed in-place, modifying the current BatchMeta's SampleMeta order. + + Args: + indices : list[int] + A list of integers specifying the new order of SampleMeta. Each integer + represents the current index of the SampleMeta in the BatchMeta. + """ + # Reorder the samples + reordered_samples = [self.samples[i] for i in indices] + object.__setattr__(self, "samples", reordered_samples) + + # Update necessary attributes + self._update_after_reorder() + + def _update_after_reorder(self) -> None: + """Update related attributes specifically for the reorder operation""" + # Update batch_index for each sample + for idx, sample in enumerate(self.samples): + object.__setattr__(sample, "_batch_index", idx) + + # Update cached index lists + object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) + object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples]) + object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples]) + + # Rebuild storage groups + storage_meta_groups = self._build_storage_meta_groups() + object.__setattr__(self, "_storage_meta_groups", storage_meta_groups) + + # Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder + @classmethod def from_samples( cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None