Skip to content

Commit c460e3b

Browse files
committed
fixed an error
1 parent c6d0281 commit c460e3b

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

keras/layers/crf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ def __init__(self, output_dim,
119119
weights=None, input_dim=None, input_length=None, unroll=False, **kwargs):
120120
self.supports_masking = True
121121
self.output_dim = output_dim
122-
self.learn_mode = learn_mode.lower()
123-
self.test_mode = test_mode.lower()
122+
self.learn_mode = learn_mode
124123
assert self.learn_mode in ['join', 'marginal']
124+
self.test_mode = test_mode
125125
if self.test_mode is None:
126126
self.test_mode = 'viterbi' if self.learn_mode == 'join' else 'marginal'
127127
else:

tests/keras/layers/test_crf.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
@keras_test
1212
def test_CRF():
13-
1413
# data
1514
x = np.random.randint(1, embedding_num, nb_samples * timesteps).reshape((nb_samples, timesteps))
1615
x[0, -4:] = 0 # right padding
@@ -44,7 +43,7 @@ def test_CRF():
4443
# test `viterbi_acc
4544
_, v_acc, _ = model.evaluate(x, y)
4645
np_acc = (y_pred[x > 0] == y[x > 0]).mean()
47-
assert_allclose([v_acc], [np_acc], atol=1e-6)
46+
assert np.abs(v_acc - np_acc) < 1e-4
4847

4948
# test config
5049
model.get_config()

0 commit comments

Comments
 (0)