Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Use with_seed decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Jul 11, 2019
1 parent 80af4c4 commit 09563f7
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tests/nightly/estimator/test_sentiment_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import argparse
import os
import sys
import tarfile
import random
import collections
Expand All @@ -30,6 +31,9 @@
from mxnet.contrib import text
from mxnet.gluon import nn, rnn
from mxnet.gluon.contrib.estimator import estimator
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../../python/unittest'))
from common import with_seed


class TextCNN(nn.Block):
Expand Down Expand Up @@ -232,6 +236,9 @@ def test_estimator_cpu(**kwargs):
run(net, train_dataloader, val_dataloader, **kwargs)


# Model
# using fixed seed to reduce flakiness in accuracy assertion
@with_seed(7)
def test_estimator_gpu(**kwargs):
'''
Test estimator by training Bidirectional RNN for 5 epochs on the IMDB dataset
Expand All @@ -252,9 +259,6 @@ def test_estimator_gpu(**kwargs):
train_dataloader = gluon.data.DataLoader(train_set, batch_size, shuffle=True)
test_dataloader = gluon.data.DataLoader(test_set, batch_size)

# Model
# using fixed seed to reduce flakiness in accuracy assertion
mx.random.seed(7)
num_hiddens, num_layers = 100, 2
net = BiRNN(vocab, embed_size, num_hiddens, num_layers)
net.initialize(mx.init.Xavier(), ctx=ctx)
Expand Down

0 comments on commit 09563f7

Please sign in to comment.