Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
516df5e
WIP: test
thomasw21 Aug 6, 2021
eececbb
Still trying to figure out deepspeed
thomasw21 Aug 6, 2021
b78dfaa
WIP
thomasw21 Aug 6, 2021
0bfc2f4
Test test
thomasw21 Aug 6, 2021
b801637
Test how to setup deepspeed in unit tests
thomasw21 Aug 6, 2021
2b81d40
Test something else
thomasw21 Aug 6, 2021
aeca8c1
Empty strings might be problematic
thomasw21 Aug 6, 2021
520ef72
Remove unecessary arguments
thomasw21 Aug 6, 2021
37522b4
Woops
thomasw21 Aug 6, 2021
76f01fe
Remove global variables at the end of each test and init deepspeed
thomasw21 Aug 6, 2021
188b33b
Woops
thomasw21 Aug 6, 2021
57191c4
Maybe adding classmethod
thomasw21 Aug 6, 2021
1389e6d
Woops
thomasw21 Aug 7, 2021
e458540
Add debug print to check that tear down happends
thomasw21 Aug 7, 2021
d7f331f
Reset global variables before
thomasw21 Aug 7, 2021
af9a716
Let's test this
thomasw21 Aug 7, 2021
b908124
Try something else
thomasw21 Aug 7, 2021
28cea95
WIP
thomasw21 Aug 7, 2021
642ef91
More fix
thomasw21 Aug 7, 2021
5143ce6
More fix
thomasw21 Aug 7, 2021
8cfb92c
More stuff to fix
thomasw21 Aug 7, 2021
9dbd939
We really want to compare vectors and not coordinates
thomasw21 Aug 7, 2021
82c6ca1
Reformat
thomasw21 Aug 7, 2021
7c6ea15
check something out
thomasw21 Aug 7, 2021
076b69f
fix test
thomasw21 Aug 7, 2021
2e0f71a
Remove prefix-lm flag as it's integrated
thomasw21 Aug 9, 2021
4ec9aca
Update test
thomasw21 Sep 16, 2021
18b1c97
Woops
thomasw21 Sep 16, 2021
76aad89
Add test for without reset attention mask
thomasw21 Sep 16, 2021
86f8928
Fix test for non reset attention mask
thomasw21 Sep 16, 2021
fe4a815
Fix test
thomasw21 Sep 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,10 @@ def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_

assert partial_prefix_indices is None or len(partial_prefix_indices) == micro_batch_size, f"partial_prefix_indices has to be None or its length equal to {micro_batch_size}, got {len(partial_prefix_indices)}"
for batch_id in range(micro_batch_size):
prefix_indices.append([])

# Prefix lm per document.
if reset_attention_mask:
prefix_indices.append([])

# Compute the index of all eod tokens in data.
eod_indices = (data[batch_id] == eod_token).nonzero().squeeze(-1)

Expand Down Expand Up @@ -356,12 +356,14 @@ def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_
assert partial_prefix_indices is None or isinstance(partial_prefix_indices[batch_id], int), \
f"Per document prefix has to store an int for each row, got {partial_prefix_indices[batch_id]}"

prefix_index: int
if partial_prefix_indices is None or partial_prefix_indices[batch_id] is None:
# We need to randomly generate a prefix index
prefix_index = randint(0, seq_length - 1)
else:
# We get value from partial_prefix_indices, and run validation on that value
prefix_index = partial_prefix_indices[batch_id]
assert 0 <= prefix_index < seq_length - 1, f"Prefix index needs to be between documents indices, 0 <= {prefix_index} < {seq_length - 1} should be True."
prefix_indices[batch_id].append(prefix_index)
prefix_indices.append(prefix_index)

return prefix_indices
2 changes: 1 addition & 1 deletion pretrain_prefix_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_batch_pipe(data):
loss_on_targets_only=args.loss_on_targets_only
)

return (tokens, position_ids, attention_mask), (labels, loss_mask)
return (tokens, position_ids, attention_mask), (labels, loss_mask), prefix_indices
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is okay as all values beyong the 3rd index are ignored. Which is nice to be used to expose variables needed for testing.


def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
Expand Down
1,000 changes: 1,000 additions & 0 deletions tests/data/gpt2/openwebtext-1000.jsonl

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def test_import():
import megatron
280 changes: 280 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
import unittest
from random import randint
from unittest.mock import patch

import deepspeed
import torch

from megatron import initialize_megatron, get_args, get_tokenizer, global_vars
from megatron.testing_utils import TestCasePlus
from megatron.training import setup_model_and_optimizer
from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe
from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe


def get_default_args():
"""return a dictionary with key as argument name and value as additional arguments"""
return {
# GPT_ARGS
"--num-layers": "2",
"--hidden-size": "128",
"--num-attention-heads": "4",
"--seq-length": "256",
"--max-position-embeddings": "256",
"--micro-batch-size": "4",
"--global-batch-size": "8",
"--lr-decay-iters": "320000",
"--lr-decay-style": "cosine",
"--lr": "0.00015",
"--min-lr": "1.0e-5",
"--train-iters": "5000",
"--tokenizer-type": "PretrainedFromHF",
"--tokenizer-name-or-path": "gpt2",
"--data-impl": "mmap",
"--split": "949,50,1",
"--distributed-backend": "nccl",
"--weight-decay": "1e-2",
"--clip-grad": "1.0",
"--lr-warmup-fraction": ".01",
"--fp16": "",

"--attention-dropout": "0",
"--hidden-dropout": "0",

# OUTPUT_ARGS
"--log-interval": "10",
"--save-interval": "500",
"--eval-interval": "100",
"--eval-iters": "10",
"--checkpoint-activations": "",

# DATA_ARGS
}


def flatten_arguments(args):
"""
Converts dictionary argument to a list.

Note: we add "IGNORED" at the beginning as this value is ignored by the argparser

Example: {"arg1": "value1", "arg2": "value2"} -> ["IGNORED", "arg1", "value1", "arg2", "value2"]
"""
return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]


def equal_vectors(tensor1, tensor2, dim=-1):
"""View tensor1 and tensor2 as a list of vectors, and compute equality"""
return torch.linalg.norm(tensor1 - tensor2, dim=dim) == 0


class MyTestCase(TestCasePlus):
@classmethod
def setUpClass(cls) -> None:
deepspeed.init_distributed()

def setUp(self) -> None:
super().setUp()

# We reset all global variables
global_vars._GLOBAL_ARGS = None
global_vars._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
global_vars._GLOBAL_TOKENIZER = None
global_vars._GLOBAL_TENSORBOARD_WRITER = None
global_vars._GLOBAL_ADLR_AUTORESUME = None
global_vars._GLOBAL_TIMERS = None

def test_gpt(self):
"""Test causal invariance, ie past token don't depend on future tokens."""
command_args = get_default_args()

with patch('sys.argv', flatten_arguments(command_args)):
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch
input_batch = get_gpt_batch_pipe({"text": token_ids})[0]

# get a modified version of the first batch, we change a specific index
changed_index = randint(0, args.seq_length - 2)
input_token_ids_changed = input_batch[0].clone()
# We increment the token_id by one for that index in order to artificially change the sequence.
input_token_ids_changed[:, changed_index] = \
(input_token_ids_changed[:,changed_index] + 1) % args.padded_vocab_size

output = model(*input_batch)
output_changed = model(input_token_ids_changed, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(equal_vectors(output[:, :changed_index], output_changed[:, :changed_index]))
)
# All tokens in the future should have changed
self.assertFalse(
torch.any(equal_vectors(output[:, changed_index:], output_changed[:, changed_index:]))
)

def test_prefix_lm_reset_attention_mask(self):
"""
Test prefix invariances when `reset_attention_mask=True`:
- Past target tokens don't depend on future target tokens.
- Target tokens depend on input tokens.
- Input tokens depend on all other input tokens, but never target tokens.
"""
command_args = get_default_args()

command_args["--reset-attention-mask"] = ""
command_args["--loss-on-targets-only"] = ""

with patch('sys.argv', flatten_arguments(command_args)):
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
for i in range(9, -1, -1):
input_batch, _, prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})
if (prefix_indices[0][0] != 0):
break
if i == 0:
# FIXME: find a better way to not obtain empty prefix
raise ValueError("Could not obtain non pathological case where prefix is not empty")

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
)
)
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
)
)

## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
)
)

def test_prefix_lm_wo_reset_attention_mask(self):
"""
Test prefix invariances when `reset_attention_mask=False`:
- Past target tokens don't depend on future target tokens.
- Target tokens depend on input tokens.
- Input tokens depend on all other input tokens, but never target tokens.
"""
command_args = get_default_args()

command_args["--loss-on-targets-only"] = ""

with patch('sys.argv', flatten_arguments(command_args)):
initialize_megatron()
args = get_args()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
input_batch, _, prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

model(*input_batch)

#TODO: Check all invariants

def test_gpt_rotary_embeddings(self):
"""Test rotary embeddings"""
command_args = get_default_args()

del command_args["--max-position-embeddings"]
command_args["--position-embedding-type"] = "rotary"

with patch('sys.argv', flatten_arguments(command_args)):
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch
input_batch = get_gpt_batch_pipe({"text": token_ids})[0]

model(*input_batch)

#TODO: Check all invariants


if __name__ == '__main__':
unittest.main()
3 changes: 1 addition & 2 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,13 @@ def compare_meg_data_files(self, tgt, ref):
self.assertTrue(Path(tgt_path).exists(), )
self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False))

@unittest.skip("Skip test until data is fixed.")
def test_process_data_microsoft(self):
"""We want to be stable to Microsoft version."""
src_dir = self.src_dir
data_dir = f"{self.data_dir}/gpt2"
output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False)

input_path = f"{self.tests_dir}/tools/openwebtext-1000.jsonl"
input_path = f"{self.tests_dir}/data/gpt2/openwebtext-1000.jsonl"

output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext"

Expand Down