diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 674cbcd6e3..5c1da80a2c 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -57,7 +57,7 @@ class TextClassifier(ClassificationTask): def __init__( self, num_classes: int, - backbone: str = "prajjwal1/bert-medium", + backbone: str = "prajjwal1/bert-tiny", loss_fn: LOSS_FN_TYPE = None, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index 3d62dbb0dc..bdeedbeb94 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -25,11 +25,11 @@ "sentiment", train_file="data/imdb/train.csv", val_file="data/imdb/valid.csv", - backbone="prajjwal1/bert-medium", + backbone="prajjwal1/bert-tiny", ) # 2. Build the task -model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes) +model = TextClassifier(backbone=datamodule.backbone, num_classes=datamodule.num_classes) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) diff --git a/flash_notebooks/text_classification.ipynb b/flash_notebooks/text_classification.ipynb index 72f0e9a367..183695e8db 100644 --- a/flash_notebooks/text_classification.ipynb +++ b/flash_notebooks/text_classification.ipynb @@ -110,11 +110,11 @@ "outputs": [], "source": [ "datamodule = TextClassificationData.from_csv(\n", + " \"review\",\n", + " \"sentiment\",\n", " train_file=\"data/imdb/train.csv\",\n", " val_file=\"data/imdb/valid.csv\",\n", " test_file=\"data/imdb/test.csv\",\n", - " input_fields=\"review\",\n", - " target_fields=\"sentiment\",\n", " backbone=\"prajjwal1/bert-tiny\",\n", ")" ] @@ -302,9 +302,9 @@ "outputs": [], "source": [ "datamodule = TextClassificationData.from_csv(\n", + " \"review\",\n", " predict_file=\"data/imdb/predict.csv\",\n", - " input_fields=\"review\",\n", - " backbone=model.backbone,\n", + " backbone=\"prajjwal1/bert-tiny\",\n", ")\n", "predictions = flash.Trainer().predict(model, datamodule=datamodule)\n", "print(predictions)"