Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 25, 2023
1 parent 2fe2c8a commit 4ec9026
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 41 deletions.
2 changes: 2 additions & 0 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def extract(
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
tokens = tokenizer(batch_of_texts, return_offsets_mapping=True)
# remove CLS and SEP tokens, they are added later anyhow
old_batch_of_texts = batch_of_texts

Check failure on line 88 in wtpsplit/extract.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (F841)

wtpsplit/extract.py:88:9: F841 Local variable `old_batch_of_texts` is assigned to but never used

Check failure on line 88 in wtpsplit/extract.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F841)

wtpsplit/extract.py:88:9: F841 Local variable `old_batch_of_texts` is assigned to but never used
batch_of_texts = [text[1:-1] for text in tokens["input_ids"]]
offset_mapping = [offset[1:-1] for offset in tokens["offset_mapping"]]
cls_token_id = tokenizer.cls_token_id
Expand Down Expand Up @@ -219,6 +220,7 @@ def extract(
attention_mask=batch_attention_mask,
**kwargs,
)["logits"]
print(np.max(logits[0, :, 0]))

for i in range(start, end):
original_idx, start_char_idx, end_char_idx = locs[i]
Expand Down
7 changes: 2 additions & 5 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,17 @@ def evaluate_sentence(
batch_size=batch_size,
verbose=True,
)
logits, offsets_mapping = logits[0], offsets_mapping[0]
logits, offsets_mapping = logits[0], offsets_mapping[0] # FIXME

true_end_indices = np.cumsum(np.array([len(s) for s in sentences])) + np.arange(len(sentences)) * len(separator)
newline_labels = np.zeros(len(text))
newline_labels[true_end_indices - 1] = 1

print("newline_labels", newline_labels.shape)


if "xlm" in model.config.model_type:
tokens = tokenizer.tokenize(text)
char_probs = token_to_char_probs(text, tokens, logits[:, positive_index], tokenizer, offsets_mapping)
else:
char_probs = logits[:, positive_index]
print("char probs", char_probs.shape)
metrics, info = get_metrics(newline_labels, char_probs)

info["newline_labels"] = newline_labels
Expand Down
87 changes: 52 additions & 35 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def forward(
label_weights=None,
**kwargs,
):
reduced_attention_mask = (input_ids != 0).to(torch.long)
if position_ids is not None:
reduced_attention_mask = (input_ids != 0).to(torch.long)
else:
# XXX: 1 is pad token id
reduced_attention_mask = (input_ids != 1).to(torch.long)

output = dict(
self.backbone.forward(
Expand All @@ -91,6 +95,7 @@ def forward(

# main (newline prediction) objective
if self.do_sentence_training:
# label smoothing
sentence_labels = (0.5 - self.loss_margin) + (labels == Constants.NEWLINE_INDEX + 1).to(
logits.dtype
).view(-1) * self.loss_margin * 2
Expand All @@ -112,11 +117,13 @@ def forward(
if self.do_auxiliary_training:
loss_fn = nn.CrossEntropyLoss()

# exclude newline and no labels
aux_labels = torch.where(
(labels == 0) | (labels == Constants.NEWLINE_INDEX + 1),
0,
labels - Constants.AUX_OFFSET,
)
# exclude reduced_attention_mask tokens from labels
aux_labels = torch.where(
reduced_attention_mask == 1,
aux_labels,
Expand Down Expand Up @@ -208,16 +215,16 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):
tokenizer=tokenizer if args.use_subwords else None,
)

if input_ids[0] != tokenizer.cls_token_id:
print(input_ids)
print(len(input_ids))
print(tokenizer.cls_token_id)
raise ValueError("CLS token not first token")
if input_ids[-1] != tokenizer.sep_token_id:
print(input_ids)
print(len(input_ids))
print(tokenizer.sep_token_id)
raise ValueError("SEP token not last token")
# if input_ids[0] != tokenizer.cls_token_id:
# print(input_ids)
# print(len(input_ids))
# print(tokenizer.cls_token_id)
# raise ValueError("CLS token not first token")
# if input_ids[-1] != tokenizer.sep_token_id:
# print(input_ids)
# print(len(input_ids))
# print(tokenizer.sep_token_id)
# raise ValueError("SEP token not last token")

if len(input_ids) > args.block_size:
if tokenizer:
Expand All @@ -242,21 +249,22 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):

input_ids = torch.tensor(input_ids[: args.block_size], dtype=torch.long)
labels = torch.tensor(labels[: args.block_size], dtype=torch.long)
if input_ids[-1] != tokenizer.sep_token_id:
print(input_ids)
print(tokenizer.sep_token_id)
print(labels)
raise ValueError("SEP token not last token")
if input_ids[0] != tokenizer.cls_token_id:
print(input_ids)
print(tokenizer.cls_token_id)
print(labels)
raise ValueError("CLS token not first token")
if (input_ids == tokenizer.cls_token_id).sum() != 1:
print(input_ids)
print(tokenizer.cls_token_id)
print(labels)
raise ValueError("CLS token not unique")
# if input_ids[-1] != tokenizer.sep_token_id:
# print(input_ids)
# print(tokenizer.sep_token_id)
# print(labels)
# raise ValueError("SEP token not last token")
# if input_ids[0] != tokenizer.cls_token_id:
# print(input_ids)
# print(tokenizer.cls_token_id)
# print(labels)
# raise ValueError("CLS token not first token")
# TODO: check this - why does it occur in train split?
# if (input_ids == tokenizer.cls_token_id).sum() != 1:
# print(input_ids)
# print(tokenizer.cls_token_id)
# print(labels)
# raise ValueError("CLS token not unique")

position_ids = torch.arange(len(input_ids), dtype=torch.long)
label_weights = torch.ones(args.block_size, dtype=torch.float32)
Expand Down Expand Up @@ -317,9 +325,12 @@ def main():

backbone.config.base_model = args.model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
# needed since we create labels in collate_fn based on tokens
# TODO: problematic for <UNK> tokens!
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})

else:
tokenizer = None
config = LACanineConfig.from_pretrained(
args.model_name_or_path,
raw_lookahead=args.lookahead,
Expand Down Expand Up @@ -357,7 +368,6 @@ def main():

with training_args.main_process_first():
print(summary(model, depth=4))
# also save base model
# backbone.push_to_hub("markus583/xlm-token-untrained", private=True)

def prepare_dataset(
Expand Down Expand Up @@ -582,9 +592,9 @@ def maybe_pad(text):
batched=True,
num_proc=num_workers,
# a bit hacky but oh well, only drop if sentence
remove_columns=["ends_with_punctuation"]
remove_columns=["ends_with_punctuation", args.text_column]
if args.text_column == "text"
else [],
else [args.text_column],
)

return dataset
Expand All @@ -599,14 +609,21 @@ def maybe_pad(text):
num_workers=args.preprocessing_num_workers,
include_languages=args.include_languages,
shuffle=args.shuffle,
split="train",
split="valid",
)

# print some samples from the dataset
for index in random.sample(range(len(train_dataset)), 5):
print(f"Sample {index} of the training set: {train_dataset[index]}.")
print(tokenizer.decode(train_dataset[index]["input_ids"]))
print()
count = 0
while count < 5:
index = random.choice(range(len(train_dataset)))
sample = train_dataset[index]

if sample.get('lang') == "de":
print(f"Sample {index} of the training set: {sample}.")
if tokenizer:
print(tokenizer.decode(sample["input_ids"]))
print()
count += 1

# dataset we use is in cached now
# m_c4 files are test/valid splits of already downloaded data
Expand All @@ -628,7 +645,7 @@ def compute_metrics(trainer):

model = trainer._wrap_model(trainer.model, training=False)

for lang_code, lang_data in eval_data.items():
for lang_code, lang_data in eval_data.items(): # TODO: tqdm integration
if args.include_languages is not None and lang_code not in args.include_languages:
continue

Expand Down
7 changes: 6 additions & 1 deletion wtpsplit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def get_subword_label_dict(label_args, tokenizer):
for i, c in enumerate(label_args.auxiliary_chars):
token_id = tokenizer.convert_tokens_to_ids(c)
label_dict[token_id] = 1 + Constants.AUX_OFFSET + i
# TODO: remove UNKs?
print(f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}")
if token_id == tokenizer.unk_token_id:
n_unks += 1
Expand Down Expand Up @@ -177,7 +178,11 @@ def corrupt(
if random.random() < label_args.newline_remove_prob:
if separator == " " and random.random() < label_args.newline_whitespace_prob:
if tokenizer:
input_ids[i + 1] = tokenizer.convert_tokens_to_ids(" ")
# inserting " " leaks \n information
# the token is never there naturally, so it is a 1:1 proxy for \n
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
else:
input_ids[i + 1] = ord(" ")
else:
Expand Down

0 comments on commit 4ec9026

Please sign in to comment.