Skip to content

Commit

Permalink
add missing reset_global_context()
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-petrenko committed Sep 17, 2024
1 parent 12cd0f6 commit a3479be
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ check-codestyle:
.PHONY: test

test:
pytest -s --maxfail=2
pytest -s --maxfail=2 -rA
# ; echo "Tests finished. You might need to type 'reset' and press Enter to fix the terminal window"


Expand Down
4 changes: 4 additions & 0 deletions sample_factory/algo/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def set_global_context(ctx: SampleFactoryContext):


def reset_global_context():
"""
Most useful in tests, call this after any part of the global context has been modified
by a test in any way.
"""
global GLOBAL_CONTEXT
GLOBAL_CONTEXT = SampleFactoryContext()

Expand Down
1 change: 0 additions & 1 deletion sample_factory/model/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(self):
Optional custom functions for creating parts of the model (encoders, decoders, etc.), or
even overriding the entire actor-critic with a custom model.
"""

self.make_actor_critic_func: MakeActorCriticFunc = default_make_actor_critic_func

# callables user can specify to generate parts of the policy
Expand Down
3 changes: 3 additions & 0 deletions tests/envs/pettingzoo/test_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from sample_factory.algo.utils.context import reset_global_context
from sample_factory.algo.utils.misc import ExperimentStatus
from sample_factory.train import run_rl
from sample_factory.utils.utils import log
Expand All @@ -15,6 +16,8 @@ class TestPettingZooEnv:
@pytest.fixture(scope="class", autouse=True)
def register_pettingzoo_fixture(self):
register_custom_components()
yield # this is where the actual test happens
reset_global_context()

# noinspection PyUnusedLocal
@staticmethod
Expand Down

0 comments on commit a3479be

Please sign in to comment.