diff --git a/tests/test_model.py b/tests/test_model.py index 3079a429a..d77e2776e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -6,7 +6,7 @@ import torch from megatron import initialize_megatron, get_args, get_tokenizer, global_vars -from megatron.testing_utils import TestCasePlus +from megatron.testing_utils import TestCasePlus, mockenv_context from megatron.training import setup_model_and_optimizer from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe @@ -69,10 +69,6 @@ def equal_vectors(tensor1, tensor2, dim=-1): class MyTestCase(TestCasePlus): - @classmethod - def setUpClass(cls) -> None: - deepspeed.init_distributed() - def setUp(self) -> None: super().setUp() @@ -84,45 +80,51 @@ def setUp(self) -> None: global_vars._GLOBAL_ADLR_AUTORESUME = None global_vars._GLOBAL_TIMERS = None + self.dist_env_1_gpu = dict( + MASTER_ADDR="localhost", MASTER_PORT="9994", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1" + ) + def test_gpt(self): """Test causal invariance, ie past token don't depend on future tokens.""" command_args = get_default_args() with patch('sys.argv', flatten_arguments(command_args)): - initialize_megatron() - args = get_args() - tokenizer = get_tokenizer() + with mockenv_context(**self.dist_env_1_gpu): + deepspeed.init_distributed() + initialize_megatron() + args = get_args() + tokenizer = get_tokenizer() - model, _, _ = setup_model_and_optimizer(gpt_model_provider) - model = model[0] + model, _, _ = setup_model_and_optimizer(gpt_model_provider) + model = model[0] - token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) + token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) - # eod is a special token - token_ids[token_ids == tokenizer.eod] += 1 - token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size + # eod is a special token + token_ids[token_ids == tokenizer.eod] += 1 + token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size - # process batch - input_batch = get_gpt_batch_pipe({"text": token_ids})[0] + # process batch + input_batch = get_gpt_batch_pipe({"text": token_ids})[0] - # get a modified version of the first batch, we change a specific index - changed_index = randint(0, args.seq_length - 2) - input_token_ids_changed = input_batch[0].clone() - # We increment the token_id by one for that index in order to artificially change the sequence. - input_token_ids_changed[:, changed_index] = \ - (input_token_ids_changed[:,changed_index] + 1) % args.padded_vocab_size + # get a modified version of the first batch, we change a specific index + changed_index = randint(0, args.seq_length - 2) + input_token_ids_changed = input_batch[0].clone() + # We increment the token_id by one for that index in order to artificially change the sequence. + input_token_ids_changed[:, changed_index] = \ + (input_token_ids_changed[:,changed_index] + 1) % args.padded_vocab_size - output = model(*input_batch) - output_changed = model(input_token_ids_changed, *input_batch[1:]) + output = model(*input_batch) + output_changed = model(input_token_ids_changed, *input_batch[1:]) - # All token in past should be unchanged - self.assertTrue( - torch.all(equal_vectors(output[:, :changed_index], output_changed[:, :changed_index])) - ) - # All tokens in the future should have changed - self.assertFalse( - torch.any(equal_vectors(output[:, changed_index:], output_changed[:, changed_index:])) - ) + # All token in past should be unchanged + self.assertTrue( + torch.all(equal_vectors(output[:, :changed_index], output_changed[:, :changed_index])) + ) + # All tokens in the future should have changed + self.assertFalse( + torch.any(equal_vectors(output[:, changed_index:], output_changed[:, changed_index:])) + ) def test_prefix_lm_reset_attention_mask(self): """ @@ -137,90 +139,92 @@ def test_prefix_lm_reset_attention_mask(self): command_args["--loss-on-targets-only"] = "" with patch('sys.argv', flatten_arguments(command_args)): - initialize_megatron() - args = get_args() - tokenizer = get_tokenizer() - - model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider) - model = model[0] - - token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) - - # eod is a special token, this also guarantees that the whole row is considered as a document. - token_ids[token_ids == tokenizer.eod] += 1 - token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size - - # process batch to have non empty prefix - for i in range(9, -1, -1): - input_batch, _, prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids}) - if (prefix_indices[0][0] != 0): - break - if i == 0: - # FIXME: find a better way to not obtain empty prefix - raise ValueError("Could not obtain non pathological case where prefix is not empty") - - output = model(*input_batch) - - ## --------------- CHANGE A TARGET TOKEN --------------------------- - # get a modified version of the first batch - # guaranteed to exist as each row has at least one partial document - changed_target_index = prefix_indices[0][0] - token_ids_changed_target = input_batch[0].clone() - # We increment the token id on the changed index. - token_ids_changed_target[0, changed_target_index] = \ - (token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size - # make sure we're not changing a token to eod as it's a special token - token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1 - token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size - - # Test change - output_changed_target = model(token_ids_changed_target, *input_batch[1:]) - - # All token in past should be unchanged - self.assertTrue( - torch.all( - equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index]) + with mockenv_context(**self.dist_env_1_gpu): + deepspeed.init_distributed() + initialize_megatron() + args = get_args() + tokenizer = get_tokenizer() + + model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider) + model = model[0] + + token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) + + # eod is a special token, this also guarantees that the whole row is considered as a document. + token_ids[token_ids == tokenizer.eod] += 1 + token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size + + # process batch to have non empty prefix + for i in range(9, -1, -1): + input_batch, _, prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids}) + if (prefix_indices[0][0] != 0): + break + if i == 0: + # FIXME: find a better way to not obtain empty prefix + raise ValueError("Could not obtain non pathological case where prefix is not empty") + + output = model(*input_batch) + + ## --------------- CHANGE A TARGET TOKEN --------------------------- + # get a modified version of the first batch + # guaranteed to exist as each row has at least one partial document + changed_target_index = prefix_indices[0][0] + token_ids_changed_target = input_batch[0].clone() + # We increment the token id on the changed index. + token_ids_changed_target[0, changed_target_index] = \ + (token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size + # make sure we're not changing a token to eod as it's a special token + token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1 + token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size + + # Test change + output_changed_target = model(token_ids_changed_target, *input_batch[1:]) + + # All token in past should be unchanged + self.assertTrue( + torch.all( + equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index]) + ) ) - ) - # All tokens in the future should have changed - self.assertFalse( - torch.any( - equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:]) + # All tokens in the future should have changed + self.assertFalse( + torch.any( + equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:]) + ) ) - ) - # Unchanged changed rows should not change either - self.assertTrue( - torch.all( - equal_vectors(output[1, :], output_changed_target[1, :]) + # Unchanged changed rows should not change either + self.assertTrue( + torch.all( + equal_vectors(output[1, :], output_changed_target[1, :]) + ) ) - ) - - ## --------------- CHANGE AN INPUT TOKEN --------------------------- - # Let's change the the last prefix token and make sure that the first token changed - # guaranteed to be positive as we avoid pathological case previously - last_prefix_index = prefix_indices[0][0] - 1 - token_ids_changed_input = input_batch[0].clone() - # We increment the token id on the changed index. - token_ids_changed_input[0, last_prefix_index] = \ - (token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size - # make sure we're not changing a token to eod as it's a special token - token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1 - token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size - - output_changed_input = model(token_ids_changed_input, *input_batch[1:]) - - # All tokens should be changed - self.assertFalse( - torch.any( - equal_vectors(output[0, :], output_changed_input[0, :]) + + ## --------------- CHANGE AN INPUT TOKEN --------------------------- + # Let's change the the last prefix token and make sure that the first token changed + # guaranteed to be positive as we avoid pathological case previously + last_prefix_index = prefix_indices[0][0] - 1 + token_ids_changed_input = input_batch[0].clone() + # We increment the token id on the changed index. + token_ids_changed_input[0, last_prefix_index] = \ + (token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size + # make sure we're not changing a token to eod as it's a special token + token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1 + token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size + + output_changed_input = model(token_ids_changed_input, *input_batch[1:]) + + # All tokens should be changed + self.assertFalse( + torch.any( + equal_vectors(output[0, :], output_changed_input[0, :]) + ) ) - ) - # Unchanged changed rows should not change either - self.assertTrue( - torch.all( - equal_vectors(output[1, :], output_changed_input[1, :]) + # Unchanged changed rows should not change either + self.assertTrue( + torch.all( + equal_vectors(output[1, :], output_changed_input[1, :]) + ) ) - ) def test_prefix_lm_wo_reset_attention_mask(self): """ @@ -234,18 +238,20 @@ def test_prefix_lm_wo_reset_attention_mask(self): command_args["--loss-on-targets-only"] = "" with patch('sys.argv', flatten_arguments(command_args)): - initialize_megatron() - args = get_args() + with mockenv_context(**self.dist_env_1_gpu): + deepspeed.init_distributed() + initialize_megatron() + args = get_args() - model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider) - model = model[0] + model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider) + model = model[0] - token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) - input_batch, _, prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids}) + token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) + input_batch, _, prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids}) - model(*input_batch) + model(*input_batch) - #TODO: Check all invariants + #TODO: Check all invariants def test_gpt_rotary_embeddings(self): """Test rotary embeddings""" @@ -255,25 +261,27 @@ def test_gpt_rotary_embeddings(self): command_args["--position-embedding-type"] = "rotary" with patch('sys.argv', flatten_arguments(command_args)): - initialize_megatron() - args = get_args() - tokenizer = get_tokenizer() + with mockenv_context(**self.dist_env_1_gpu): + deepspeed.init_distributed() + initialize_megatron() + args = get_args() + tokenizer = get_tokenizer() - model, _, _ = setup_model_and_optimizer(gpt_model_provider) - model = model[0] + model, _, _ = setup_model_and_optimizer(gpt_model_provider) + model = model[0] - token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) + token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) - # eod is a special token - token_ids[token_ids == tokenizer.eod] += 1 - token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size + # eod is a special token + token_ids[token_ids == tokenizer.eod] += 1 + token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size - # process batch - input_batch = get_gpt_batch_pipe({"text": token_ids})[0] + # process batch + input_batch = get_gpt_batch_pipe({"text": token_ids})[0] - model(*input_batch) + model(*input_batch) - #TODO: Check all invariants + #TODO: Check all invariants if __name__ == '__main__':