From 3ba7882db10e8ec54d32176113f460033856ef65 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Thu, 19 Aug 2021 03:01:19 +0900 Subject: [PATCH 01/13] feat: expose glu activations as argument --- megatron/arguments.py | 5 +++++ megatron/model/transformer.py | 5 ++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 326c948ee..d7368ea5e 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -313,6 +313,11 @@ def _add_network_size_args(parser): default=PositionEmbeddingType.absolute, help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.' ) + group.add_argument('--glu-activation', type=str, + choices=["liglu", "geglu", "reglu", "swiglu"], + default="", + help='GLU activations to use.' + ) return parser diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e7612b76f..fed3862a7 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -30,6 +30,7 @@ import deepspeed +import activations as glu_activations from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb # flags required to enable jit fusion kernels @@ -76,7 +77,9 @@ def __init__(self, init_method, output_layer_init_method): self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu - if args.openai_gelu: + if args.glu_activation: + self.activation_func = getattr(glu_activations, args.glu_activation) + elif args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu From f004a5aa955a2da041b5ea383a486dfc49ef8471 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Thu, 19 Aug 2021 08:22:04 +0900 Subject: [PATCH 02/13] chore: rename activations -> glu_activations --- megatron/model/{activations.py => glu_activations.py} | 0 megatron/model/transformer.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename megatron/model/{activations.py => glu_activations.py} (100%) diff --git a/megatron/model/activations.py b/megatron/model/glu_activations.py similarity index 100% rename from megatron/model/activations.py rename to megatron/model/glu_activations.py diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index fed3862a7..8874bca6d 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -30,7 +30,7 @@ import deepspeed -import activations as glu_activations +import glu_activations from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb # flags required to enable jit fusion kernels From c064e2e09efa87ea9fa0495d102f77459fda640a Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Thu, 19 Aug 2021 08:26:04 +0900 Subject: [PATCH 03/13] refactor: use lookup dict instead of `getattr()` --- megatron/model/transformer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 8874bca6d..eadc4d624 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -30,7 +30,7 @@ import deepspeed -import glu_activations +from .glu_activations import geglu, liglu, reglu, swiglu from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb # flags required to enable jit fusion kernels @@ -78,7 +78,13 @@ def __init__(self, init_method, output_layer_init_method): self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu if args.glu_activation: - self.activation_func = getattr(glu_activations, args.glu_activation) + glu_lookup = { + "gegelu": geglu, + "liglu": liglu, + "reglu": reglu, + "swiglu": swiglu, + } + self.activation_func = glu_lookup[args.glu_activation] elif args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: From 2944fe436338eb74979ae8b4f51976cb5b9dac78 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Sun, 22 Aug 2021 00:35:18 +0900 Subject: [PATCH 04/13] refactor: mv lookup dict to `glu_activations.py` --- megatron/model/glu_activations.py | 12 ++++++++++-- megatron/model/transformer.py | 10 ++-------- tests/test_activations.py | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/megatron/model/glu_activations.py b/megatron/model/glu_activations.py index 82ccdf098..8705c8d17 100644 --- a/megatron/model/glu_activations.py +++ b/megatron/model/glu_activations.py @@ -7,10 +7,10 @@ class _GLUBaseModule(nn.Module): def __init__(self, activation_fn): super().__init__() self.activation_fn = activation_fn - + def forward(self, x): # dim=-1 breaks in jit for pt<1.10 - x1, x2 = x.chunk(2, dim=(x.ndim-1)) + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) return x1 * self.activation_fn(x2) @@ -38,3 +38,11 @@ def __init__(self): geglu = torch.jit.script(GEGLU()) reglu = torch.jit.script(ReGLU()) swiglu = torch.jit.script(SwiGLU()) + + +GLU_ACTIVATIONS = { + "gegelu": geglu, + "liglu": liglu, + "reglu": reglu, + "swiglu": swiglu, +} diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index eadc4d624..473b8e06b 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -30,7 +30,7 @@ import deepspeed -from .glu_activations import geglu, liglu, reglu, swiglu +from .glu_activations import GLU_ACTIVATIONS from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb # flags required to enable jit fusion kernels @@ -78,13 +78,7 @@ def __init__(self, init_method, output_layer_init_method): self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu if args.glu_activation: - glu_lookup = { - "gegelu": geglu, - "liglu": liglu, - "reglu": reglu, - "swiglu": swiglu, - } - self.activation_func = glu_lookup[args.glu_activation] + self.activation_func = GLU_ACTIVATIONS[args.glu_activation] elif args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: diff --git a/tests/test_activations.py b/tests/test_activations.py index 85c949f4a..4d9225a91 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -4,7 +4,7 @@ import torch from torch.nn import functional as F -from megatron.model.activations import liglu, geglu, reglu, swiglu +from megatron.model.glu_activations import liglu, geglu, reglu, swiglu from megatron.testing_utils import set_seed From 3729ebc22ad59404f98fa22a972d5e57ba67a4a6 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Sun, 22 Aug 2021 02:04:27 +0900 Subject: [PATCH 05/13] chore: rm unnecessary default arg --- megatron/arguments.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index d7368ea5e..91913f11b 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -315,7 +315,6 @@ def _add_network_size_args(parser): ) group.add_argument('--glu-activation', type=str, choices=["liglu", "geglu", "reglu", "swiglu"], - default="", help='GLU activations to use.' ) From b3c6bbe3cac7820f9a127b557c49bfa6ed1a7724 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Mon, 23 Aug 2021 02:24:55 +0900 Subject: [PATCH 06/13] test: add bf16 test; gelu in `test_training_all()` --- megatron/testing_utils.py | 18 ++++++++++++++++++ tests/test_activations.py | 19 +++++++++++++------ tests/test_training.py | 1 + 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index 62991c044..f1fa7187d 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -25,6 +25,7 @@ import random from distutils.util import strtobool from io import StringIO +from packaging import version from pathlib import Path from typing import Iterator, Union from unittest import mock @@ -212,6 +213,23 @@ def torch_assert_equal(actual, expected): torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0) +def is_torch_bf16_available(): + # from https://github.com/huggingface/transformers/blob/26eb566e43148c80d0ea098c76c3d128c0281c16/src/transformers/file_utils.py#L301 + if is_torch_available(): + import torch + if not torch.cuda.is_available() or torch.version.cuda is None: + return False + if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: + return False + if int(torch.version.cuda.split(".")[0]) < 11: + return False + if not version.parse(torch.__version__) > version.parse("1.09"): + return False + return True + else: + return False + + def get_tests_dir(append_path=None): """ Args: diff --git a/tests/test_activations.py b/tests/test_activations.py index be50a3eb0..fc34c88b9 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -4,8 +4,8 @@ import torch from torch.nn import functional as F -from megatron.model.glu_activations import liglu, geglu, reglu, swiglu -from megatron.testing_utils import set_seed, torch_assert_equal +from megatron.model.glu_activations import GLU_ACTIVATIONS, geglu, liglu, reglu, swiglu +from megatron.testing_utils import set_seed, torch_assert_equal, is_torch_bf16_available class TestActivations(unittest.TestCase): @@ -17,13 +17,13 @@ def setUp(self): self.num_channels = random.randint(1, 384) * 2 self.x = torch.randn(self.batch_size, self.seq_len, self.num_channels) self.x1, self.x2 = self.x.chunk(2, dim=-1) + # glu should halve the last dimension + self.output_shape = [self.batch_size, self.seq_len, self.num_channels // 2] def test_shapes(self): - # glu should halve the last dimension - output_shape = [self.batch_size, self.seq_len, self.num_channels // 2] - for activation_fn in [liglu, geglu, reglu, swiglu]: + for activation_fn in GLU_ACTIVATIONS.values(): output = activation_fn(self.x) - self.assertEqual(list(output.shape), output_shape) + self.assertEqual(list(output.shape), self.output_shape) def test_liglu(self): expected = self.x1 * self.x2 @@ -40,3 +40,10 @@ def test_reglu(self): def test_swiglu(self): expected = self.x1 * F.silu(self.x2) torch_assert_equal(swiglu(self.x), expected) + + def test_bf16_jit(self): + if is_torch_bf16_available(): + x_bf16 = self.x.to(torch.bfloat16) + for activation_fn in GLU_ACTIVATIONS.values(): + output = activation_fn(x_bf16) + self.assertEqual(list(output.shape), self.output_shape) diff --git a/tests/test_training.py b/tests/test_training.py index 7306615f1..43b748141 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -102,6 +102,7 @@ def test_training_all(self): --log-timers-to-tensorboard --log-batch-size-to-tensorboard --log-validation-ppl-to-tensorboard + --glu-activation=geglu """.split() ds_args = f""" From ff248f0a5d1781cb48a0c2c37da049ff96773b9d Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Mon, 23 Aug 2021 02:39:01 +0900 Subject: [PATCH 07/13] Update megatron/testing_utils.py Co-authored-by: Stas Bekman --- megatron/testing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index f1fa7187d..d1dbc05dd 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -223,7 +223,7 @@ def is_torch_bf16_available(): return False if int(torch.version.cuda.split(".")[0]) < 11: return False - if not version.parse(torch.__version__) > version.parse("1.09"): + if not version.parse(torch.__version__) >= version.parse("1.09"): return False return True else: From 0056ab7ded2102df9da899875d3ac3e162340b57 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Mon, 23 Aug 2021 02:45:11 +0900 Subject: [PATCH 08/13] refactor: use `require_torch_bf16` decorator --- megatron/testing_utils.py | 8 ++++++++ tests/test_activations.py | 14 +++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index d1dbc05dd..801d74d9b 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -230,6 +230,14 @@ def is_torch_bf16_available(): return False +def require_torch_bf16(test_case): + """Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.9.""" + if not is_torch_bf16_available(): + return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.9")(test_case) + else: + return test_case + + def get_tests_dir(append_path=None): """ Args: diff --git a/tests/test_activations.py b/tests/test_activations.py index fc34c88b9..5e224cbe1 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -5,7 +5,7 @@ from torch.nn import functional as F from megatron.model.glu_activations import GLU_ACTIVATIONS, geglu, liglu, reglu, swiglu -from megatron.testing_utils import set_seed, torch_assert_equal, is_torch_bf16_available +from megatron.testing_utils import set_seed, torch_assert_equal, require_torch_bf16 class TestActivations(unittest.TestCase): @@ -41,9 +41,9 @@ def test_swiglu(self): expected = self.x1 * F.silu(self.x2) torch_assert_equal(swiglu(self.x), expected) - def test_bf16_jit(self): - if is_torch_bf16_available(): - x_bf16 = self.x.to(torch.bfloat16) - for activation_fn in GLU_ACTIVATIONS.values(): - output = activation_fn(x_bf16) - self.assertEqual(list(output.shape), self.output_shape) + @require_torch_bf16 + def test_bf16_jit(self): + x_bf16 = self.x.to(torch.bfloat16) + for activation_fn in GLU_ACTIVATIONS.values(): + output = activation_fn(x_bf16) + self.assertEqual(list(output.shape), self.output_shape) From 4097af6f39b4a9926a13c3220247c6040264cb01 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Mon, 23 Aug 2021 03:51:08 +0900 Subject: [PATCH 09/13] chore: comment out bf16 test uncomment in the future when torch supports gelu kernels for bf16 --- tests/test_activations.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_activations.py b/tests/test_activations.py index 5e224cbe1..a1763d7b4 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -5,7 +5,7 @@ from torch.nn import functional as F from megatron.model.glu_activations import GLU_ACTIVATIONS, geglu, liglu, reglu, swiglu -from megatron.testing_utils import set_seed, torch_assert_equal, require_torch_bf16 +from megatron.testing_utils import set_seed, torch_assert_equal class TestActivations(unittest.TestCase): @@ -41,9 +41,10 @@ def test_swiglu(self): expected = self.x1 * F.silu(self.x2) torch_assert_equal(swiglu(self.x), expected) - @require_torch_bf16 - def test_bf16_jit(self): - x_bf16 = self.x.to(torch.bfloat16) - for activation_fn in GLU_ACTIVATIONS.values(): - output = activation_fn(x_bf16) - self.assertEqual(list(output.shape), self.output_shape) + # from megatron.testing_utils import require_torch_bf16 + # @require_torch_bf16 + # def test_bf16_jit(self): + # x_bf16 = self.x.to(torch.bfloat16) + # for activation_fn in GLU_ACTIVATIONS.values(): + # output = activation_fn(x_bf16) + # self.assertEqual(list(output.shape), self.output_shape) From dc61f89eadf589a8bc6cf805cf8695a5b1ed648d Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 22 Aug 2021 12:00:49 -0700 Subject: [PATCH 10/13] consistent style --- tests/test_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_training.py b/tests/test_training.py index 43b748141..ca4977826 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -102,7 +102,7 @@ def test_training_all(self): --log-timers-to-tensorboard --log-batch-size-to-tensorboard --log-validation-ppl-to-tensorboard - --glu-activation=geglu + --glu-activation geglu """.split() ds_args = f""" From 6798e819a4bf346222209eb1bab1af076ff36e39 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 22 Aug 2021 12:09:17 -0700 Subject: [PATCH 11/13] fix look up table --- megatron/model/glu_activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/model/glu_activations.py b/megatron/model/glu_activations.py index 8705c8d17..9e0eb5b29 100644 --- a/megatron/model/glu_activations.py +++ b/megatron/model/glu_activations.py @@ -41,7 +41,7 @@ def __init__(self): GLU_ACTIVATIONS = { - "gegelu": geglu, + "geglu": geglu, "liglu": liglu, "reglu": reglu, "swiglu": swiglu, From ba8f4041e79fa4f587c03215afc3ca5beeffdc7b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 22 Aug 2021 12:11:26 -0700 Subject: [PATCH 12/13] better grouping --- tests/test_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_training.py b/tests/test_training.py index ca4977826..f0e45beaa 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -90,6 +90,7 @@ def test_training_all(self): --eval-interval 10 --eval-iters 5 --checkpoint-activations + --glu-activation geglu --exit-interval {exit_interval} --merge-file {data_dir}/gpt2-tiny-merges.txt @@ -102,7 +103,6 @@ def test_training_all(self): --log-timers-to-tensorboard --log-batch-size-to-tensorboard --log-validation-ppl-to-tensorboard - --glu-activation geglu """.split() ds_args = f""" From 0f36f705a648441b552c1a8e6f400dc6eded5ea4 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Mon, 23 Aug 2021 04:18:32 +0900 Subject: [PATCH 13/13] fix: replace hard coded options with `GLU_ACTIVATIONS` --- megatron/arguments.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 91913f11b..ba1d0c9a1 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -22,6 +22,7 @@ import deepspeed from megatron.enums import PositionEmbeddingType +from megatron.model.glu_activations import GLU_ACTIVATIONS def parse_args(extra_args_provider=None, defaults={}, @@ -314,7 +315,7 @@ def _add_network_size_args(parser): help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.' ) group.add_argument('--glu-activation', type=str, - choices=["liglu", "geglu", "reglu", "swiglu"], + choices=GLU_ACTIVATIONS.keys(), help='GLU activations to use.' )