diff --git a/examples/token-classification/run_ner.py b/examples/token-classification/run_ner.py index 3eed7098a5aa..7054dea78ca8 100644 --- a/examples/token-classification/run_ner.py +++ b/examples/token-classification/run_ner.py @@ -344,7 +344,7 @@ def compute_metrics(p): if training_args.do_predict: logger.info("*** Predict ***") - test_dataset = datasets["test"] + test_dataset = tokenized_datasets["test"] predictions, labels, metrics = trainer.predict(test_dataset) predictions = np.argmax(predictions, axis=2)