Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into refactor/data_source
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 authored Nov 4, 2021
2 parents 9e452ff + 642e63f commit c25526b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TextClassifier(ClassificationTask):
def __init__(
self,
num_classes: int,
backbone: str = "prajjwal1/bert-medium",
backbone: str = "prajjwal1/bert-tiny",
loss_fn: LOSS_FN_TYPE = None,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
Expand Down
4 changes: 2 additions & 2 deletions flash_examples/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
backbone="prajjwal1/bert-medium",
backbone="prajjwal1/bert-tiny",
)

# 2. Build the task
model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes)
model = TextClassifier(backbone=datamodule.backbone, num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
Expand Down
8 changes: 4 additions & 4 deletions flash_notebooks/text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@
"outputs": [],
"source": [
"datamodule = TextClassificationData.from_csv(\n",
" \"review\",\n",
" \"sentiment\",\n",
" train_file=\"data/imdb/train.csv\",\n",
" val_file=\"data/imdb/valid.csv\",\n",
" test_file=\"data/imdb/test.csv\",\n",
" input_fields=\"review\",\n",
" target_fields=\"sentiment\",\n",
" backbone=\"prajjwal1/bert-tiny\",\n",
")"
]
Expand Down Expand Up @@ -302,9 +302,9 @@
"outputs": [],
"source": [
"datamodule = TextClassificationData.from_csv(\n",
" \"review\",\n",
" predict_file=\"data/imdb/predict.csv\",\n",
" input_fields=\"review\",\n",
" backbone=model.backbone,\n",
" backbone=\"prajjwal1/bert-tiny\",\n",
")\n",
"predictions = flash.Trainer().predict(model, datamodule=datamodule)\n",
"print(predictions)"
Expand Down

0 comments on commit c25526b

Please sign in to comment.