Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Add local search sampler #208

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft

Conversation

hyeok9855
Copy link
Collaborator

@hyeok9855 hyeok9855 commented Oct 29, 2024

This is a draft PR for adding LocalSearchSampler.
The local search is based on the work [1] and [2].

Test in hypergrid env:

python tutorials/examples/train_hypergrid_simple_ls.py --no_cuda

@hyeok9855 hyeok9855 self-assigned this Oct 29, 2024
@hyeok9855 hyeok9855 marked this pull request as draft October 29, 2024 11:03
@hyeok9855 hyeok9855 force-pushed the hyeok9855/local-search branch 2 times, most recently from c2d59b3 to c6b9f64 Compare October 29, 2024 17:15
@josephdviviano
Copy link
Collaborator

@hyeok9855 can you fix the merge conflicts?

@josephdviviano
Copy link
Collaborator

I noticed you are using force-push -- be careful with this, it can put the code in a state hard to resolve with the rest of the work.

https://www.gitkraken.com/learn/git/problems/git-push-force#:~:text=The%20Risks%20of%20Git%20Push%20Force&text=Because%20you%20have%20failed%20to,deleting%20your%20team%20member's%20work.

@hyeok9855
Copy link
Collaborator Author

I fixed an issue in the backward mask!

Is there anything necessary to do next?

Comment on lines +488 to +497
# TODO: Implement Metropolis-Hastings acceptance criterion.
# p(x->s'->x') = p_B(x->s')p_F(s'->x')
# p(x'->s'->x) = p_B(x'->s')p_F(s'->x)
# The acceptance ratio is
# min(1, R(x')p(x->s'->x') / R(x)p(x'->s'-> x))
# Note that
# p(x->s'->x') / p(x'->s'-> x))
# = p_B(x->s')p_F(s'->x') / p_B(x'->s')p_F(s'->x)
# = p_B(x->s'->s0)p_F(s0->s'->x') / p_B(x'->s'->s0)p_F(s0->s'->x)
# = p_B(tau|x)p_F(tau') / p_B(tau'|x')p_F(tau)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how Metropolis-Hastings works.
In the last line, we need both log_pfs and log_pbs of each trajectory.

Comment on lines +326 to +328
if use_metropolis_hastings:
### FIXME: I realize that the trajectory needs to be reverted to get the forward probability.
### TODO: Resolve the issue first https://github.com/GFNOrg/torchgfn/issues/109
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backward trajectory (e.g., [2, 1] -> [1, 1] -> [1, 0] -> [0, 0] -> [0, 0] (dummy) ...) is quite different from the forward trajectory (e.g., [0, 0] -> [1, 0] -> [1, 1] -> [2, 1] -> [-1, -1] (terminal) -> [-1, -1] -> ...).

Thus, we cannot compute pf for bwd. traj. in the same way as in TrajectoryBasedGFlowNet which calculates pf for fwd. traj.

I believe there must be some other way to resolve this, but I guess it would be the simplest if we can revert the traj. directions, check this related issue #109.

Comment on lines +324 to +364
# Calculate the forward probability if needed (metropolis-hastings).
### COPIED FROM `TrajectoryBasedGFlowNet.get_pfs_and_pbs` ###
if use_metropolis_hastings:
### FIXME: I realize that the trajectory needs to be reverted to get the forward probability.
### TODO: Resolve the issue first https://github.com/GFNOrg/torchgfn/issues/109
valid_states = backward_trajectories.states[
~backward_trajectories.states.is_sink_state
]
valid_actions = backward_trajectories.actions[
~backward_trajectories.actions.is_dummy
]

if backward_trajectories.conditioning is not None:
cond_dim = (-1,) * len(backward_trajectories.conditioning.shape)
traj_len = backward_trajectories.states.tensor.shape[0]
masked_cond = backward_trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~backward_trajectories.states.is_sink_state]

# Here, we pass all valid states, i.e., non-sink states.
with has_conditioning_exception_handler("pf", self.estimator):
estimator_outputs = self.estimator(valid_states, masked_cond)
else:
# Here, we pass all valid states, i.e., non-sink states.
with no_conditioning_exception_handler("pf", self.estimator):
estimator_outputs = self.estimator(valid_states)

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = self.estimator.to_probability_distribution(
valid_states, estimator_outputs
).log_prob(
valid_actions.tensor
) # Using the actions sampled off-policy.
log_pf_backward_trajectories = torch.full_like(
backward_trajectories.actions.tensor[..., 0],
fill_value=0.0,
dtype=torch.float,
)
log_pf_backward_trajectories[
~backward_trajectories.actions.is_dummy
] = valid_log_pf_actions
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is copied from TrajectoryBasedGFlowNet.get_pfs_and_pbs.

I couldn't find any other way to calculate the log_pfs or log_pbs, but this is too ugly and too much redundancies.

We need more modular functionality to minimize the redundancy.

Comment on lines +431 to +486
### COPIED FROM `TrajectoryBasedGFlowNet.get_pfs_and_pbs` ###
if use_metropolis_hastings:
valid_states = new_trajectories.states[
~new_trajectories.states.is_sink_state
]
valid_actions = new_trajectories.actions[~new_trajectories.actions.is_dummy]

non_initial_valid_states = valid_states[~valid_states.is_initial_state]
non_exit_valid_actions = valid_actions[~valid_actions.is_exit]

# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
if new_trajectories.conditioning is not None:
# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(new_trajectories.conditioning.shape)
traj_len = new_trajectories.states.tensor.shape[0]
masked_cond = new_trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~new_trajectories.states.is_sink_state][
~valid_states.is_initial_state
]

# Pass all valid states, i.e., non-sink states, except the initial state.
with has_conditioning_exception_handler(
"pb", self.backward_sampler.estimator
):
estimator_outputs = self.backward_sampler.estimator(
non_initial_valid_states, masked_cond
)
else:
# Pass all valid states, i.e., non-sink states, except the initial state.
with no_conditioning_exception_handler(
"pb", self.backward_sampler.estimator
):
estimator_outputs = self.backward_sampler.estimator(
non_initial_valid_states
)

valid_log_pb_actions = (
self.backward_sampler.estimator.to_probability_distribution(
non_initial_valid_states, estimator_outputs
).log_prob(non_exit_valid_actions.tensor)
)

log_pb_new_trajectories = torch.full_like(
new_trajectories.actions.tensor[..., 0],
fill_value=0.0,
dtype=torch.float,
)
log_pb_new_trajectories_slice = torch.full_like(
valid_actions.tensor[..., 0], fill_value=0.0, dtype=torch.float
)
log_pb_new_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions
log_pb_new_trajectories[
~new_trajectories.actions.is_dummy
] = log_pb_new_trajectories_slice
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DITTO

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants