Skip to content

Commit

Permalink
Revert past six commits
Browse files Browse the repository at this point in the history
This patch reverts the past six commits as they were accidentally pushed
to main.
  • Loading branch information
boomanaiden154 committed Nov 21, 2024
1 parent 31aaf6b commit f6ce59a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 57 deletions.
33 changes: 21 additions & 12 deletions compiler_opt/rl/generate_bc_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

import concurrent.futures
import contextlib
import functools
import gin
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Generator, Union
import json

from absl import app
from absl import flags
from absl import logging
import bisect
Expand All @@ -38,13 +40,15 @@
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step
from tf_agents.specs import tensor_spec
from tf_agents.system import system_multiprocessing as multiprocessing

from compiler_opt.rl import corpus
from compiler_opt.rl import env

from compiler_opt.distributed import worker
from compiler_opt.distributed import buffered_scheduler
from compiler_opt.distributed.local import local_worker_manager

from compiler_opt.tools import generate_test_model # pylint:disable=unused-import

flags.FLAGS['gin_files'].allow_override = True
Expand Down Expand Up @@ -339,18 +343,11 @@ def __init__(
kwargs.pop('reward_key', None)
self._working_dir = None

class MlgoTaskWrapper(mlgo_task_type):
# TODO(391): mlgo_task_type is not gin configurable at env.py
# since it is spawned in a separate process from the main thread
def __init__(self):
super().__init__(**kwargs)

self._env = env.MLGOEnvironmentBase(
clang_path=clang_path,
task_type=MlgoTaskWrapper,
task_type=mlgo_task_type,
obs_spec=obs_spec,
action_spec=action_spec,
interactive_only=True,
)
if self._env.action_spec:
if self._env.action_spec.dtype != tf.int64:
Expand Down Expand Up @@ -531,7 +528,6 @@ def explore_at_state_generator(
exploration and can be used for deciding which actions to explore at
the exploration state.
num_samples: the number of samples to generate
num_samples: the number of samples to generate
Yields:
base_seq: a tf.train.SequenceExample containing a compiled trajectory
Expand Down Expand Up @@ -729,6 +725,7 @@ class ModuleWorker(worker.Worker):
exploration_frac: how often to explore in a trajectory
max_exploration_steps: maximum number of exploration steps
tf_policy_action: list of the action/advice function from loaded policies
exploration_policy_paths: paths to load exploration policies.
explore_on_features: dict of feature names and functions which specify
when to explore on the respective feature
obs_action_specs: optional observation spec annotating TimeStep
Expand All @@ -745,7 +742,7 @@ def __init__(
clang_path: str = gin.REQUIRED,
mlgo_task_type: Type[env.MLGOTask] = gin.REQUIRED,
policy_paths: List[Optional[str]] = [],
exploration_frac: float = 1.0,
exploration_frac: float = gin.REQUIRED,
max_exploration_steps: int = 7,
callable_policies: List[Optional[Callable[[Any], np.ndarray]]] = [],
exploration_policy_paths: Optional[str] = None,
Expand Down Expand Up @@ -872,7 +869,7 @@ def select_best_exploration(
def gen_trajectories(
# pylint: disable=dangerous-default-value
data_path: str = gin.REQUIRED,
delete_flags: Tuple[str, ...] = ('',),
delete_flags: Tuple[str, ...] = gin.REQUIRED,
output_file_name: str = gin.REQUIRED,
output_path: str = gin.REQUIRED,
mlgo_task_type: Type[env.MLGOTask] = gin.REQUIRED,
Expand All @@ -883,7 +880,7 @@ def gen_trajectories(
num_output_files: int = 1,
profiling_file_path: Optional[str] = None,
worker_wait_sec: int = 100,
worker_class_type: Type[ModuleWorker] = ModuleWorker,
worker_class_type=ModuleWorker,
worker_manager_class=local_worker_manager.LocalWorkerPoolManager,
):
"""Generates all trajectories for imitation learning training.
Expand Down Expand Up @@ -1016,3 +1013,15 @@ def gen_trajectories(
modules_processed,
time_compiler_calls,
)


def main(_):
gin.parse_config_files_and_bindings(
FLAGS.gin_files, bindings=FLAGS.gin_bindings, skip_unknown=True)
logging.info(gin.config_str())

gen_trajectories()


if __name__ == '__main__':
multiprocessing.handle_main(functools.partial(app.run, main))
43 changes: 0 additions & 43 deletions compiler_opt/rl/generate_bc_trajectories_main.py

This file was deleted.

4 changes: 2 additions & 2 deletions compiler_opt/rl/inlining/gin_configs/common.gin
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import compiler_opt.rl.inlining.env
config_registry.get_configuration.implementation=@configs.InliningConfig

launcher_path=None
llvm_size_path='/usr/local/google/home/tvmarinov/Documents/mlgo_compiler_opt/inlining/runfolder.rundir/llvm-size'
clang_path='/usr/local/google/home/tvmarinov/Documents/mlgo_compiler_opt/inlining/chrome_on_android/chromium2/src/third_party/llvm-build/tflite_build_cold/bin/clang'
llvm_size_path=None
clang_path=None

runners.InliningRunner.llvm_size_path=%llvm_size_path
runners.InliningRunner.clang_path=%clang_path
Expand Down

0 comments on commit f6ce59a

Please sign in to comment.