Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 29, 2020
1 parent 1b5fa7b commit 510d991
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/gluonnlp/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def bart_base():
cfg.MODEL.tie_weights = True
cfg.MODEL.attention_dropout = 0.1
cfg.MODEL.activation_dropout = 0.0
cfg.MODEL.dropout = 0.0
cfg.MODEL.dropout = 0.1
cfg.MODEL.layer_norm_eps = 1E-5
cfg.MODEL.pooler_activation = 'tanh'
cfg.MODEL.data_norm = True
Expand Down Expand Up @@ -270,13 +270,15 @@ def get_cfg(cls, key=None):
return bart_cfg_reg.create(key)

@classmethod
def from_cfg(cls, cfg,
def from_cfg(cls, cfg, dtype=None,
use_pooler=False,
classifier_activation=False):
cfg = cls.get_cfg().clone_merge(cfg)
embed_initializer = mx.init.create(*cfg.INITIALIZER.embed)
weight_initializer = mx.init.create(*cfg.INITIALIZER.weight)
bias_initializer = mx.init.create(*cfg.INITIALIZER.bias)
if dtype is None:
dtype = cfg.MODEL.dtype
return cls(src_vocab_size=cfg.MODEL.vocab_size,
tgt_vocab_size=cfg.MODEL.vocab_size,
max_src_length=cfg.MODEL.max_src_length,
Expand Down Expand Up @@ -310,7 +312,7 @@ def from_cfg(cls, cfg,
embed_initializer=embed_initializer,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dtype=cfg.MODEL.dtype)
dtype=dtype)


def list_pretrained_bart():
Expand Down
3 changes: 3 additions & 0 deletions src/gluonnlp/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ def __init__(self, units: int = 512,
hidden_size=hidden_size,
dropout=dropout,
activation_dropout=activation_dropout,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
layer_norm_eps=layer_norm_eps,
activation=activation,
pre_norm=pre_norm,
dtype=dtype)
Expand Down

0 comments on commit 510d991

Please sign in to comment.