Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

switch to torchmetrics #169

Merged
merged 9 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Callable, List, Optional, Tuple, Type

import torch
from pytorch_lightning.metrics import Metric
from pytorch_tabnet.tab_network import TabNet
from torch.nn import functional as F
from torchmetrics import Metric

from flash.core.classification import ClassificationTask
from flash.core.data import DataPipeline
Expand Down
2 changes: 1 addition & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Callable, Mapping, Sequence, Type, Union

import torch
from pytorch_lightning.metrics.classification import Accuracy
from torchmetrics.classification import Accuracy
from transformers import BertForSequenceClassification

from flash.core.classification import ClassificationDataPipeline, ClassificationTask
Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/summarization/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import numpy as np
import torch
from pytorch_lightning.metrics import Metric
from rouge_score import rouge_scorer, scoring
from rouge_score.scoring import AggregateScore, Score
from torchmetrics import Metric

from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence

Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/translation/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List

import torch
from pytorch_lightning.metrics import Metric
from torchmetrics import Metric


def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
Expand Down
4 changes: 2 additions & 2 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Callable, Mapping, Sequence, Type, Union

import torch
from pytorch_lightning.metrics import Accuracy
from torch import nn
from torch.nn import functional as F
from torchmetrics import Accuracy

from flash.core.classification import ClassificationTask
from flash.vision.backbones import backbone_and_num_features
Expand All @@ -33,7 +33,7 @@ class ImageClassifier(ClassificationTask):
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`.
metrics: Metrics to compute for training and evaluation,
defaults to :class:`pytorch_lightning.metrics.Accuracy`.
defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
"""

Expand Down
2 changes: 1 addition & 1 deletion flash/vision/embedding/image_embedder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union

import torch
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.nn import functional as F
from torchmetrics import Accuracy

from flash.core import Task
from flash.core.data import TaskDataPipeline
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/finetuning/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
from torchmetrics.classification import Accuracy, Precision, Recall

import flash
from flash.core.data import download_data
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pytorch-lightning==1.2.0rc0 # todo: we shall align with real 1.2
torch>=1.7 # TODO: regenerate weights with lewer PT version
PyYAML>=5.1
Pillow>=7.2
torchmetrics>=0.2
torchvision>=0.8 # lower to 0.7 after PT 1.6
transformers>=4.0
pytorch-tabnet==3.1
Expand Down