diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index 4bbad3d5..f0191ce9 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -10,8 +10,9 @@ from tqdm.auto import tqdm from transformers import AutoModelForTokenClassification, HfArgumentParser +import wtpsplit.models from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture -from wtpsplit.extract import extract +from wtpsplit.extract import PyTorchWrapper, extract from wtpsplit.utils import Constants @@ -91,7 +92,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_ block_size=args.block_size, batch_size=args.batch_size, pad_last_batch=True, - )[0].numpy() + )[0] test_labels = get_labels(lang_code, test_sentences, after_space=False) dset_group.create_dataset("test_logits", data=test_logits) @@ -110,7 +111,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_ block_size=args.block_size, batch_size=args.batch_size, pad_last_batch=False, - )[0].numpy() + )[0] train_labels = get_labels(lang_code, train_sentences, after_space=False) dset_group.create_dataset("train_logits", data=train_logits) @@ -128,7 +129,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_ else: valid_data = None - model = AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device) + model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device)) # first, logits for everything. f = load_or_compute_logits(args, model, eval_data, valid_data) diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 6e3515b0..96718bb0 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -149,8 +149,8 @@ def extract( if len(batch_input_hashes) < batch_size and pad_last_batch: n_missing = batch_size - len(batch_input_hashes) - batch_input_hashes = np.pad(batch_input_hashes, (0, n_missing, 0, 0, 0, 0)) - batch_attention_mask = np.pad(batch_attention_mask, (0, n_missing, 0, 0)) + batch_input_hashes = np.pad(batch_input_hashes, ((0, n_missing), (0, 0), (0, 0))) + batch_attention_mask = np.pad(batch_attention_mask, ((0, n_missing), (0, 0))) kwargs = {"language_ids": language_ids[: len(batch_input_hashes)]} if uses_lang_adapters else {}