Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First version of Zero-Shot Sequence Labeler #2260

Merged
merged 36 commits into from
May 1, 2021
Merged

First version of Zero-Shot Sequence Labeler #2260

merged 36 commits into from
May 1, 2021

Conversation

alanakbik
Copy link
Collaborator

@alanakbik alanakbik commented May 1, 2021

First version of TARS few-shot sequence tagger. Train like this:

# init corpus and map label names to descriptions
corpus = WNUT_17(label_name_map={
    "location": "location name",
    "corporation": "corporation name",
    "person": "person name",
    "creative-work": "name of song, movie, book or other creative work",
    "product": "name of product or consumer good",
    "group": "name of music band, sports team or non-corporate organization",
}
)

dictionary = corpus.make_label_dictionary('ner')
print(dictionary)

# init the TARS sequence tagger
tars_tagger = TARSTagger(
    'ner_wnut',
    dictionary,
    tag_type='ner',
    embeddings='bert-base-uncased',
    num_negative_labels_to_sample=1,
    prefix=True,
)

# train the model
trainer = ModelTrainer(tars_tagger, corpus)

trainer.train('resources/taggers/few-shot-sequence-tagger',
              learning_rate=0.02,
              mini_batch_size=16,
              mini_batch_chunk_size=1,
              max_epochs=20,
              monitor_test=True,
              embeddings_storage_mode="none",
              )

This PR also makes a number of smaller changes:

  • Change the best model logic so that again the best-model.pt file is saved. The information which epoch was the best model is no longer encoded directly in the filename but rather will be added explicitly to the logs.
  • Changes in loss averaging so that we get consistency between training / testing and in particular TARS-like approaches. This means that the sequence tagger no longer returns an averaged loss over all words in the mini-batch, but rather a summed loss plus the information over how many words the loss was summed.

@alanakbik alanakbik merged commit d5dd0a2 into master May 1, 2021
@alanakbik alanakbik deleted the tars_tagger branch May 1, 2021 09:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant