Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions tests/experimental/transfer_queue/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,41 +220,42 @@ def test_get_prompt_metadata(self, setup_teardown_register_controller_info):
mode="insert",
)
)
metadata.reorder([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
assert metadata.global_indexes == [
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
30,
29,
28,
27,
26,
25,
24,
23,
22,
21,
20,
19,
18,
17,
16,
]
assert metadata.local_indexes == [
8,
9,
10,
11,
12,
13,
14,
15,
8,
9,
10,
11,
12,
13,
14,
13,
12,
11,
10,
9,
8,
15,
14,
13,
12,
11,
10,
9,
8,
]
storage_ids = metadata.storage_ids
assert len(set(storage_ids[: len(storage_ids) // 2])) == 1
Expand Down
35 changes: 35 additions & 0 deletions verl/experimental/transfer_queue/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,41 @@ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMet

return BatchMeta(samples=merged_samples, extra_info=merged_extra_info)

def reorder(self, indices: list[int]):
"""
Reorder the SampleMeta in the BatchMeta according to the given indices.

The operation is performed in-place, modifying the current BatchMeta's SampleMeta order.

Args:
indices : list[int]
A list of integers specifying the new order of SampleMeta. Each integer
represents the current index of the SampleMeta in the BatchMeta.
"""
# Reorder the samples
reordered_samples = [self.samples[i] for i in indices]
object.__setattr__(self, "samples", reordered_samples)

# Update necessary attributes
self._update_after_reorder()

def _update_after_reorder(self) -> None:
"""Update related attributes specifically for the reorder operation"""
# Update batch_index for each sample
for idx, sample in enumerate(self.samples):
object.__setattr__(sample, "_batch_index", idx)

# Update cached index lists
object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
object.__setattr__(self, "_local_indexes", [sample.local_index for sample in self.samples])
object.__setattr__(self, "_storage_ids", [sample.storage_id for sample in self.samples])

# Rebuild storage groups
storage_meta_groups = self._build_storage_meta_groups()
object.__setattr__(self, "_storage_meta_groups", storage_meta_groups)

# Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder

@classmethod
def from_samples(
cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None
Expand Down