Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 140 additions & 132 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from megatron import initialize_megatron, get_args, get_tokenizer, global_vars
from megatron.testing_utils import TestCasePlus
from megatron.testing_utils import TestCasePlus, mockenv_context
from megatron.training import setup_model_and_optimizer
from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe
from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe
Expand Down Expand Up @@ -69,10 +69,6 @@ def equal_vectors(tensor1, tensor2, dim=-1):


class MyTestCase(TestCasePlus):
@classmethod
def setUpClass(cls) -> None:
deepspeed.init_distributed()

def setUp(self) -> None:
super().setUp()

Expand All @@ -84,45 +80,51 @@ def setUp(self) -> None:
global_vars._GLOBAL_ADLR_AUTORESUME = None
global_vars._GLOBAL_TIMERS = None

self.dist_env_1_gpu = dict(
MASTER_ADDR="localhost", MASTER_PORT="9994", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
)

def test_gpt(self):
"""Test causal invariance, ie past token don't depend on future tokens."""
command_args = get_default_args()

with patch('sys.argv', flatten_arguments(command_args)):
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size
# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch
input_batch = get_gpt_batch_pipe({"text": token_ids})[0]
# process batch
input_batch = get_gpt_batch_pipe({"text": token_ids})[0]

# get a modified version of the first batch, we change a specific index
changed_index = randint(0, args.seq_length - 2)
input_token_ids_changed = input_batch[0].clone()
# We increment the token_id by one for that index in order to artificially change the sequence.
input_token_ids_changed[:, changed_index] = \
(input_token_ids_changed[:,changed_index] + 1) % args.padded_vocab_size
# get a modified version of the first batch, we change a specific index
changed_index = randint(0, args.seq_length - 2)
input_token_ids_changed = input_batch[0].clone()
# We increment the token_id by one for that index in order to artificially change the sequence.
input_token_ids_changed[:, changed_index] = \
(input_token_ids_changed[:,changed_index] + 1) % args.padded_vocab_size

output = model(*input_batch)
output_changed = model(input_token_ids_changed, *input_batch[1:])
output = model(*input_batch)
output_changed = model(input_token_ids_changed, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(equal_vectors(output[:, :changed_index], output_changed[:, :changed_index]))
)
# All tokens in the future should have changed
self.assertFalse(
torch.any(equal_vectors(output[:, changed_index:], output_changed[:, changed_index:]))
)
# All token in past should be unchanged
self.assertTrue(
torch.all(equal_vectors(output[:, :changed_index], output_changed[:, :changed_index]))
)
# All tokens in the future should have changed
self.assertFalse(
torch.any(equal_vectors(output[:, changed_index:], output_changed[:, changed_index:]))
)

def test_prefix_lm_reset_attention_mask(self):
"""
Expand All @@ -137,90 +139,92 @@ def test_prefix_lm_reset_attention_mask(self):
command_args["--loss-on-targets-only"] = ""

with patch('sys.argv', flatten_arguments(command_args)):
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
for i in range(9, -1, -1):
input_batch, _, prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})
if (prefix_indices[0][0] != 0):
break
if i == 0:
# FIXME: find a better way to not obtain empty prefix
raise ValueError("Could not obtain non pathological case where prefix is not empty")

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
for i in range(9, -1, -1):
input_batch, _, prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})
if (prefix_indices[0][0] != 0):
break
if i == 0:
# FIXME: find a better way to not obtain empty prefix
raise ValueError("Could not obtain non pathological case where prefix is not empty")

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
)
)
)
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
)
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
)
)
)

## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])

## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])
)
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
)
)
)

def test_prefix_lm_wo_reset_attention_mask(self):
"""
Expand All @@ -234,18 +238,20 @@ def test_prefix_lm_wo_reset_attention_mask(self):
command_args["--loss-on-targets-only"] = ""

with patch('sys.argv', flatten_arguments(command_args)):
initialize_megatron()
args = get_args()
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]
model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
input_batch, _, prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
input_batch, _, prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

model(*input_batch)
model(*input_batch)

#TODO: Check all invariants
#TODO: Check all invariants

def test_gpt_rotary_embeddings(self):
"""Test rotary embeddings"""
Expand All @@ -255,25 +261,27 @@ def test_gpt_rotary_embeddings(self):
command_args["--position-embedding-type"] = "rotary"

with patch('sys.argv', flatten_arguments(command_args)):
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size
# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch
input_batch = get_gpt_batch_pipe({"text": token_ids})[0]
# process batch
input_batch = get_gpt_batch_pipe({"text": token_ids})[0]

model(*input_batch)
model(*input_batch)

#TODO: Check all invariants
#TODO: Check all invariants


if __name__ == '__main__':
Expand Down