Skip to content

Commit ad94185

Browse files
committed
fixed PEP8
1 parent 7e8495c commit ad94185

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

examples/conll2000_chunking_crf.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
vocab = ['<pad>', '<unk>'] + [w for w, f in word_counts.iteritems() if f >= 3]
3434
word2idx = dict((w, i) for i, w in enumerate(vocab))
3535

36+
3637
def process_data(data, maxlen=None, onehot=False):
3738
if maxlen is None:
3839
maxlen = max(len(s) for s in data)
@@ -46,6 +47,7 @@ def process_data(data, maxlen=None, onehot=False):
4647
else:
4748
return x, numpy.expand_dims(y, 2)
4849

50+
4951
train_x, train_y = process_data(train)
5052
test_x, test_y = process_data(test)
5153

@@ -56,7 +58,7 @@ def process_data(data, maxlen=None, onehot=False):
5658
print('==== training CRF ====')
5759

5860
model = Sequential()
59-
model.add(Embedding(len(vocab), 200, mask_zero=True)) # Random embedding
61+
model.add(Embedding(len(vocab), 200, mask_zero=True)) # Random embedding
6062
crf = CRF(len(class_labels), sparse_target=True)
6163
model.add(crf)
6264
model.summary()
@@ -77,7 +79,7 @@ def process_data(data, maxlen=None, onehot=False):
7779
print('==== training BiLSTM-CRF ====')
7880

7981
model = Sequential()
80-
model.add(Embedding(len(vocab), 200, mask_zero=True)) # Random embedding
82+
model.add(Embedding(len(vocab), 200, mask_zero=True)) # Random embedding
8183
model.add(Bidirectional(LSTM(100, return_sequences=True)))
8284
crf = CRF(len(class_labels), sparse_target=True)
8385
model.add(crf)

keras/layers/crf.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ class CRF(Layer):
4545
model = Sequential()
4646
model.add(Embedding(3001, 300, mask_zero=True)(X)
4747
48-
# use learn_mode = 'join', test_mode = 'viterbi'
49-
crf = CRF(10)
50-
model.add(crf(Embed))
48+
# use learn_mode = 'join', test_mode = 'viterbi', sparse_target = True (label indice output)
49+
crf = CRF(10, sparse_target=True)
50+
model.add(crf)
5151
5252
# crf.accuracy is default to Viterbi acc if using join-mode (default).
5353
# One can add crf.marginal_acc if interested, but may slow down learning
5454
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
5555
56-
# y can be either onehot representation or label indices (with shape 1 at dim 3)
56+
# y must be label indices (with shape 1 at dim 3) here, since `sparse_target=True`
5757
model.fit(x, y)
5858
5959
# prediction give onehot representation of Viterbi best path
@@ -331,16 +331,16 @@ def get_logZ(self, in_energy, mask):
331331
def get_energy(self, y_true, in_energy, mask):
332332
'''Energy = a1' y1 + u1' y1 + y1' U y2 + u2' y2 + y2' U y3 + u3' y3 + an' y3
333333
'''
334-
in_energy = K.sum(in_energy * y_true, 2) # (B, T)
334+
in_energy = K.sum(in_energy * y_true, 2) # (B, T)
335335
chain_energy = K.sum(K.dot(y_true[:, :-1, :], self.U) * y_true[:, 1:, :], 2) # (B, T-1)
336336
chain_energy = self.chain_activation(chain_energy)
337337

338338
if mask is not None:
339339
mask = K.cast(mask, K.floatx())
340-
chain_mask = mask[:, :-1] * mask[:, 1:] # (B, T-1), mask[:,:-1]*mask[:,1:] makes it work with any padding
340+
chain_mask = mask[:, :-1] * mask[:, 1:] # (B, T-1), mask[:,:-1]*mask[:,1:] makes it work with any padding
341341
in_energy = in_energy * mask
342342
chain_energy = chain_energy * chain_mask
343-
total_energy = K.sum(in_energy, -1) + K.sum(chain_energy, -1) # (B, )
343+
total_energy = K.sum(in_energy, -1) + K.sum(chain_energy, -1) # (B, )
344344

345345
return total_energy
346346

@@ -368,19 +368,19 @@ def step(self, in_energy_t, states, return_logZ=True):
368368
t = K.cast(i[0, 0], dtype='int32')
369369
if len(states) > 3:
370370
if K._BACKEND == 'theano':
371-
m = states[3][:, t:(t+2)]
371+
m = states[3][:, t:(t + 2)]
372372
else:
373373
m = tf.slice(states[3], [0, t], [-1, 2])
374374
in_energy_t = in_energy_t * K.expand_dims(m[:, 0])
375375
chain_energy = chain_energy * K.expand_dims(K.expand_dims(m[:, 0] * m[:, 1])) # (1, F, F)*(B, 1, 1) -> (B, F, F)
376376
if return_logZ:
377-
energy = chain_energy + K.expand_dims(in_energy_t - prev_target_val, 2) # shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
378-
new_target_val = self.log_sum_exp(-energy, 1) # shapes: (B, F)
377+
energy = chain_energy + K.expand_dims(in_energy_t - prev_target_val, 2) # shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
378+
new_target_val = self.log_sum_exp(-energy, 1) # shapes: (B, F)
379379
return new_target_val, [new_target_val, i + 1]
380380
else:
381381
energy = chain_energy + K.expand_dims(in_energy_t + prev_target_val, 2)
382382
min_energy = K.min(energy, 1)
383-
argmin_table = K.cast(K.argmin(energy, 1), K.floatx()) # cast for tf-version `K.rnn`
383+
argmin_table = K.cast(K.argmin(energy, 1), K.floatx()) # cast for tf-version `K.rnn`
384384
return argmin_table, [min_energy, i + 1]
385385

386386
def recursion(self, in_energy, mask=None, go_backwards=False, return_sequences=True, return_logZ=True):
@@ -403,8 +403,8 @@ def recursion(self, in_energy, mask=None, go_backwards=False, return_sequences=T
403403
If `return_logZ = False`, compute the Viterbi's best path lookup table.
404404
'''
405405
chain_energy = self.chain_activation(self.U)
406-
chain_energy = K.expand_dims(chain_energy, 0) # shape=(1, F, F): F=num of output features. 1st F is for t-1, 2nd F for t
407-
prev_target_val = K.zeros_like(in_energy[:, 0, :]) # shape=(B, F), dtype=float32
406+
chain_energy = K.expand_dims(chain_energy, 0) # shape=(1, F, F): F=num of output features. 1st F is for t-1, 2nd F for t
407+
prev_target_val = K.zeros_like(in_energy[:, 0, :]) # shape=(B, F), dtype=float32
408408

409409
if go_backwards:
410410
in_energy = K.reverse(in_energy, 1)
@@ -458,7 +458,7 @@ def viterbi_decoding(self, X, mask=None):
458458

459459
# backward to find best path, `initial_best_idx` can be any, as all elements in the last argmin_table are the same
460460
argmin_tables = K.reverse(argmin_tables, 1)
461-
initial_best_idx = [K.expand_dims(argmin_tables[:, 0, 0])] # matrix instead of vector is required by tf `K.rnn`
461+
initial_best_idx = [K.expand_dims(argmin_tables[:, 0, 0])] # matrix instead of vector is required by tf `K.rnn`
462462

463463
def gather_each_row(params, indices):
464464
n = K.shape(indices)[0]

tests/keras/layers/test_crf.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
def test_CRF():
1313
# data
1414
x = np.random.randint(1, embedding_num, nb_samples * timesteps).reshape((nb_samples, timesteps))
15-
x[0, -4:] = 0 # right padding
16-
x[1, :5] = 0 # left padding
15+
x[0, -4:] = 0 # right padding
16+
x[1, :5] = 0 # left padding
1717
y = np.random.randint(0, output_dim, nb_samples * timesteps).reshape((nb_samples, timesteps))
1818
y_onehot = np.eye(output_dim)[y]
19-
y = np.expand_dims(y, 2) # .astype('float32')
19+
y = np.expand_dims(y, 2) # .astype('float32')
2020

2121
# test with no masking, onehot, fix length
2222
model = Sequential()
@@ -37,8 +37,8 @@ def test_CRF():
3737

3838
# check mask
3939
y_pred = model.predict(x).argmax(-1)
40-
assert (y_pred[0, -4:] == 0).all() # right padding
41-
assert (y_pred[1, :5] == 0).all() # left padding
40+
assert (y_pred[0, -4:] == 0).all() # right padding
41+
assert (y_pred[1, :5] == 0).all() # left padding
4242

4343
# test `viterbi_acc
4444
_, v_acc, _ = model.evaluate(x, y)
@@ -48,7 +48,7 @@ def test_CRF():
4848
# test config
4949
model.get_config()
5050
model = model_from_json(model.to_json())
51-
51+
5252
# test marginal learn mode, fix length, unroll
5353

5454
model = Sequential()

0 commit comments

Comments
 (0)