diff --git a/relation_extract.py b/relation_extract.py index 0d14f64..171b8a1 100644 --- a/relation_extract.py +++ b/relation_extract.py @@ -211,7 +211,7 @@ def seq_gather(x): return K.tf.gather_nd(seq, idxs) -bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path) +bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None) for l in bert_model.layers: l.trainable = True