Skip to content

Commit

Permalink
Handle gin.bind_parameter in tests (#395)
Browse files Browse the repository at this point in the history
If a test needs to `gin.bind_parameter` it should also clear that, or
make the binding within a `gin.config_scope`, otherwise the bindings may
leak to other tests if all the tests are run by e.g. pytest.

This patch uses both approaches: for tests that set up bindings in `setUp`
we clear in `tearDown`, else we use the scoped approach.

Issue #394
  • Loading branch information
mtrofin authored Dec 5, 2024
1 parent 82dc72a commit 14edc8a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 30 deletions.
13 changes: 7 additions & 6 deletions compiler_opt/distributed/worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def __init__(self, argument):
class WorkerTest(absltest.TestCase):

def test_gin_args(self):
with gin.unlock_config():
gin.bind_parameter('_test.SomeType.argument', 42)
real_args = worker.get_full_worker_args(
SomeType, more_args=2, even_more_args='hi')
self.assertDictEqual(real_args,
dict(argument=42, more_args=2, even_more_args='hi'))
with gin.config_scope('worker_test'):
with gin.unlock_config():
gin.bind_parameter('_test.SomeType.argument', 42)
real_args = worker.get_full_worker_args(
SomeType, more_args=2, even_more_args='hi')
self.assertDictEqual(real_args,
dict(argument=42, more_args=2, even_more_args='hi'))


if __name__ == '__main__':
Expand Down
3 changes: 3 additions & 0 deletions compiler_opt/es/blackbox_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
class BlackboxLearnerTests(absltest.TestCase):
"""Tests for blackbox_learner"""

def tearDown(self):
gin.clear_config()

def setUp(self):
super().setUp()

Expand Down
56 changes: 32 additions & 24 deletions compiler_opt/rl/agent_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,42 @@ def setUp(self):
super().setUp()

def test_create_behavioral_cloning_agent(self):
gin.bind_parameter('create_agent.policy_network', q_network.QNetwork)
gin.bind_parameter('BehavioralCloningAgent.optimizer',
tf.compat.v1.train.AdamOptimizer())
tf_agent = agent_config.create_agent(
agent_config.BCAgentConfig(
time_step_spec=self._time_step_spec, action_spec=self._action_spec),
preprocessing_layer_creator=_observation_processing_layer)
self.assertIsInstance(tf_agent,
behavioral_cloning_agent.BehavioralCloningAgent)
with gin.config_scope('test_create_behavioral_cloning_agent'):
gin.bind_parameter('create_agent.policy_network', q_network.QNetwork)
gin.bind_parameter('BehavioralCloningAgent.optimizer',
tf.compat.v1.train.AdamOptimizer())
tf_agent = agent_config.create_agent(
agent_config.BCAgentConfig(
time_step_spec=self._time_step_spec,
action_spec=self._action_spec),
preprocessing_layer_creator=_observation_processing_layer)
self.assertIsInstance(tf_agent,
behavioral_cloning_agent.BehavioralCloningAgent)

def test_create_dqn_agent(self):
gin.bind_parameter('create_agent.policy_network', q_network.QNetwork)
gin.bind_parameter('DqnAgent.optimizer', tf.compat.v1.train.AdamOptimizer())
tf_agent = agent_config.create_agent(
agent_config.DQNAgentConfig(
time_step_spec=self._time_step_spec, action_spec=self._action_spec),
preprocessing_layer_creator=_observation_processing_layer)
self.assertIsInstance(tf_agent, dqn_agent.DqnAgent)
with gin.config_scope('test_create_dqn_agent'):
gin.bind_parameter('create_agent.policy_network', q_network.QNetwork)
gin.bind_parameter('DqnAgent.optimizer',
tf.compat.v1.train.AdamOptimizer())
tf_agent = agent_config.create_agent(
agent_config.DQNAgentConfig(
time_step_spec=self._time_step_spec,
action_spec=self._action_spec),
preprocessing_layer_creator=_observation_processing_layer)
self.assertIsInstance(tf_agent, dqn_agent.DqnAgent)

def test_create_ppo_agent(self):
gin.bind_parameter('create_agent.policy_network',
actor_distribution_network.ActorDistributionNetwork)
gin.bind_parameter('PPOAgent.optimizer', tf.compat.v1.train.AdamOptimizer())
tf_agent = agent_config.create_agent(
agent_config.PPOAgentConfig(
time_step_spec=self._time_step_spec, action_spec=self._action_spec),
preprocessing_layer_creator=_observation_processing_layer)
self.assertIsInstance(tf_agent, ppo_agent.PPOAgent)
with gin.config_scope('test_create_ppo_agent'):
gin.bind_parameter('create_agent.policy_network',
actor_distribution_network.ActorDistributionNetwork)
gin.bind_parameter('PPOAgent.optimizer',
tf.compat.v1.train.AdamOptimizer())
tf_agent = agent_config.create_agent(
agent_config.PPOAgentConfig(
time_step_spec=self._time_step_spec,
action_spec=self._action_spec),
preprocessing_layer_creator=_observation_processing_layer)
self.assertIsInstance(tf_agent, ppo_agent.PPOAgent)


if __name__ == '__main__':
Expand Down

0 comments on commit 14edc8a

Please sign in to comment.