Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Multicollector interruptor #963

Merged
merged 19 commits into from
Mar 17, 2023

Conversation

albertbou92
Copy link
Contributor

@albertbou92 albertbou92 commented Mar 13, 2023

Description

This PR implements a preemptive mechanism to early stop stragglers in multi sync data collectors. The invalid data can be identified because its trajectory ids are -1.

For now, preemptive mechanism is not compatible with split_trajs=True, but this could be easily adapted by ignoring trajectory ids equal to -1 in the split_trajectories method.

i.e.
splits = [(splits == i).sum().item() for i in splits.unique_consecutive() if i != -1] (changing line 39 in collectors/utils.py)
out_splits = rollout_tensordict.view(-1)[traj_ids != -1].split(splits, 0) (changing line 54 in collectors/utils.py)

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213 if this solves the issue #15213

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 13, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM I left some comments

@@ -67,6 +68,44 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase:
return td.set("action", self.action_spec.rand())


class Interruptor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's public it should be in the doc

Copy link
Contributor Author

@albertbou92 albertbou92 Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I think it makes more sense that Interruptor and InterruptorManager are private. I don't see Interruptor class being useful for users beyond the scope of the collector.

return self._collect is False


class InterruptorManager(SyncManager):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


for batch in collector:
assert (
batch["collector"]["traj_ids"][
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rewrite this in a more straightforward way, it's a bit hard to read
What if none is -1? Shouldn't we test that we have at least one traj_id set to -1 in the batch?

Copy link
Contributor Author

@albertbou92 albertbou92 Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the code for more clarity.

Regarding the possibility of having none -1, I have addressed that as well by setting the preemptive threshold to 0.0 in the test instead of 0.25. This way, all collectors will stop after the first iteration and only the very first frame of each sync collector will be valid. So we know for sure there are -1's in the batch.

@vmoens
Copy link
Contributor

vmoens commented Mar 16, 2023

Can you merge main in this branch, the tests are failing bc they're looking for a deprecated function in tensordict

@vmoens vmoens added the enhancement New feature or request label Mar 17, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you have a look at how it would play out with the distributed collectors I integrated yesterday?
We don't necessarily need everything to be fully compatible but I want to make sure that we're not missing some obvious point, e.g. how we handle the traj-ids in the distributed collector may need a bit of refactoring.

(I will merge this sooner than that as the solution is neat and usable as of now)

@@ -1399,6 +1462,18 @@ def iterator(self) -> Iterator[TensorDictBase]:

i += 1
max_traj_idx = None

if self.interruptor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for clarity can we have if self.interruptor is not None?

@vmoens vmoens merged commit e79f15e into pytorch:main Mar 17, 2023
@albertbou92 albertbou92 deleted the multicollector_interruptor branch January 18, 2024 10:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants