diff --git a/megatron/arguments.py b/megatron/arguments.py index ba1d0c9a1..5f4e2b53f 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -22,7 +22,7 @@ import deepspeed from megatron.enums import PositionEmbeddingType -from megatron.model.glu_activations import GLU_ACTIVATIONS +import megatron def parse_args(extra_args_provider=None, defaults={}, @@ -315,7 +315,7 @@ def _add_network_size_args(parser): help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.' ) group.add_argument('--glu-activation', type=str, - choices=GLU_ACTIVATIONS.keys(), + choices=megatron.model.glu_activations.GLU_ACTIVATIONS.keys(), help='GLU activations to use.' ) diff --git a/tests/test_activations.py b/tests/test_activations.py index a1763d7b4..34097ad22 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -43,7 +43,7 @@ def test_swiglu(self): # from megatron.testing_utils import require_torch_bf16 # @require_torch_bf16 - # def test_bf16_jit(self): + # def test_bf16_jit(self): # x_bf16 = self.x.to(torch.bfloat16) # for activation_fn in GLU_ACTIVATIONS.values(): # output = activation_fn(x_bf16)