-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom action distributions #5164
Conversation
Test FAILed. |
Test PASSed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! I think two more additions would make this feature more discoverable by users:
- Add an end-to-end runnable example in
rllib/examples/custom_action_dist.py
- Update
rllib-models.rst
to include a section on custom action distributions, and also updaterllib-examples.rst
to link to the example script.
@@ -69,6 +72,25 @@ def sampled_action_prob(self): | |||
"""Returns the log probability of the sampled action.""" | |||
return tf.exp(self.logp(self.sample_op)) | |||
|
|||
@DeveloperAPI | |||
@staticmethod | |||
def parameter_shape_for_action_space(action_space, model_config=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about required_model_output_size
?
@@ -69,6 +72,25 @@ def sampled_action_prob(self): | |||
"""Returns the log probability of the sampled action.""" | |||
return tf.exp(self.logp(self.sample_op)) | |||
|
|||
@DeveloperAPI | |||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@staticmethod | |
@classmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really understand this suggestion. Why do you think this should be a class method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm I guess staticmethod is fine, since you don't really need the class.
model_config (dict): Model's config dict (as defined in catalog.py) | ||
|
||
Returns: | ||
dist_dim (int or np.ndarray of ints): size of the required |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dist_dim (int or np.ndarray of ints): size of the required | |
model_output_size (int or np.ndarray of ints): size of the required |
""" | ||
|
||
@DeveloperAPI | ||
def __init__(self, inputs): | ||
def __init__(self, inputs, model_config=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid accidents where we forget to pass this, consider making it a required argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I lean towards making it required but this may break other users' code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is DeveloperAPI I would say it's ok to err on the side of avoiding bugs vs backwards compatibility.
@staticmethod | ||
@override(ActionDistribution) | ||
def parameter_shape_for_action_space(action_space, model_config=None): | ||
return action_space.shape[0] * 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.product(action_space.shape)? here and elsewhere?
python/ray/rllib/models/catalog.py
Outdated
return partial(MultiCategorical, input_lens=action_space.nvec), \ | ||
int(sum(action_space.nvec)) | ||
|
||
return dist, dist.parameter_shape_for_action_space( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like now you could simplify this to return just dist
(or this cleanup can be done later).
Will probably be another day or two before I get a chance to work on the stuff you mentioned. |
@mawright any updates? Let us know if you need help. |
Sorry, have been busy finishing my dissertation the last few weeks and have a busy week ahead of me with neurips reviewer responses. This is still on my to do list. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
The Travis check said this commit errored on a Python 2.7 build but I can't figure out why. From what I can tell the error is occurring in a C++ file: https://travis-ci.com/ray-project/ray/jobs/222169300#L1427 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
The travis tests look ok to me, but not sure if jenkins is fully passing still. Going to wait for the latest build before merging. |
jenkins retest this please |
Test PASSed. |
Test PASSed. |
* custom action dist wip * Test case for custom action dist * ActionDistribution.get_parameter_shape_for_action_space pattern * Edit exception message to also suggest using a custom action distribution * Clean up ModelCatalog.get_action_dist * Pass model config to ActionDistribution constructors * Update custom action distribution test case * Name fix * Autoformatter * parameter shape static methods for torch distributions * Fix docstring * Generalize fake array for graph initialization * Fix action dist constructors * Correct parameter shape static methods for multicategorical and gaussian * Make suggested changes to custom action dist's * Correct instances of not passing model config to action dist * Autoformatter * fix tuple distribution constructor * bugfix
What do these changes do?
Adds support for custom action distributions. They are registered to and looked up from the same global "ModelCatalog" class that are currently used for custom models and preprocessors.
There remain some issues with the handling of the action tensors in preexisting code that I will mention in a new Github issue.
Related issue number
#4895
Linter
scripts/format.sh
to lint the changes in this PR.