-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Multicollector interruptor #963
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM I left some comments
torchrl/collectors/collectors.py
Outdated
@@ -67,6 +68,44 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase: | |||
return td.set("action", self.action_spec.rand()) | |||
|
|||
|
|||
class Interruptor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's public it should be in the doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I think it makes more sense that Interruptor and InterruptorManager are private. I don't see Interruptor class being useful for users beyond the scope of the collector.
torchrl/collectors/collectors.py
Outdated
return self._collect is False | ||
|
||
|
||
class InterruptorManager(SyncManager): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
test/test_collector.py
Outdated
|
||
for batch in collector: | ||
assert ( | ||
batch["collector"]["traj_ids"][ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rewrite this in a more straightforward way, it's a bit hard to read
What if none is -1? Shouldn't we test that we have at least one traj_id set to -1 in the batch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simplified the code for more clarity.
Regarding the possibility of having none -1, I have addressed that as well by setting the preemptive threshold to 0.0 in the test instead of 0.25. This way, all collectors will stop after the first iteration and only the very first frame of each sync collector will be valid. So we know for sure there are -1's in the batch.
Can you merge main in this branch, the tests are failing bc they're looking for a deprecated function in tensordict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you have a look at how it would play out with the distributed collectors I integrated yesterday?
We don't necessarily need everything to be fully compatible but I want to make sure that we're not missing some obvious point, e.g. how we handle the traj-ids in the distributed collector may need a bit of refactoring.
(I will merge this sooner than that as the solution is neat and usable as of now)
torchrl/collectors/collectors.py
Outdated
@@ -1399,6 +1462,18 @@ def iterator(self) -> Iterator[TensorDictBase]: | |||
|
|||
i += 1 | |||
max_traj_idx = None | |||
|
|||
if self.interruptor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for clarity can we have if self.interruptor is not None
?
Description
This PR implements a preemptive mechanism to early stop stragglers in multi sync data collectors. The invalid data can be identified because its trajectory ids are -1.
For now, preemptive mechanism is not compatible with
split_trajs=True
, but this could be easily adapted by ignoring trajectory ids equal to -1 in thesplit_trajectories
method.i.e.
splits = [(splits == i).sum().item() for i in splits.unique_consecutive() if i != -1] (changing line 39 in collectors/utils.py)
out_splits = rollout_tensordict.view(-1)[traj_ids != -1].split(splits, 0) (changing line 54 in collectors/utils.py)
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!