From da50d9b6ebffd67f2a86d5941365fd478efca541 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 2 Feb 2021 21:40:15 +0100 Subject: [PATCH 1/5] simplify examples --- flash/setup_tools.py | 4 +- .../finetuning/image_classification.py | 38 +++++----- flash_examples/finetuning/summarization.py | 41 ++++++----- .../finetuning/tabular_classification.py | 44 ++++++------ .../finetuning/text_classification.py | 44 ++++++------ flash_examples/finetuning/translation.py | 41 ++++++----- flash_examples/predict/classify_image.py | 32 ++++----- flash_examples/predict/classify_tabular.py | 18 ++--- flash_examples/predict/classify_text.py | 42 ++++++----- flash_examples/predict/image_embedder.py | 32 ++++----- flash_examples/predict/summarize.py | 71 +++++++++---------- flash_examples/predict/translate.py | 36 +++++----- 12 files changed, 212 insertions(+), 231 deletions(-) diff --git a/flash/setup_tools.py b/flash/setup_tools.py index c447d1f043f..1812713d9fa 100644 --- a/flash/setup_tools.py +++ b/flash/setup_tools.py @@ -69,7 +69,9 @@ def _load_readme_description(path_dir: str, homepage: str = __homepage__, ver: s github_source_url = os.path.join(homepage, "raw", ver) # replace relative repository path to absolute link to the release # do not replace all "docs" as in the readme we reger some other sources with particular path to docs - text = text.replace("docs/source/_static/images/", f"{os.path.join(github_source_url, 'docs/source/_static/images/')}") + text = text.replace( + "docs/source/_static/images/", f"{os.path.join(github_source_url, 'docs/source/_static/images/')}" + ) # readthedocs badge text = text.replace('badge/?version=stable', f'badge/?version={ver}') diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index b5202c16611..3bb7b84c892 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -3,29 +3,27 @@ from flash.core.finetuning import FreezeUnfreeze from flash.vision import ImageClassificationData, ImageClassifier -if __name__ == "__main__": +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') +# 2. Load the data +datamodule = ImageClassificationData.from_folders( + train_folder="data/hymenoptera_data/train/", + valid_folder="data/hymenoptera_data/val/", + test_folder="data/hymenoptera_data/test/", +) - # 2. Load the data - datamodule = ImageClassificationData.from_folders( - train_folder="data/hymenoptera_data/train/", - valid_folder="data/hymenoptera_data/val/", - test_folder="data/hymenoptera_data/test/", - ) +# 3. Build the model +model = ImageClassifier(num_classes=datamodule.num_classes) - # 3. Build the model - model = ImageClassifier(num_classes=datamodule.num_classes) +# 4. Create the trainer. Run twice on data +trainer = flash.Trainer(max_epochs=2) - # 4. Create the trainer. Run twice on data - trainer = flash.Trainer(max_epochs=2) +# 5. Train the model +trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) - # 5. Train the model - trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) +# 6. Test the model +trainer.test() - # 6. Test the model - trainer.test() - - # 7. Save it! - trainer.save_checkpoint("image_classification_model.pt") +# 7. Save it! +trainer.save_checkpoint("image_classification_model.pt") diff --git a/flash_examples/finetuning/summarization.py b/flash_examples/finetuning/summarization.py index 806ce8bfcfc..5565bdcc70f 100644 --- a/flash_examples/finetuning/summarization.py +++ b/flash_examples/finetuning/summarization.py @@ -2,30 +2,29 @@ from flash import download_data from flash.text import SummarizationData, SummarizationTask -if __name__ == "__main__": - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') - # 2. Load the data - datamodule = SummarizationData.from_files( - train_file="data/xsum/train.csv", - valid_file="data/xsum/valid.csv", - test_file="data/xsum/test.csv", - input="input", - target="target" - ) +# 2. Load the data +datamodule = SummarizationData.from_files( + train_file="data/xsum/train.csv", + valid_file="data/xsum/valid.csv", + test_file="data/xsum/test.csv", + input="input", + target="target" +) - # 3. Build the model - model = SummarizationTask() +# 3. Build the model +model = SummarizationTask() - # 4. Create the trainer. Run once on data - trainer = flash.Trainer(max_epochs=1) +# 4. Create the trainer. Run once on data +trainer = flash.Trainer(max_epochs=1) - # 5. Fine-tune the model - trainer.finetune(model, datamodule=datamodule) +# 5. Fine-tune the model +trainer.finetune(model, datamodule=datamodule) - # 6. Test model - trainer.test() +# 6. Test model +trainer.test() - # 7. Save it! - trainer.save_checkpoint("summarization_model_xsum.pt") +# 7. Save it! +trainer.save_checkpoint("summarization_model_xsum.pt") diff --git a/flash_examples/finetuning/tabular_classification.py b/flash_examples/finetuning/tabular_classification.py index 69e6496e399..836ebd6b7a7 100644 --- a/flash_examples/finetuning/tabular_classification.py +++ b/flash_examples/finetuning/tabular_classification.py @@ -4,32 +4,30 @@ from flash.core.data import download_data from flash.tabular import TabularClassifier, TabularData -if __name__ == "__main__": +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') +# 2. Load the data +datamodule = TabularData.from_csv( + "./data/titanic/titanic.csv", + test_csv="./data/titanic/test.csv", + categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], + numerical_input=["Fare"], + target="Survived", + val_size=0.25, +) - # 2. Load the data - datamodule = TabularData.from_csv( - "./data/titanic/titanic.csv", - test_csv="./data/titanic/test.csv", - categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], - numerical_input=["Fare"], - target="Survived", - val_size=0.25, - ) +# 3. Build the model +model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()]) - # 3. Build the model - model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()]) +# 4. Create the trainer. Run 10 times on data +trainer = flash.Trainer(max_epochs=10) - # 4. Create the trainer. Run 10 times on data - trainer = flash.Trainer(max_epochs=10) +# 5. Train the model +trainer.fit(model, datamodule=datamodule) - # 5. Train the model - trainer.fit(model, datamodule=datamodule) +# 6. Test model +trainer.test() - # 6. Test model - trainer.test() - - # 7. Save it! - trainer.save_checkpoint("tabular_classification_model.pt") +# 7. Save it! +trainer.save_checkpoint("tabular_classification_model.pt") diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index dd07f46bf90..c37f9091581 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -2,32 +2,30 @@ from flash.core.data import download_data from flash.text import TextClassificationData, TextClassifier -if __name__ == "__main__": +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') +# 2. Load the data +datamodule = TextClassificationData.from_files( + train_file="data/imdb/train.csv", + valid_file="data/imdb/valid.csv", + test_file="data/imdb/test.csv", + input="review", + target="sentiment", + batch_size=512 +) - # 2. Load the data - datamodule = TextClassificationData.from_files( - train_file="data/imdb/train.csv", - valid_file="data/imdb/valid.csv", - test_file="data/imdb/test.csv", - input="review", - target="sentiment", - batch_size=512 - ) +# 3. Build the model +model = TextClassifier(num_classes=datamodule.num_classes) - # 3. Build the model - model = TextClassifier(num_classes=datamodule.num_classes) +# 4. Create the trainer. Run once on data +trainer = flash.Trainer(max_epochs=1) - # 4. Create the trainer. Run once on data - trainer = flash.Trainer(max_epochs=1) +# 5. Fine-tune the model +trainer.finetune(model, datamodule=datamodule, strategy='freeze') - # 5. Fine-tune the model - trainer.finetune(model, datamodule=datamodule, strategy='freeze') +# 6. Test model +trainer.test() - # 6. Test model - trainer.test() - - # 7. Save it! - trainer.save_checkpoint("text_classification_model.pt") +# 7. Save it! +trainer.save_checkpoint("text_classification_model.pt") diff --git a/flash_examples/finetuning/translation.py b/flash_examples/finetuning/translation.py index e7f1debce3e..78599a0ed19 100644 --- a/flash_examples/finetuning/translation.py +++ b/flash_examples/finetuning/translation.py @@ -2,30 +2,29 @@ from flash import download_data from flash.text import TranslationData, TranslationTask -if __name__ == "__main__": - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') - # 2. Load the data - datamodule = TranslationData.from_files( - train_file="data/wmt_en_ro/train.csv", - valid_file="data/wmt_en_ro/valid.csv", - test_file="data/wmt_en_ro/test.csv", - input="input", - target="target", - ) +# 2. Load the data +datamodule = TranslationData.from_files( + train_file="data/wmt_en_ro/train.csv", + valid_file="data/wmt_en_ro/valid.csv", + test_file="data/wmt_en_ro/test.csv", + input="input", + target="target", +) - # 3. Build the model - model = TranslationTask() +# 3. Build the model +model = TranslationTask() - # 4. Create the trainer. Run once on data - trainer = flash.Trainer(max_epochs=1, precision=16, gpus=1) +# 4. Create the trainer. Run once on data +trainer = flash.Trainer(max_epochs=1, precision=16, gpus=1) - # 5. Fine-tune the model - trainer.finetune(model, datamodule=datamodule) +# 5. Fine-tune the model +trainer.finetune(model, datamodule=datamodule) - # 6. Test model - trainer.test() +# 6. Test model +trainer.test() - # 7. Save it! - trainer.save_checkpoint("translation_model_en_ro.pt") +# 7. Save it! +trainer.save_checkpoint("translation_model_en_ro.pt") diff --git a/flash_examples/predict/classify_image.py b/flash_examples/predict/classify_image.py index abed822a873..f3bff574c1f 100644 --- a/flash_examples/predict/classify_image.py +++ b/flash_examples/predict/classify_image.py @@ -2,23 +2,21 @@ from flash.core.data import download_data from flash.vision import ImageClassificationData, ImageClassifier -if __name__ == "__main__": +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') +# 2. Load the model from a checkpoint +model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") - # 2. Load the model from a checkpoint - model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") +# 3a. Predict what's on a few images! ants or bees? +predictions = model.predict([ + "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", + "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", + "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", +]) +print(predictions) - # 3a. Predict what's on a few images! ants or bees? - predictions = model.predict([ - "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", - "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", - "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", - ]) - print(predictions) - - # 3b. Or generate predictions with a whole folder! - datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/") - predictions = Trainer().predict(model, datamodule=datamodule) - print(predictions) +# 3b. Or generate predictions with a whole folder! +datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/") +predictions = Trainer().predict(model, datamodule=datamodule) +print(predictions) diff --git a/flash_examples/predict/classify_tabular.py b/flash_examples/predict/classify_tabular.py index 0b7ed795ca4..46c57249524 100644 --- a/flash_examples/predict/classify_tabular.py +++ b/flash_examples/predict/classify_tabular.py @@ -1,16 +1,12 @@ from flash.core.data import download_data from flash.tabular import TabularClassifier -if __name__ == "__main__": +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') +# 2. Load the model from a checkpoint +model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt") - # 2. Load the model from a checkpoint - model = TabularClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt" - ) - - # 3. Generate predictions from a sheet file! Who would survive? - predictions = model.predict("data/titanic/titanic.csv") - print(predictions) +# 3. Generate predictions from a sheet file! Who would survive? +predictions = model.predict("data/titanic/titanic.csv") +print(predictions) diff --git a/flash_examples/predict/classify_text.py b/flash_examples/predict/classify_text.py index e9b4585e926..89d38a1a091 100644 --- a/flash_examples/predict/classify_text.py +++ b/flash_examples/predict/classify_text.py @@ -3,28 +3,26 @@ from flash.core.data import download_data from flash.text import TextClassificationData, TextClassifier -if __name__ == "__main__": +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') +# 2. Load the model from a checkpoint +model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") - # 2. Load the model from a checkpoint - model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") +# 2a. Classify a few sentences! How was the movie? +predictions = model.predict([ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado." + "Very, very afraid" + "This guy has done a great job with this movie!", +]) +print(predictions) - # 2a. Classify a few sentences! How was the movie? - predictions = model.predict([ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado." - "Very, very afraid" - "This guy has done a great job with this movie!", - ]) - print(predictions) - - # 2b. Or generate predictions from a sheet file! - datamodule = TextClassificationData.from_file( - predict_file="data/imdb/predict.csv", - input="review", - ) - predictions = Trainer().predict(model, datamodule=datamodule) - print(predictions) +# 2b. Or generate predictions from a sheet file! +datamodule = TextClassificationData.from_file( + predict_file="data/imdb/predict.csv", + input="review", +) +predictions = Trainer().predict(model, datamodule=datamodule) +print(predictions) diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py index 3463258a123..9653316462e 100644 --- a/flash_examples/predict/image_embedder.py +++ b/flash_examples/predict/image_embedder.py @@ -3,26 +3,24 @@ from flash.core.data import download_data from flash.vision import ImageEmbedder -if __name__ == "__main__": +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') +# 2. Create an ImageEmbedder with swav trained on imagenet. +# Check out SWAV: https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav +embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128) - # 2. Create an ImageEmbedder with swav trained on imagenet. - # Check out SWAV: https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav - embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128) +# 3. Generate an embedding from an image path. +embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg') - # 3. Generate an embedding from an image path. - embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg') +# 4. Print embeddings shape +print(embeddings.shape) - # 4. Print embeddings shape - print(embeddings.shape) +# 5. Create a tensor random image +random_image = torch.randn(1, 3, 32, 32) - # 5. Create a tensor random image - random_image = torch.randn(1, 3, 32, 32) +# 6. Generate an embedding from this random image. +embeddings = embedder.predict(random_image) - # 6. Generate an embedding from this random image. - embeddings = embedder.predict(random_image) - - # 7. Print embeddings shape - print(embeddings.shape) +# 7. Print embeddings shape +print(embeddings.shape) diff --git a/flash_examples/predict/summarize.py b/flash_examples/predict/summarize.py index 1cd5e68e4e7..fe4b62180a2 100644 --- a/flash_examples/predict/summarize.py +++ b/flash_examples/predict/summarize.py @@ -3,42 +3,41 @@ from flash.core.data import download_data from flash.text import SummarizationData, SummarizationTask -if __name__ == "__main__": - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') - # 2. Load the model from a checkpoint - model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") +# 2. Load the model from a checkpoint +model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") - # 2a. Summarize an article! - predictions = model.predict([ - """ - Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local - people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. - They came to Brixton to see work which has started to revitalise the borough. - It was Charles' first visit to the area since 1996, when he was accompanied by the former - South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue - for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. - ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. - She asked me were they ripe and I said yes - they're from the Dominican Republic."" - Mr Chong is one of 170 local retailers who accept the Brixton Pound. - Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market - or in participating shops. - During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children - nearby on an estate off Coldharbour Lane. Mr West said: - ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" - He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" - Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. - The trust hopes to restore and refurbish the building, - where once Jimi Hendrix and The Clash played, as a new community and business centre." - """ - ]) - print(predictions) +# 2a. Summarize an article! +predictions = model.predict([ + """ + Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local + people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. + They came to Brixton to see work which has started to revitalise the borough. + It was Charles' first visit to the area since 1996, when he was accompanied by the former + South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue + for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. + ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. + She asked me were they ripe and I said yes - they're from the Dominican Republic."" + Mr Chong is one of 170 local retailers who accept the Brixton Pound. + Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market + or in participating shops. + During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children + nearby on an estate off Coldharbour Lane. Mr West said: + ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" + He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" + Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. + The trust hopes to restore and refurbish the building, + where once Jimi Hendrix and The Clash played, as a new community and business centre." + """ +]) +print(predictions) - # 2b. Or generate summaries from a sheet file! - datamodule = SummarizationData.from_file( - predict_file="data/xsum/predict.csv", - input="input", - ) - predictions = Trainer().predict(model, datamodule=datamodule) - print(predictions) +# 2b. Or generate summaries from a sheet file! +datamodule = SummarizationData.from_file( + predict_file="data/xsum/predict.csv", + input="input", +) +predictions = Trainer().predict(model, datamodule=datamodule) +print(predictions) diff --git a/flash_examples/predict/translate.py b/flash_examples/predict/translate.py index 4003b689d0d..a0f020052e9 100644 --- a/flash_examples/predict/translate.py +++ b/flash_examples/predict/translate.py @@ -3,25 +3,23 @@ from flash import download_data from flash.text import TranslationData, TranslationTask -if __name__ == "__main__": +# 1. Download the data +download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') +# 2. Load the model from a checkpoint +model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") - # 2. Load the model from a checkpoint - model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") +# 2a. Translate a few sentences! +predictions = model.predict([ + "BBC News went to meet one of the project's first graduates.", + "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", +]) +print(predictions) - # 2a. Translate a few sentences! - predictions = model.predict([ - "BBC News went to meet one of the project's first graduates.", - "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", - ]) - print(predictions) - - # 2b. Or generate translations from a sheet file! - datamodule = TranslationData.from_file( - predict_file="data/wmt_en_ro/predict.csv", - input="input", - ) - predictions = Trainer().predict(model, datamodule=datamodule) - print(predictions) +# 2b. Or generate translations from a sheet file! +datamodule = TranslationData.from_file( + predict_file="data/wmt_en_ro/predict.csv", + input="input", +) +predictions = Trainer().predict(model, datamodule=datamodule) +print(predictions) From 0177af8ca466b7c33ee0efbb93fa759dbca3f4e8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 3 Feb 2021 00:56:24 +0100 Subject: [PATCH 2/5] . --- flash/core/data/datamodule.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flash/core/data/datamodule.py b/flash/core/data/datamodule.py index 7833e0c396c..8907b591187 100644 --- a/flash/core/data/datamodule.py +++ b/flash/core/data/datamodule.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from typing import Any, Optional import pytorch_lightning as pl @@ -63,11 +63,11 @@ def __init__( self.batch_size = batch_size # TODO: figure out best solution for setting num_workers - # if num_workers is None: - # num_workers = os.cpu_count() if num_workers is None: - # warnings.warn("Could not infer cpu count automatically, setting it to zero") - num_workers = 0 + num_workers = os.cpu_count() + # if num_workers is None: + # # warnings.warn("Could not infer cpu count automatically, setting it to zero") + # num_workers = 0 self.num_workers = num_workers self._data_pipeline = None From 7974acde49b24a64fa0d31f61a297d32a8dba569 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 3 Feb 2021 00:58:15 +0100 Subject: [PATCH 3/5] . --- flash/core/data/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/core/data/datamodule.py b/flash/core/data/datamodule.py index 8907b591187..5597a05512a 100644 --- a/flash/core/data/datamodule.py +++ b/flash/core/data/datamodule.py @@ -64,7 +64,7 @@ def __init__( # TODO: figure out best solution for setting num_workers if num_workers is None: - num_workers = os.cpu_count() + num_workers = os.cpu_count() # if num_workers is None: # # warnings.warn("Could not infer cpu count automatically, setting it to zero") # num_workers = 0 From 41ac58c2994b324d86d2f79b8b4cd2b9aa5a065e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 3 Feb 2021 01:09:56 +0100 Subject: [PATCH 4/5] fix --- tests/core/test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_data.py b/tests/core/test_data.py index 65f69cec7c6..02fa763f412 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -55,7 +55,7 @@ def test_cpu_count_none(): train_ds = DummyDataset() # with patch("os.cpu_count", return_value=None), pytest.warns(UserWarning, match="Could not infer"): dm = DataModule(train_ds, num_workers=None) - assert dm.num_workers == 0 + assert dm.num_workers > 0 def test_pipeline(): From 513c951b5a484fbd4631e970ee81df8ea947a0df Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 3 Feb 2021 09:10:54 +0100 Subject: [PATCH 5/5] w --- flash/core/data/datamodule.py | 9 +++++---- tests/core/test_data.py | 7 ++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/flash/core/data/datamodule.py b/flash/core/data/datamodule.py index 5597a05512a..d32699d2ebf 100644 --- a/flash/core/data/datamodule.py +++ b/flash/core/data/datamodule.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import platform from typing import Any, Optional import pytorch_lightning as pl @@ -64,10 +65,10 @@ def __init__( # TODO: figure out best solution for setting num_workers if num_workers is None: - num_workers = os.cpu_count() - # if num_workers is None: - # # warnings.warn("Could not infer cpu count automatically, setting it to zero") - # num_workers = 0 + if platform.system() == "Darwin": + num_workers = 0 + else: + num_workers = os.cpu_count() self.num_workers = num_workers self._data_pipeline = None diff --git a/tests/core/test_data.py b/tests/core/test_data.py index 02fa763f412..ef0740a3d01 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import platform + import torch from flash import DataModule @@ -55,7 +57,10 @@ def test_cpu_count_none(): train_ds = DummyDataset() # with patch("os.cpu_count", return_value=None), pytest.warns(UserWarning, match="Could not infer"): dm = DataModule(train_ds, num_workers=None) - assert dm.num_workers > 0 + if platform.system() == "Darwin": + assert dm.num_workers == 0 + else: + assert dm.num_workers > 0 def test_pipeline():