From 510d991a770487d7d6aff1cd6ae8a4dcf7c87ed6 Mon Sep 17 00:00:00 2001 From: ZheyuYe Date: Thu, 30 Jul 2020 02:33:22 +0800 Subject: [PATCH] test --- src/gluonnlp/models/bart.py | 8 +++++--- src/gluonnlp/models/transformer.py | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/gluonnlp/models/bart.py b/src/gluonnlp/models/bart.py index 2fb154a2ae..fe61fab37e 100644 --- a/src/gluonnlp/models/bart.py +++ b/src/gluonnlp/models/bart.py @@ -64,7 +64,7 @@ def bart_base(): cfg.MODEL.tie_weights = True cfg.MODEL.attention_dropout = 0.1 cfg.MODEL.activation_dropout = 0.0 - cfg.MODEL.dropout = 0.0 + cfg.MODEL.dropout = 0.1 cfg.MODEL.layer_norm_eps = 1E-5 cfg.MODEL.pooler_activation = 'tanh' cfg.MODEL.data_norm = True @@ -270,13 +270,15 @@ def get_cfg(cls, key=None): return bart_cfg_reg.create(key) @classmethod - def from_cfg(cls, cfg, + def from_cfg(cls, cfg, dtype=None, use_pooler=False, classifier_activation=False): cfg = cls.get_cfg().clone_merge(cfg) embed_initializer = mx.init.create(*cfg.INITIALIZER.embed) weight_initializer = mx.init.create(*cfg.INITIALIZER.weight) bias_initializer = mx.init.create(*cfg.INITIALIZER.bias) + if dtype is None: + dtype = cfg.MODEL.dtype return cls(src_vocab_size=cfg.MODEL.vocab_size, tgt_vocab_size=cfg.MODEL.vocab_size, max_src_length=cfg.MODEL.max_src_length, @@ -310,7 +312,7 @@ def from_cfg(cls, cfg, embed_initializer=embed_initializer, weight_initializer=weight_initializer, bias_initializer=bias_initializer, - dtype=cfg.MODEL.dtype) + dtype=dtype) def list_pretrained_bart(): diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 41286f81c7..8e55111e19 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -483,6 +483,9 @@ def __init__(self, units: int = 512, hidden_size=hidden_size, dropout=dropout, activation_dropout=activation_dropout, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + layer_norm_eps=layer_norm_eps, activation=activation, pre_norm=pre_norm, dtype=dtype)