Skip to content

Commit

Permalink
[FEATURE] update backtranslation and add multinomial sampler (dmlc#1259)
Browse files Browse the repository at this point in the history
* back translation bash

* split "lang-pair" para in clean_tok_para_corpus

* added clean_tok_mono_corpus

* fix

* add num_process para

* fix

* fix

* add yml

* rm yml

* update cfg name

* update evaluate

* added max_update / save_interval_update params

* fix

* fix

* multi gpu inference

* fix

* update

* update multi gpu inference

* fix

* fix

* split evaluate and parallel infer

* fix

* test

* fix

* update

* add comments

* fix

* remove todo comment

* revert remove todo comment

* raw lines remove duplicated '\n'

* update multinomaial sampler

* fix

* fix

* fix

* fix

* sampling

* update script

* fix

* add test_case with k > 1 in topk sampling

* fix multinomial sampler

* update docs

* comments situation eos_id = None

* fix

Co-authored-by: Hu <[email protected]>
  • Loading branch information
hutao965 and Hu authored Jul 11, 2020
1 parent 83e1f13 commit a646c34
Show file tree
Hide file tree
Showing 10 changed files with 749 additions and 88 deletions.
9 changes: 6 additions & 3 deletions scripts/datasets/machine_translation/wmt2014_ende.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ sacrebleu -t wmt14 -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/test.raw.${TGT}

# Clean and tokenize the training + dev corpus
cd ${SAVE_PATH}
nlp_preprocess clean_tok_para_corpus --lang-pair ${SRC}-${TGT} \
nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--tgt-lang ${TGT} \
--src-corpus train.raw.${SRC} \
--tgt-corpus train.raw.${TGT} \
--min-num-words 1 \
--max-num-words 100 \
--src-save-path train.tok.${SRC} \
--tgt-save-path train.tok.${TGT}

nlp_preprocess clean_tok_para_corpus --lang-pair ${SRC}-${TGT} \
nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--tgt-lang ${TGT} \
--src-corpus dev.raw.${SRC} \
--tgt-corpus dev.raw.${TGT} \
--min-num-words 1 \
Expand All @@ -35,7 +37,8 @@ nlp_preprocess clean_tok_para_corpus --lang-pair ${SRC}-${TGT} \
--tgt-save-path dev.tok.${TGT}

# For test corpus, we will just tokenize the data
nlp_preprocess clean_tok_para_corpus --lang-pair ${SRC}-${TGT} \
nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--tgt-lang ${TGT} \
--src-corpus test.raw.${SRC} \
--tgt-corpus test.raw.${TGT} \
--src-save-path test.tok.${SRC} \
Expand Down
9 changes: 6 additions & 3 deletions scripts/datasets/machine_translation/wmt2017_zhen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ sacrebleu -t wmt17 -l ${SRC}-${TGT} --echo ref > ${SAVE_PATH}/test.raw.${TGT}

# Clean and tokenize the training + dev corpus
cd ${SAVE_PATH}
nlp_preprocess clean_tok_para_corpus --lang-pair ${SRC}-${TGT} \
nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--tgt-lang ${TGT} \
--src-corpus train.raw.${SRC} \
--tgt-corpus train.raw.${TGT} \
--src-tokenizer jieba \
Expand All @@ -29,7 +30,8 @@ nlp_preprocess clean_tok_para_corpus --lang-pair ${SRC}-${TGT} \
--src-save-path train.tok.${SRC} \
--tgt-save-path train.tok.${TGT}

nlp_preprocess clean_tok_para_corpus --lang-pair ${SRC}-${TGT} \
nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--tgt-lang ${TGT} \
--src-corpus dev.raw.${SRC} \
--tgt-corpus dev.raw.${TGT} \
--src-tokenizer jieba \
Expand All @@ -41,7 +43,8 @@ nlp_preprocess clean_tok_para_corpus --lang-pair ${SRC}-${TGT} \
--tgt-save-path dev.tok.${TGT}

# For test corpus, we will just tokenize the data
nlp_preprocess clean_tok_para_corpus --lang-pair ${SRC}-${TGT} \
nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--tgt-lang ${TGT} \
--src-corpus test.raw.${SRC} \
--tgt-corpus test.raw.${TGT} \
--src-tokenizer jieba \
Expand Down
113 changes: 62 additions & 51 deletions scripts/machine_translation/evaluate_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from mxnet import gluon
import argparse
import logging
import io
import time
from gluonnlp.utils.misc import logging_config
from gluonnlp.models.transformer import TransformerNMTModel,\
Expand All @@ -25,9 +24,9 @@ def parse_args():
parser.add_argument('--seed', type=int, default=100, help='The random seed.')
parser.add_argument('--src_lang', type=str, default='en', help='Source language')
parser.add_argument('--tgt_lang', type=str, default='de', help='Target language')
parser.add_argument('--src_corpus', type=str,
parser.add_argument('--src_corpus', type=str, required=True,
help='The source corpus for evaluation.')
parser.add_argument('--tgt_corpus', type=str,
parser.add_argument('--tgt_corpus', type=str, default=None,
help='The target corpus for evaluation.')
parser.add_argument('--src_tokenizer', choices=['spm',
'subword_nmt',
Expand Down Expand Up @@ -68,17 +67,21 @@ def parse_args():
help='The b in the a * x + b formula of beam search')
parser.add_argument('--param_path', type=str, help='The path to the model parameters.')
parser.add_argument('--gpus', type=str, default='0',
help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.'
help='List of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.'
'(using single gpu is suggested)')
parser.add_argument('--save_dir', type=str, default=None,
help='The path to save the log files and predictions.')
parser.add_argument('--stochastic', action='store_true',
help='Whether to use the stochastic beam search')
parser.add_argument('--inference', action='store_true',
help='Whether to inference with your own data')
help='Whether to inference with your own data, '
'when applying inference, tgt_corpus is not needed and will be set to None.')
args = parser.parse_args()
if args.save_dir is None:
args.save_dir = os.path.splitext(args.param_path)[0] + '_evaluation'
assert args.inference or args.tgt_corpus, 'requring --tgt_corpus while not using --inference'
if args.inference:
args.tgt_corpus = None
logging_config(args.save_dir, console=True)
logging.info(args)
return args
Expand All @@ -91,8 +94,8 @@ def process_corpus(corpus_path, sentence_normalizer, bpe_tokenizer,
raw_lines = []
with open(corpus_path, 'r', encoding='utf-8') as f:
for line in f:
raw_lines.append(line)
line = line.strip()
raw_lines.append(line)
line = sentence_normalizer(line)
if base_tokenizer is not None:
line = ' '.join(base_tokenizer.encode(line))
Expand Down Expand Up @@ -171,24 +174,25 @@ def evaluate(args):
max_length_b=args.max_length_b)

logging.info(beam_search_sampler)
ctx = ctx_l[0]
avg_nll_loss = 0
ntokens = 0
pred_sentences = []
processed_sent = 0
start_eval_time = time.time()
all_src_token_ids, all_src_lines = process_corpus(args.src_corpus,
sentence_normalizer=src_normalizer,
base_tokenizer=base_src_tokenizer,
bpe_tokenizer=src_tokenizer,
add_bos=False,
add_eos=True)
all_tgt_token_ids, all_tgt_lines = process_corpus(args.tgt_corpus,
sentence_normalizer=tgt_normalizer,
base_tokenizer=base_tgt_tokenizer,
bpe_tokenizer=tgt_tokenizer,
add_bos=True,
add_eos=True)
all_src_token_ids, all_src_lines = process_corpus(
args.src_corpus,
sentence_normalizer=src_normalizer,
base_tokenizer=base_src_tokenizer,
bpe_tokenizer=src_tokenizer,
add_bos=False,
add_eos=True
)
if args.tgt_corpus is not None:
all_tgt_token_ids, all_tgt_lines = process_corpus(
args.tgt_corpus,
sentence_normalizer=tgt_normalizer,
base_tokenizer=base_tgt_tokenizer,
bpe_tokenizer=tgt_tokenizer,
add_bos=True,
add_eos=True
)
else: # when applying inference, populate the fake tgt tokens
all_tgt_token_ids = all_tgt_lines = [[] for i in range(len(all_src_token_ids))]
test_dataloader = gluon.data.DataLoader(
list(zip(all_src_token_ids,
[len(ele) for ele in all_src_token_ids],
Expand All @@ -197,8 +201,14 @@ def evaluate(args):
batch_size=32,
batchify_fn=Tuple(Pad(), Stack(), Pad(), Stack()),
shuffle=False)


ctx = ctx_l[0]
pred_sentences = []
start_eval_time = time.time()
# evaluate
if not args.inference:
avg_nll_loss = 0
ntokens = 0
for i, (src_token_ids, src_valid_length, tgt_token_ids, tgt_valid_length)\
in enumerate(test_dataloader):
src_token_ids = mx.np.array(src_token_ids, ctx=ctx, dtype=np.int32)
Expand Down Expand Up @@ -226,38 +236,39 @@ def evaluate(args):
end_eval_time = time.time()
avg_nll_loss = avg_nll_loss / ntokens

with io.open(os.path.join(args.save_dir, 'gt_sentences.txt'), 'w', encoding='utf-8') as of:
for line in all_tgt_lines:
of.write(line + '\n')
with io.open(os.path.join(args.save_dir, 'pred_sentences.txt'), 'w', encoding='utf-8') as of:
for line in pred_sentences:
of.write(line + '\n')
with open(os.path.join(args.save_dir, 'gt_sentences.txt'), 'w', encoding='utf-8') as of:
of.write('\n'.join(all_tgt_lines))
of.write('\n')
with open(os.path.join(args.save_dir, 'pred_sentences.txt'), 'w', encoding='utf-8') as of:
of.write('\n'.join(pred_sentences))
of.write('\n')

sacrebleu_out = sacrebleu.corpus_bleu(sys_stream=pred_sentences, ref_streams=[all_tgt_lines])
logging.info('Time Spent: {}, #Sent={}, SacreBlEU={} Avg NLL={}, Perplexity={}'
.format(end_eval_time - start_eval_time, len(all_tgt_lines),
sacrebleu_out.score, avg_nll_loss, np.exp(avg_nll_loss)))

# inference only
else:
for i, (src_token_ids, src_valid_length, _, _) in tqdm(enumerate(test_dataloader)):
src_token_ids = mx.np.array(src_token_ids, ctx=ctx, dtype=np.int32)
src_valid_length = mx.np.array(src_valid_length, ctx=ctx, dtype=np.int32)
init_input = mx.np.array([tgt_vocab.bos_id for _ in range(src_token_ids.shape[0])], ctx=ctx)
states = inference_model.init_states(src_token_ids, src_valid_length)
samples, scores, valid_length = beam_search_sampler(init_input, states, src_valid_length)
for j in range(samples.shape[0]):
pred_tok_ids = samples[j, 0, :valid_length[j, 0].asnumpy()].asnumpy().tolist()
bpe_decode_line = tgt_tokenizer.decode(pred_tok_ids[1:-1])
pred_sentence = base_tgt_tokenizer.decode(bpe_decode_line.split(' '))
pred_sentences.append(pred_sentence)

with io.open('pred_sentences.txt', 'a', encoding='utf-8') as of:
for line in pred_sentences:
of.write(line + '\n')

processed_sent = processed_sent + len(pred_sentences)
pred_sentences = []

with open(os.path.join(args.save_dir, 'pred_sentences.txt'), 'w', encoding='utf-8') as of:
processed_sentences = 0
for src_token_ids, src_valid_length, _, _ in tqdm(test_dataloader):
src_token_ids = mx.np.array(src_token_ids, ctx=ctx, dtype=np.int32)
src_valid_length = mx.np.array(src_valid_length, ctx=ctx, dtype=np.int32)
init_input = mx.np.array([tgt_vocab.bos_id for _ in range(src_token_ids.shape[0])], ctx=ctx)
states = inference_model.init_states(src_token_ids, src_valid_length)
samples, scores, valid_length = beam_search_sampler(init_input, states, src_valid_length)
for j in range(samples.shape[0]):
pred_tok_ids = samples[j, 0, :valid_length[j, 0].asnumpy()].asnumpy().tolist()
bpe_decode_line = tgt_tokenizer.decode(pred_tok_ids[1:-1])
pred_sentence = base_tgt_tokenizer.decode(bpe_decode_line.split(' '))
pred_sentences.append(pred_sentence)
of.write('\n'.join(pred_sentences))
of.write('\n')
processed_sentences += len(pred_sentences)
pred_sentences = []
end_eval_time = time.time()
logging.info('Time Spent: {}, Inferred sentences: {}'
.format(end_eval_time - start_eval_time, processed_sentences))

if __name__ == '__main__':
os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round'
Expand Down
42 changes: 33 additions & 9 deletions scripts/machine_translation/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,23 @@ def parse_args():
parser.add_argument('--tgt_vocab_path', type=str,
help='Path to the target vocab.')
parser.add_argument('--seed', type=int, default=100, help='The random seed.')
parser.add_argument('--epochs', type=int, default=30, help='upper epoch limit')
parser.add_argument('--epochs', type=int, default=30, help='Upper epoch limit, '
'the model will keep training when epochs < 0 and max_update < 0.')
parser.add_argument('--max_update', type=int, default=-1,
help='Max update steps, when max_update > 0, epochs will be set to -1, '
'each update step contains gpu_num * num_accumulated batches.')
parser.add_argument('--save_interval_update', type=int, default=500,
help='Update interval of saving checkpoints while using max_update.')
parser.add_argument('--cfg', type=str, default='transformer_nmt_base',
help='Configuration of the transformer model. '
'You may select a yml file or use the prebuild configurations.')
parser.add_argument('--label_smooth_alpha', type=float, default=0.1,
help='Weight of label smoothing')
parser.add_argument('--batch_size', type=int, default=2700,
help='Batch size. Number of tokens per gpu in a minibatch')
help='Batch size. Number of tokens per gpu in a minibatch.')
parser.add_argument('--val_batch_size', type=int, default=16,
help='Batch size for evaluation.')
parser.add_argument('--num_buckets', type=int, default=20, help='Bucket number')
parser.add_argument('--num_buckets', type=int, default=20, help='Bucket number.')
parser.add_argument('--bucket_scheme', type=str, default='exp',
help='Strategy for generating bucket keys. It supports: '
'"constant": all the buckets have the same width; '
Expand Down Expand Up @@ -145,11 +151,12 @@ def parse_args():
parser.add_argument('--gpus', type=str,
help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.')
args = parser.parse_args()
if args.max_update > 0:
args.epochs = -1
logging_config(args.save_dir, console=True)
logging.info(args)
return args


def validation(model, data_loader, ctx_l):
"""Validate the model on the dataset
Expand Down Expand Up @@ -355,13 +362,15 @@ def train(args):
log_start_time = time.time()
num_params, num_fixed_params = None, None
# TODO(sxjscience) Add a log metric class

accum_count = 0
loss_denom = 0
n_train_iters = 0
log_wc = 0
log_avg_loss = 0.0
log_loss_denom = 0
for epoch_id in range(args.epochs):
epoch_id = 0
while (args.epochs < 0 or epoch_id < args.epochs): # when args.epochs < 0, the model will keep training
n_epoch_train_iters = 0
processed_batch_num = 0
train_multi_data_loader = grouper(train_data_loader, len(ctx_l))
Expand Down Expand Up @@ -414,7 +423,8 @@ def train(args):
accum_count = 0
loss_denom = 0
model.collect_params().zero_grad()
if epoch_id >= (args.epochs - args.num_averages):
if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \
(args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update):
model_averager.step()
if n_epoch_train_iters % args.log_interval == 0:
log_end_time = time.time()
Expand All @@ -430,12 +440,26 @@ def train(args):
log_avg_loss = 0
log_loss_denom = 0
log_wc = 0
model.save_parameters(os.path.join(args.save_dir,
'epoch{:d}.params'.format(epoch_id)),
deduplicate=True)
if args.max_update > 0 and n_train_iters % args.save_interval_update == 0:
model.save_parameters(os.path.join(args.save_dir,
'{:d}.params'.format(n_train_iters // args.save_interval_update)),
deduplicate=True)
if args.max_update > 0 and n_train_iters >= args.max_update:
break

if args.epochs > 0:
model.save_parameters(os.path.join(args.save_dir,
'epoch{:d}.params'.format(epoch_id)),
deduplicate=True)

avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

if args.max_update > 0 and n_train_iters >= args.max_update:
break
epoch_id += 1

if args.num_averages > 0:
model_averager.copy_back(model.collect_params()) # TODO(sxjscience) Rewrite using update
model.save_parameters(os.path.join(args.save_dir, 'average.params'),
Expand Down
Loading

0 comments on commit a646c34

Please sign in to comment.