forked from dmlc/gluon-nlp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_models_gpt2.py
152 lines (135 loc) · 4.93 KB
/
test_models_gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import pytest
import numpy as np
import mxnet as mx
import tempfile
from numpy.testing import assert_allclose
from gluonnlp.models.gpt2 import GPT2Model, GPT2ForLM, \
list_pretrained_gpt2, get_pretrained_gpt2
from gluonnlp.loss import LabelSmoothCrossEntropyLoss
mx.npx.set_np()
def test_list_pretrained_gpt2():
assert len(list_pretrained_gpt2()) > 0
@pytest.mark.parametrize('compute_layout', ['auto', 'TN', 'NT'])
def test_gpt2_small_config(compute_layout, ctx):
cfg = GPT2Model.get_cfg()
cfg.defrost()
cfg.MODEL.vocab_size = 1000
cfg.MODEL.units = 128
cfg.MODEL.num_layers = 2
cfg.MODEL.num_heads = 2
cfg.MODEL.compute_layout = compute_layout
cfg.freeze()
# Generate TN layout
cfg_tn = cfg.clone()
cfg_tn.defrost()
cfg_tn.MODEL.layout = 'TN'
cfg_tn.freeze()
with ctx:
batch_size = 4
sequence_length = 16
inputs = mx.np.random.randint(0, 1000, (batch_size, sequence_length), ctx=ctx)
gpt2_model = GPT2Model.from_cfg(cfg)
gpt2_model.initialize(ctx=ctx)
gpt2_model.hybridize()
hiddens, _ = gpt2_model(
inputs,
gpt2_model.init_states(batch_size, ctx)
)
gpt2_model_tn = GPT2Model.from_cfg(cfg_tn)
gpt2_model_tn.share_parameters(gpt2_model.collect_params())
gpt2_model_tn.hybridize()
hiddens_tn, _ = gpt2_model_tn(
inputs.T,
gpt2_model_tn.init_states(batch_size, ctx)
)
assert_allclose(np.swapaxes(hiddens_tn.asnumpy(), 0, 1),
hiddens.asnumpy(), 1E-4, 1E-4)
# Test for GPT2ForLM
gpt2_lm_model = GPT2ForLM(cfg)
gpt2_lm_model.initialize(ctx=ctx)
gpt2_lm_model.hybridize()
logits, states = gpt2_lm_model(
inputs,
gpt2_lm_model.init_states(batch_size, ctx)
)
gpt2_lm_model_tn = GPT2ForLM(cfg_tn)
gpt2_lm_model_tn.share_parameters(gpt2_lm_model.collect_params())
gpt2_lm_model_tn.hybridize()
logits_tn, states_tn = gpt2_lm_model_tn(
inputs.T,
gpt2_lm_model_tn.init_states(batch_size, ctx)
)
assert_allclose(np.swapaxes(logits_tn.asnumpy(), 0, 1),
logits.asnumpy(), 1E-4, 1E-4)
assert_allclose(np.swapaxes(states_tn.asnumpy(), 2, 3),
states.asnumpy(), 1E-4, 1E-4)
def test_gpt2_incremental_states(ctx):
with ctx:
batch_size = 4
sequence_length = 5
inputs = mx.np.random.randint(0, 1000, (batch_size, sequence_length), ctx=ctx)
cfg = GPT2Model.get_cfg()
gpt2_model = GPT2Model.from_cfg(cfg)
gpt2_model.initialize(ctx=ctx)
gpt2_model.hybridize()
one_time_hiddens, one_time_states = gpt2_model(
inputs,
gpt2_model.init_states(batch_size, ctx)
)
states = gpt2_model.init_states(batch_size, ctx)
hiddens_l = []
for i in range(sequence_length):
hiddens, states = gpt2_model(
inputs[:, i:i+1],
states
)
hiddens_l.append(hiddens)
hiddens_concat = mx.np.concatenate(hiddens_l, axis=1)
assert_allclose(one_time_states.asnumpy(),
states.asnumpy(), 1E-4, 1E-4)
assert_allclose(one_time_hiddens.asnumpy(),
hiddens_concat.asnumpy(), 1E-4, 1E-4)
@pytest.mark.slow
@pytest.mark.remote_required
@pytest.mark.parametrize('model_name', ['gpt2_124M', 'gpt2_355M', 'gpt2_774M'])
def test_gpt2(model_name, ctx):
# test from pretrained
assert len(list_pretrained_gpt2()) > 0
with tempfile.TemporaryDirectory() as root, ctx:
cfg, tokenizer, params_path, lm_params_path =\
get_pretrained_gpt2(model_name, load_backbone=True, load_lm=True, root=root)
assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
# test backbone
gpt2_model = GPT2Model.from_cfg(cfg)
gpt2_model.load_parameters(params_path)
# test lm model
gpt2_lm_model = GPT2ForLM(cfg)
gpt2_lm_model.load_parameters(lm_params_path)
# test forward
batch_size = 3
seq_length = 32
vocab_size = len(tokenizer.vocab)
input_ids = mx.np.array(
np.random.randint(
2,
vocab_size,
(batch_size, seq_length)
),
dtype=np.int32,
ctx=ctx
)
logits, _ = gpt2_lm_model(
input_ids,
gpt2_lm_model.init_states(batch_size, ctx)
)
mx.npx.waitall()
# test backward
label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=vocab_size)
with mx.autograd.record():
logits, _ = gpt2_lm_model(
input_ids,
gpt2_lm_model.init_states(batch_size, ctx)
)
loss = label_smooth_loss(logits, input_ids)
loss.backward()
mx.npx.waitall()