-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Draft] Add local search sampler #208
base: master
Are you sure you want to change the base?
Conversation
c2d59b3
to
c6b9f64
Compare
c6b9f64
to
1ccf16c
Compare
@hyeok9855 can you fix the merge conflicts? |
I noticed you are using |
I fixed an issue in the backward mask! Is there anything necessary to do next? |
# TODO: Implement Metropolis-Hastings acceptance criterion. | ||
# p(x->s'->x') = p_B(x->s')p_F(s'->x') | ||
# p(x'->s'->x) = p_B(x'->s')p_F(s'->x) | ||
# The acceptance ratio is | ||
# min(1, R(x')p(x->s'->x') / R(x)p(x'->s'-> x)) | ||
# Note that | ||
# p(x->s'->x') / p(x'->s'-> x)) | ||
# = p_B(x->s')p_F(s'->x') / p_B(x'->s')p_F(s'->x) | ||
# = p_B(x->s'->s0)p_F(s0->s'->x') / p_B(x'->s'->s0)p_F(s0->s'->x) | ||
# = p_B(tau|x)p_F(tau') / p_B(tau'|x')p_F(tau) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is how Metropolis-Hastings works.
In the last line, we need both log_pfs and log_pbs of each trajectory.
if use_metropolis_hastings: | ||
### FIXME: I realize that the trajectory needs to be reverted to get the forward probability. | ||
### TODO: Resolve the issue first https://github.com/GFNOrg/torchgfn/issues/109 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backward trajectory (e.g., [2, 1] -> [1, 1] -> [1, 0] -> [0, 0] -> [0, 0] (dummy) ...)
is quite different from the forward trajectory (e.g., [0, 0] -> [1, 0] -> [1, 1] -> [2, 1] -> [-1, -1] (terminal) -> [-1, -1] -> ...)
.
Thus, we cannot compute pf for bwd. traj. in the same way as in TrajectoryBasedGFlowNet
which calculates pf for fwd. traj.
I believe there must be some other way to resolve this, but I guess it would be the simplest if we can revert the traj. directions, check this related issue #109.
# Calculate the forward probability if needed (metropolis-hastings). | ||
### COPIED FROM `TrajectoryBasedGFlowNet.get_pfs_and_pbs` ### | ||
if use_metropolis_hastings: | ||
### FIXME: I realize that the trajectory needs to be reverted to get the forward probability. | ||
### TODO: Resolve the issue first https://github.com/GFNOrg/torchgfn/issues/109 | ||
valid_states = backward_trajectories.states[ | ||
~backward_trajectories.states.is_sink_state | ||
] | ||
valid_actions = backward_trajectories.actions[ | ||
~backward_trajectories.actions.is_dummy | ||
] | ||
|
||
if backward_trajectories.conditioning is not None: | ||
cond_dim = (-1,) * len(backward_trajectories.conditioning.shape) | ||
traj_len = backward_trajectories.states.tensor.shape[0] | ||
masked_cond = backward_trajectories.conditioning.unsqueeze(0).expand( | ||
(traj_len,) + cond_dim | ||
)[~backward_trajectories.states.is_sink_state] | ||
|
||
# Here, we pass all valid states, i.e., non-sink states. | ||
with has_conditioning_exception_handler("pf", self.estimator): | ||
estimator_outputs = self.estimator(valid_states, masked_cond) | ||
else: | ||
# Here, we pass all valid states, i.e., non-sink states. | ||
with no_conditioning_exception_handler("pf", self.estimator): | ||
estimator_outputs = self.estimator(valid_states) | ||
|
||
# Calculates the log PF of the actions sampled off policy. | ||
valid_log_pf_actions = self.estimator.to_probability_distribution( | ||
valid_states, estimator_outputs | ||
).log_prob( | ||
valid_actions.tensor | ||
) # Using the actions sampled off-policy. | ||
log_pf_backward_trajectories = torch.full_like( | ||
backward_trajectories.actions.tensor[..., 0], | ||
fill_value=0.0, | ||
dtype=torch.float, | ||
) | ||
log_pf_backward_trajectories[ | ||
~backward_trajectories.actions.is_dummy | ||
] = valid_log_pf_actions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is copied from TrajectoryBasedGFlowNet.get_pfs_and_pbs
.
I couldn't find any other way to calculate the log_pfs
or log_pbs
, but this is too ugly and too much redundancies.
We need more modular functionality to minimize the redundancy.
### COPIED FROM `TrajectoryBasedGFlowNet.get_pfs_and_pbs` ### | ||
if use_metropolis_hastings: | ||
valid_states = new_trajectories.states[ | ||
~new_trajectories.states.is_sink_state | ||
] | ||
valid_actions = new_trajectories.actions[~new_trajectories.actions.is_dummy] | ||
|
||
non_initial_valid_states = valid_states[~valid_states.is_initial_state] | ||
non_exit_valid_actions = valid_actions[~valid_actions.is_exit] | ||
|
||
# Using all non-initial states, calculate the backward policy, and the logprobs | ||
# of those actions. | ||
if new_trajectories.conditioning is not None: | ||
# We need to index the conditioning vector to broadcast over the states. | ||
cond_dim = (-1,) * len(new_trajectories.conditioning.shape) | ||
traj_len = new_trajectories.states.tensor.shape[0] | ||
masked_cond = new_trajectories.conditioning.unsqueeze(0).expand( | ||
(traj_len,) + cond_dim | ||
)[~new_trajectories.states.is_sink_state][ | ||
~valid_states.is_initial_state | ||
] | ||
|
||
# Pass all valid states, i.e., non-sink states, except the initial state. | ||
with has_conditioning_exception_handler( | ||
"pb", self.backward_sampler.estimator | ||
): | ||
estimator_outputs = self.backward_sampler.estimator( | ||
non_initial_valid_states, masked_cond | ||
) | ||
else: | ||
# Pass all valid states, i.e., non-sink states, except the initial state. | ||
with no_conditioning_exception_handler( | ||
"pb", self.backward_sampler.estimator | ||
): | ||
estimator_outputs = self.backward_sampler.estimator( | ||
non_initial_valid_states | ||
) | ||
|
||
valid_log_pb_actions = ( | ||
self.backward_sampler.estimator.to_probability_distribution( | ||
non_initial_valid_states, estimator_outputs | ||
).log_prob(non_exit_valid_actions.tensor) | ||
) | ||
|
||
log_pb_new_trajectories = torch.full_like( | ||
new_trajectories.actions.tensor[..., 0], | ||
fill_value=0.0, | ||
dtype=torch.float, | ||
) | ||
log_pb_new_trajectories_slice = torch.full_like( | ||
valid_actions.tensor[..., 0], fill_value=0.0, dtype=torch.float | ||
) | ||
log_pb_new_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions | ||
log_pb_new_trajectories[ | ||
~new_trajectories.actions.is_dummy | ||
] = log_pb_new_trajectories_slice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DITTO
This is a draft PR for adding
LocalSearchSampler
.The local search is based on the work [1] and [2].
Test in hypergrid env: