Skip to content

Commit

Permalink
[BC] Update explore_at_state_generator with new exploration strategy (
Browse files Browse the repository at this point in the history
#389)

Updating the exploration strategy as discussed with @alekh.
The strategy works as follows. At the current exploration state,
the exploration policy will return an array of logits. The logit of
the action currently taken is set to -np.Inf and an exploration action
is sampled from the remaining available actions. This update required
updating the tests for ModuleExplorerWorker. Additionally, added
module_explorer_type to ModuleWorker to enable overriding the
exploration strategy by overriding the ModuleExplorer class.
  • Loading branch information
tvmarino authored Nov 18, 2024
1 parent 8d152bc commit 07405ed
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 22 deletions.
45 changes: 30 additions & 15 deletions compiler_opt/rl/generate_bc_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

import math
import numpy as np
import scipy
import scipy.special
import tensorflow as tf
from tf_agents import policies
from tf_agents.typing import types as tf_types
Expand Down Expand Up @@ -502,10 +504,13 @@ def explore_function(
return seq_example_list, working_dir_names, loss_idx, base_seq_loss

def explore_at_state_generator(
self, replay_prefix: List[np.ndarray], explore_step: int,
self,
replay_prefix: List[np.ndarray],
explore_step: int,
explore_state: time_step.TimeStep,
policy: Callable[[Optional[time_step.TimeStep]], np.ndarray],
explore_policy: Callable[[time_step.TimeStep], policy_step.PolicyStep]
explore_policy: Callable[[time_step.TimeStep], policy_step.PolicyStep],
num_samples: int = 1,
) -> Generator[Tuple[tf.train.SequenceExample, ExplorationWithPolicy], None,
None]:
"""Generate sequence examples and next exploration policy while exploring.
Expand All @@ -522,21 +527,28 @@ def explore_at_state_generator(
explore_policy: randomized policy which is used to compute the gap for
exploration and can be used for deciding which actions to explore at
the exploration state.
num_samples: the number of samples to generate
Yields:
base_seq: a tf.train.SequenceExample containing a compiled trajectory
base_policy: the policy used to determine the next exploration step
"""
del explore_state
replay_prefix[explore_step] = 1 - replay_prefix[explore_step]
base_policy = ExplorationWithPolicy(
replay_prefix,
policy,
explore_policy,
self._explore_on_features,
)
base_seq = self.compile_module(base_policy.get_advice)
yield base_seq, base_policy

distr_logits = explore_policy(explore_state).action.logits.numpy()[0]
for _ in range(num_samples):
distr_logits[replay_prefix[explore_step]] = -np.Inf
if all(-np.Inf == logit for logit in distr_logits):
break
replay_prefix[explore_step] = np.random.choice(
range(distr_logits.shape[0]), p=scipy.special.softmax(distr_logits))
base_policy = ExplorationWithPolicy(
replay_prefix,
policy,
explore_policy,
self._explore_on_features,
)
base_seq = self.compile_module(base_policy.get_advice)
yield base_seq, base_policy

def _build_replay_prefix_list(self, seq_ex):
ret_list = []
Expand Down Expand Up @@ -703,6 +715,7 @@ class ModuleWorker(worker.Worker):
by the exploration policy if given.
Attributes:
module_explorer_type: type of the module explorer
clang_path: path to clang
mlgo_task_type: the type of compilation task
policy_paths: list of policies to load and use for forming the trajectories
Expand All @@ -718,13 +731,14 @@ class ModuleWorker(worker.Worker):
obs_action_specs: optional observation spec annotating TimeStep
base_path: root path to save best compiled binaries for linking
partitions: a tuple of limits defining the buckets, see partition_for_loss
env_args: additional arguments to pass to the InliningTask, used in creating
the environment. This has to include the reward_key
env_args: additional arguments to pass to the ModuleExplorer, used in
creating the environment. This has to include the reward_key
"""

def __init__(
# pylint: disable=dangerous-default-value
self,
module_explorer_type: Type[ModuleExplorer] = ModuleExplorer,
clang_path: str = gin.REQUIRED,
mlgo_task_type: Type[env.MLGOTask] = gin.REQUIRED,
policy_paths: List[Optional[str]] = [],
Expand All @@ -746,6 +760,7 @@ def __init__(
raise AssertionError("""At least one policy needs to be specified in
policy paths or callable_policies""")
logging.info('Environment args: %s', envargs)
self._module_explorer_type: Type[ModuleExplorer] = module_explorer_type
self._clang_path: str = clang_path
self._mlgo_task_type: Type[env.MLGOTask] = mlgo_task_type
self._policy_paths: List[Optional[str]] = policy_paths
Expand Down Expand Up @@ -807,7 +822,7 @@ def select_best_exploration(
logging.info('Processing module: %s', loaded_module_spec.name)
start = timeit.default_timer()
work = list(zip(self._tf_policy_action, self._exploration_policy_distrs))
exploration_worker = ModuleExplorer(
exploration_worker = self._module_explorer_type(
loaded_module_spec=loaded_module_spec,
clang_path=self._clang_path,
mlgo_task_type=self._mlgo_task_type,
Expand Down
17 changes: 10 additions & 7 deletions compiler_opt/rl/generate_bc_trajectories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def _get_seq_example_list_comp(self):
generate_bc_trajectories.add_float_feature(seq_example_comp, 47.0,
'reward')
if i == 4:
generate_bc_trajectories.add_int_feature(seq_example_comp, -3, 'action')
generate_bc_trajectories.add_int_feature(seq_example_comp, 0, 'action')
else:
generate_bc_trajectories.add_int_feature(seq_example_comp, np.mod(i, 5),
'action')
Expand All @@ -401,9 +401,9 @@ def _get_seq_example_list_comp(self):
generate_bc_trajectories.add_float_feature(seq_example_comp, 47.0,
'reward')
if i == 4:
generate_bc_trajectories.add_int_feature(seq_example_comp, -3, 'action')
generate_bc_trajectories.add_int_feature(seq_example_comp, 0, 'action')
elif i == 5:
generate_bc_trajectories.add_int_feature(seq_example_comp, 1, 'action')
generate_bc_trajectories.add_int_feature(seq_example_comp, 4, 'action')
else:
generate_bc_trajectories.add_int_feature(seq_example_comp, np.mod(i, 5),
'action')
Expand All @@ -419,11 +419,11 @@ def _get_seq_example_list_comp(self):
generate_bc_trajectories.add_float_feature(seq_example_comp, 47.0,
'reward')
if i == 4:
generate_bc_trajectories.add_int_feature(seq_example_comp, -3, 'action')
generate_bc_trajectories.add_int_feature(seq_example_comp, 0, 'action')
elif i == 5:
generate_bc_trajectories.add_int_feature(seq_example_comp, 1, 'action')
generate_bc_trajectories.add_int_feature(seq_example_comp, 4, 'action')
elif i == 9:
generate_bc_trajectories.add_int_feature(seq_example_comp, -3, 'action')
generate_bc_trajectories.add_int_feature(seq_example_comp, 0, 'action')
else:
generate_bc_trajectories.add_int_feature(seq_example_comp, np.mod(i, 5),
'action')
Expand All @@ -450,7 +450,10 @@ def _explore_policy(state: time_step.TimeStep):
# will explore every 4-th step
logits = [[
4.0 + 1e-3 * float(env_test._NUM_STEPS - times_called),
float(np.mod(times_called, 5))
-np.Inf,
-np.Inf,
-np.Inf,
float(np.mod(times_called, 5)),
]]
return policy_step.PolicyStep(
action=tfp.distributions.Categorical(logits=logits))
Expand Down

0 comments on commit 07405ed

Please sign in to comment.