diff --git a/CHANGELOG.md b/CHANGELOG.md index 4461ceff74..812b64f5f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,12 +34,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Flash Zero, a zero code command line ML platform built with flash ([#611](https://github.com/PyTorchLightning/lightning-flash/pull/611)) +- Added support for `.npy` and `.npz` files to `ImageClassificationData` and `AudioClassificationData` ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) + +- Added support for `from_csv` to the `AudioClassificationData` ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) + +- Added option to pass a `resolver` to the `from_csv` and `from_pandas` methods of `ImageClassificationData`, which is used to resolve filenames given IDs ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) - Removed bolts pretrained weights for SSL from ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) +- Changed the behaviour of the `sampler` argument of the `DataModule` to take a `Sampler` type rather than instantiated object ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) + ### Fixed - Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493)) @@ -50,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where some tasks were not compatible with PyTorch 1.7 due to use of `torch.jit.isinstance` ([#611](https://github.com/PyTorchLightning/lightning-flash/pull/611)) +- Fixed a bug where custom samplers would not be properly forwarded to the data loader ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) + ## [0.4.0] - 2021-06-22 ### Added diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 940d60ff96..99bf240646 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -548,29 +548,6 @@ def test_from_csv_multi_target(multi_target_csv): assert labels.shape == (2, 2) -# @pytest.fixture -# def bad_csv_multi_image(image_tmpdir): -# with open(image_tmpdir / "metadata.csv", "w") as csvfile: -# fieldnames = ["image", "target"] -# writer = csv.DictWriter(csvfile, fieldnames) -# writer.writeheader() -# writer.writerow({"image": "image", "target": "Ants"}) -# return str(image_tmpdir / "metadata.csv") - - -# @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -# def test_from_bad_csv_multi_image(bad_csv_multi_image): -# with pytest.raises(ValueError, match="Found multiple matches"): -# img_data = ImageClassificationData.from_csv( -# "image", -# ["target"], -# train_file=bad_csv_multi_image, -# batch_size=1, -# num_workers=0, -# ) -# _ = next(iter(img_data.train_dataloader())) - - @pytest.fixture def bad_csv_no_image(image_tmpdir): with open(image_tmpdir / "metadata.csv", "w") as csvfile: