Skip to content

Conversation

@tomaarsen
Copy link
Owner

@tomaarsen tomaarsen commented Nov 8, 2023

Hello!

Pull Request overview

  • Implement SetFit ABSA from Intel Labs into SetFit.
  • Primary new classes:
    • AbsaModel:
      • predict
      • from_pretrained
      • save_pretrained
      • push_to_hub
      • to
      • device
    • AbsaTrainer:
      • train
      • evaluate
      • add_callback
      • pop_callback
      • remove_callback
      • push_to_hub
  • Add device property to SetFitModel.
  • Modernize SetFitModel.from_pretrained with token=... instead of use_auth_token=...
  • Throw ValueError if args on Trainer is the wrong type, e.g. if it's transformers TrainingArguments.
  • Allow partial column_mapping, move column mapping behaviour into a Mixin.
  • Add test suite for AbsaModel: ~95% test coverage on new behaviour, only push_to_hub is untested.

Usage

Training (Basic)

from setfit import AbsaModel, AbsaTrainer
from datasets import load_dataset

# You can initialize a AbsaModel using one or two SentenceTransformer models, or two ABSA models
# model = AbsaModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AbsaModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2")

# The training/eval dataset must have `text`, `span`, `polarity`, and `ordinal` columns
raw_dataset = load_dataset("data", data_files="example_training_file.csv")
train_dataset = raw_dataset["train"].rename_columns({"sentence": "text", "aspect": "span", "polarity": "label"})

# The minimal Trainer instantiation
trainer = AbsaTrainer(model, train_dataset=train_dataset)
trainer.train()

Training (Advanced)

from setfit import AbsaModel, AbsaTrainer, TrainingArguments
from transformers import EarlyStoppingCallback
from datasets import load_dataset

# You can initialize a AbsaModel using one or two SentenceTransformer models, or two ABSA models
# model = AbsaModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AbsaModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2")

# The training/eval dataset must have `text`, `span`, `polarity`, and `ordinal` columns
raw_dataset = load_dataset("data", data_files="example_training_file.csv")["train"]
raw_dataset = raw_dataset.rename_columns({"sentence": "text", "aspect": "span", "polarity": "label"})
raw_dataset = raw_dataset.train_test_split(test_size=10)
train_dataset, eval_dataset = raw_dataset["train"], raw_dataset["test"]

# Training arguments for aspect and polarity training
aspect_args = TrainingArguments(
    output_dir="aspect",
    num_epochs=2,
    body_learning_rate=5e-5,
    head_learning_rate=1e-2,
    use_amp=True,
    warmup_proportion=0.2,
    evaluation_strategy="steps",
    eval_steps=20,
    save_steps=20,
    load_best_model_at_end=True,
)
polarity_args = TrainingArguments(
    output_dir="polarity",
    num_epochs=3,
    max_steps=1000,
    body_learning_rate=2e-5,
    head_learning_rate=3e-2,
    evaluation_strategy="steps",
    eval_steps=20,
    save_steps=20,
    load_best_model_at_end=True,
)

trainer = AbsaTrainer(
    model,
    args=aspect_args,
    polarity_args=polarity_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
trainer.train()

metrics = trainer.evaluate()
print(metrics)

Inference

# Predicting is as easy as with SetFit
predictions = model.predict([
    "Best pizza outside of Italy and really tasty.",
    "The variations are great and the prices are absolutely fair.",
    "Unfortunately, you have to expect some waiting time and get a note with a waiting number if it should be very full."
])
print(predictions)
"""
[
    [{'span': 'pizza', 'polarity': 'positive'}],
    [{'span': 'variations', 'polarity': 'positive'}, {'span': 'prices', 'polarity': 'positive'}],
    [{'span': 'waiting time', 'polarity': 'negative'}, {'span': 'note', 'polarity': 'positive'}, {'span': 'number', 'polarity': 'negative'}]
]
"""

Note: The model on display here was trained with a whopping 43 aspects. Not 43 aspects per class mind you, just 43 aspects between only 24 sentences (!).

Saving/Pushing to the Hub

# You can push to the Hub/save models using one or two repo_ids:
trainer.push_to_hub("tomaarsen/setfit-absa-restaurant-review", private=True)
# trainer.push_to_hub("tomaarsen/setfit-absa-restaurant-review-aspect", "tomaarsen/setfit-absa-restaurant-review-polarity", private=True)
# Or directly on the model:
# model.push_to_hub("tomaarsen/setfit-absa-restaurant-review-aspect", "tomaarsen/setfit-absa-restaurant-review-polarity", private=True)
# model.save_pretrained("absa-model")
# model.save_pretrained("absa-model-aspect", "absa-model-polarity", private=True)

TODO

  • Better model cards of saved models.

cc: @rlaperdo


  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants