diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 151bba868a..ef747aff89 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -977,6 +977,8 @@ def __getitem__(self, key): return super().__getitem__(key) def __contains__(self, k): + if isinstance(k, Compiler): + k = k.name return k in self.keys() or k.startswith('gcc-') diff --git a/tests/test_arch.py b/tests/test_arch.py index 501ff6d8d8..e9af3fa9e5 100644 --- a/tests/test_arch.py +++ b/tests/test_arch.py @@ -1,6 +1,7 @@ import pytest -from devito.arch.compiler import sniff_compiler_version, compiler_registry +from devito import switchconfig, configuration +from devito.arch.compiler import sniff_compiler_version, compiler_registry, GNUCompiler @pytest.mark.parametrize("cc", [ @@ -15,3 +16,16 @@ def test_sniff_compiler_version(cc): @pytest.mark.parametrize("cc", ['gcc-4.9', 'gcc-11', 'gcc', 'gcc-14', 'gcc-123']) def test_gcc(cc): assert cc in compiler_registry + + +def test_switcharch(): + old_compiler = configuration['compiler'] + with switchconfig(compiler='gcc-4.9'): + tmp_comp = configuration['compiler'] + assert isinstance(tmp_comp, GNUCompiler) + assert tmp_comp.suffix == '4.9' + + tmp_comp = configuration['compiler'] + assert isinstance(tmp_comp, old_compiler.__class__) + assert old_compiler.suffix == tmp_comp.suffix + assert old_compiler.name == tmp_comp.name